├── .gitignore ├── DownloadProcessedWESADdata.py ├── Models ├── Fusions.py ├── Routines.py ├── TA_LSTM.py ├── TLSTM.py └── TorchModels.py ├── MultiWave.jpg ├── README.md ├── environment.yml ├── main.py └── utils ├── Dataset.py ├── ModelUtils.py ├── WaveletUtils.py ├── __init__.py └── pytorchtools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # wandb 163 | wandb/ 164 | 165 | # Checkpoints 166 | Checkpoints/ 167 | 168 | # datasets 169 | datasets/ -------------------------------------------------------------------------------- /DownloadProcessedWESADdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import gdown 4 | 5 | if __name__ == '__main__': 6 | file_id = '1ve9ChpHoQ9SpplZAu6e2gdXvWOJyBIP6' 7 | DataSetPath = 'datasets/WESAD' 8 | os.makedirs(DataSetPath, exist_ok=True) 9 | destination = os.path.join(DataSetPath, 'WESAD_processed.zip') 10 | 11 | url = f"https://drive.google.com/uc?id={file_id}" 12 | gdown.download(url, destination, quiet=False) 13 | 14 | with zipfile.ZipFile(destination, 'r') as zip_ref: 15 | zip_ref.extractall(DataSetPath) -------------------------------------------------------------------------------- /Models/Fusions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from Models.TorchModels import CVE, Attention, Transformer 6 | import wandb 7 | 8 | class switch(torch.autograd.Function): 9 | """ 10 | Custom autograd function to apply a mask during the forward pass. 11 | """ 12 | @staticmethod 13 | def forward(ctx, x, mask): 14 | return mask * x 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | return grad_output, None 19 | 20 | class gradswitch(torch.autograd.Function): 21 | """ 22 | Custom autograd function to apply a mask during the backward pass. 23 | """ 24 | @staticmethod 25 | def forward(ctx, x, mask): 26 | ctx.mask = mask 27 | return x 28 | 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | mask = ctx.mask 32 | return grad_output * mask, None 33 | 34 | class MaskedGradLinear(nn.Module): 35 | """ 36 | MaskedGradLinear is a neural network module that applies masks to gradients during the forward pass. 37 | """ 38 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True, masks=[0., 1., 1., 0., 0., 0., 0., 0.]): 39 | """ 40 | Initializes the MaskedGradLinear module. 41 | 42 | Parameters: 43 | input_size_all (list): List of input sizes for each component. 44 | hidden_size (int): Size of the hidden layer. 45 | out_size (int): Size of the output layer. 46 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 47 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 48 | masks (list, optional): List of masks for each component. Defaults to [0., 1., 1., 0., 0., 0., 0., 0.]. 49 | """ 50 | super(MaskedGradLinear, self).__init__() 51 | multiplier = 2 if bidirectional else 1 52 | self.useExtralin = useExtralin 53 | if not useExtralin: 54 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), out_size) 55 | else: 56 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), hidden_size) 57 | self.LastLinear = nn.Linear(hidden_size, out_size) 58 | self.relu = nn.ReLU() 59 | self.masks = masks 60 | Ms = [] 61 | for i in range(len(input_size_all)): 62 | if input_size_all[i] > 0: 63 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 64 | Ms.append(ModelLinear) 65 | self.CompOuts = nn.ModuleList(Ms) 66 | 67 | def forward(self, out): 68 | """ 69 | Defines the forward pass of the MaskedGradLinear module. 70 | 71 | Parameters: 72 | out (list): List of input tensors for each component. 73 | 74 | Returns: 75 | tuple: Output tensor and list of component outputs. 76 | """ 77 | outmasked = [] 78 | for i in range(len(out)): 79 | outmasked.append(out[i]) 80 | cont_emb = torch.cat(outmasked, -1) 81 | if self.useExtralin: 82 | op = self.relu(self.foreLinear(cont_emb)) 83 | op = self.LastLinear(op).squeeze(-1) 84 | else: 85 | op = self.foreLinear(cont_emb).squeeze(-1) 86 | Outs = [] 87 | for i in range(len(out)): 88 | o = self.CompOuts[i](gradswitch.apply(out[i], 0.0)) 89 | Outs.append(o.squeeze(-1)) 90 | return op, Outs 91 | 92 | class MaskedFusionSwitch(nn.Module): 93 | """ 94 | MaskedFusionSwitch is a neural network module that applies masks during the forward pass. 95 | """ 96 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True, masks=[1., 1., 1., 1., 1., 1., 1., 1.]): 97 | """ 98 | Initializes the MaskedFusionSwitch module. 99 | 100 | Parameters: 101 | input_size_all (list): List of input sizes for each component. 102 | hidden_size (int): Size of the hidden layer. 103 | out_size (int): Size of the output layer. 104 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 105 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 106 | masks (list, optional): List of masks for each component. Defaults to [1., 1., 1., 1., 1., 1., 1., 1.]. 107 | """ 108 | super(MaskedFusionSwitch, self).__init__() 109 | multiplier = 2 if bidirectional else 1 110 | self.useExtralin = useExtralin 111 | if not useExtralin: 112 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), out_size) 113 | else: 114 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), hidden_size) 115 | self.LastLinear = nn.Linear(hidden_size, out_size) 116 | self.relu = nn.ReLU() 117 | self.masks = masks 118 | Ms = [] 119 | for i in range(len(input_size_all)): 120 | if input_size_all[i] > 0: 121 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 122 | Ms.append(ModelLinear) 123 | self.CompOuts = nn.ModuleList(Ms) 124 | 125 | def forward(self, out): 126 | """ 127 | Defines the forward pass of the MaskedFusionSwitch module. 128 | 129 | Parameters: 130 | out (list): List of input tensors for each component. 131 | 132 | Returns: 133 | tuple: Output tensor and list of component outputs. 134 | """ 135 | outmasked = [] 136 | for i in range(len(out)): 137 | outmasked.append(switch.apply(out[i], self.masks[i])) 138 | cont_emb = torch.cat(outmasked, -1) 139 | if self.useExtralin: 140 | op = self.relu(self.foreLinear(cont_emb)) 141 | op = self.LastLinear(op).squeeze(-1) 142 | else: 143 | op = self.foreLinear(cont_emb).squeeze(-1) 144 | Outs = [] 145 | for i in range(len(out)): 146 | o = self.CompOuts[i](out[i]) 147 | Outs.append(o.squeeze(-1)) 148 | return op, Outs 149 | 150 | class MaskedFusionGradSwitch(nn.Module): 151 | """ 152 | MaskedFusionGradSwitch is a neural network module that applies masks during the forward and backward passes. 153 | """ 154 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True, masks=[1., 1., 1., 1., 1., 1., 1., 1.], GradMasks=[1., 1., 1., 1., 1., 1., 1., 1.]): 155 | """ 156 | Initializes the MaskedFusionGradSwitch module. 157 | 158 | Parameters: 159 | input_size_all (list): List of input sizes for each component. 160 | hidden_size (int): Size of the hidden layer. 161 | out_size (int): Size of the output layer. 162 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 163 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 164 | masks (list, optional): List of masks for each component. Defaults to [1., 1., 1., 1., 1., 1., 1., 1.]. 165 | GradMasks (list, optional): List of gradient masks for each component. Defaults to [1., 1., 1., 1., 1., 1., 1., 1.]. 166 | """ 167 | super(MaskedFusionGradSwitch, self).__init__() 168 | multiplier = 2 if bidirectional else 1 169 | self.useExtralin = useExtralin 170 | if not useExtralin: 171 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), out_size) 172 | else: 173 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), hidden_size) 174 | self.LastLinear = nn.Linear(hidden_size, out_size) 175 | self.relu = nn.ReLU() 176 | self.masks = masks 177 | self.GradMasks = GradMasks 178 | Ms = [] 179 | for i in range(len(input_size_all)): 180 | if input_size_all[i] > 0: 181 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 182 | Ms.append(ModelLinear) 183 | self.CompOuts = nn.ModuleList(Ms) 184 | 185 | def forward(self, out): 186 | """ 187 | Defines the forward pass of the MaskedFusionGradSwitch module. 188 | 189 | Parameters: 190 | out (list): List of input tensors for each component. 191 | 192 | Returns: 193 | tuple: Output tensor and list of component outputs. 194 | """ 195 | outmasked = [] 196 | for i in range(len(out)): 197 | outmasked.append(switch.apply(out[i], self.masks[i])) 198 | cont_emb = torch.cat(outmasked, -1) 199 | if self.useExtralin: 200 | op = self.relu(self.foreLinear(cont_emb)) 201 | op = self.LastLinear(op).squeeze(-1) 202 | else: 203 | op = self.foreLinear(cont_emb).squeeze(-1) 204 | Outs = [] 205 | for i in range(len(out)): 206 | o = self.CompOuts[i](gradswitch.apply(out[i], self.GradMasks[i])) 207 | Outs.append(o.squeeze(-1)) 208 | return op, Outs 209 | 210 | 211 | class MaskedFusion(nn.Module): 212 | """ 213 | MaskedFusion is a neural network module that applies masks during the forward pass. 214 | """ 215 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True, masks=[0., 1., 1., 0., 0., 0., 0., 0.]): 216 | """ 217 | Initializes the MaskedFusion module. 218 | 219 | Parameters: 220 | input_size_all (list): List of input sizes for each component. 221 | hidden_size (int): Size of the hidden layer. 222 | out_size (int): Size of the output layer. 223 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 224 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 225 | masks (list, optional): List of masks for each component. Defaults to [0., 1., 1., 0., 0., 0., 0., 0.]. 226 | """ 227 | super(MaskedFusion, self).__init__() 228 | multiplier = 2 if bidirectional else 1 229 | self.useExtralin = useExtralin 230 | if not useExtralin: 231 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), out_size) 232 | else: 233 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), hidden_size) 234 | self.LastLinear = nn.Linear(hidden_size, out_size) 235 | self.relu = nn.ReLU() 236 | self.masks = masks 237 | Ms = [] 238 | for i in range(len(input_size_all)): 239 | if input_size_all[i] > 0: 240 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 241 | Ms.append(ModelLinear) 242 | self.CompOuts = nn.ModuleList(Ms) 243 | 244 | def forward(self, out): 245 | """ 246 | Defines the forward pass of the MaskedFusion module. 247 | 248 | Parameters: 249 | out (list): List of input tensors for each component. 250 | 251 | Returns: 252 | tuple: Output tensor and list of component outputs. 253 | """ 254 | outmasked = [] 255 | for i in range(len(out)): 256 | outmasked.append(self.masks[i] * out[i]) 257 | cont_emb = torch.cat(outmasked, -1) 258 | if self.useExtralin: 259 | op = self.relu(self.foreLinear(cont_emb)) 260 | op = self.LastLinear(op).squeeze(-1) 261 | else: 262 | op = self.foreLinear(cont_emb).squeeze(-1) 263 | Outs = [] 264 | for i in range(len(out)): 265 | o = self.CompOuts[i](out[i]) 266 | Outs.append(o.squeeze(-1)) 267 | return op, Outs 268 | 269 | class SigWeightedFusion(nn.Module): 270 | """ 271 | SigWeightedFusion is a neural network module that applies sigmoid-weighted fusion during the forward pass. 272 | """ 273 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 274 | """ 275 | Initializes the SigWeightedFusion module. 276 | 277 | Parameters: 278 | input_size_all (list): List of input sizes for each component. 279 | hidden_size (int): Size of the hidden layer. 280 | out_size (int): Size of the output layer. 281 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 282 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 283 | """ 284 | super(SigWeightedFusion, self).__init__() 285 | multiplier = 2 if bidirectional else 1 286 | self.useExtralin = useExtralin 287 | if not useExtralin: 288 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), out_size) 289 | else: 290 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), hidden_size) 291 | self.LastLinear = nn.Linear(hidden_size, out_size) 292 | self.relu = nn.ReLU() 293 | self.sig = nn.Sigmoid() 294 | self.weights = nn.Parameter(torch.zeros(len(input_size_all))) 295 | Ms = [] 296 | for i in range(len(input_size_all)): 297 | if input_size_all[i] > 0: 298 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 299 | Ms.append(ModelLinear) 300 | self.CompOuts = nn.ModuleList(Ms) 301 | 302 | def forward(self, out): 303 | """ 304 | Defines the forward pass of the SigWeightedFusion module. 305 | 306 | Parameters: 307 | out (list): List of input tensors for each component. 308 | 309 | Returns: 310 | tuple: Output tensor and list of component outputs. 311 | """ 312 | outmasked = [] 313 | weights = self.sig(self.weights * 10.) 314 | for i in range(len(out)): 315 | outmasked.append(weights[i] * out[i]) 316 | cont_emb = torch.cat(outmasked, -1) 317 | if self.useExtralin: 318 | op = self.relu(self.foreLinear(cont_emb)) 319 | op = self.LastLinear(op).squeeze(-1) 320 | else: 321 | op = self.foreLinear(cont_emb).squeeze(-1) 322 | Outs = [] 323 | for i in range(len(out)): 324 | o = self.CompOuts[i](out[i]) 325 | Outs.append(o.squeeze(-1)) 326 | self.wandblog(weights) 327 | return op, Outs 328 | 329 | def wandblog(self, weights): 330 | """ 331 | Logs the weights to Weights & Biases. 332 | 333 | Parameters: 334 | weights (torch.Tensor): Weights tensor. 335 | """ 336 | weightdict = {} 337 | for i, w in enumerate(weights): 338 | weightdict['weights_Model' + str(i)] = w 339 | wandb.log(weightdict, commit=False) 340 | 341 | class GumbelWeightedFusion(nn.Module): 342 | """ 343 | GumbelWeightedFusion is a neural network module that applies Gumbel-softmax-weighted fusion during the forward pass. 344 | """ 345 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 346 | """ 347 | Initializes the GumbelWeightedFusion module. 348 | 349 | Parameters: 350 | input_size_all (list): List of input sizes for each component. 351 | hidden_size (int): Size of the hidden layer. 352 | out_size (int): Size of the output layer. 353 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 354 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 355 | """ 356 | super(GumbelWeightedFusion, self).__init__() 357 | multiplier = 2 if bidirectional else 1 358 | self.useExtralin = useExtralin 359 | self.tau = 100. 360 | if not useExtralin: 361 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), out_size) 362 | else: 363 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), hidden_size) 364 | self.LastLinear = nn.Linear(hidden_size, out_size) 365 | self.relu = nn.ReLU() 366 | self.sig = nn.Sigmoid() 367 | self.weights = nn.Parameter(torch.ones(len(input_size_all)) / len(input_size_all)) 368 | Ms = [] 369 | for i in range(len(input_size_all)): 370 | if input_size_all[i] > 0: 371 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 372 | Ms.append(ModelLinear) 373 | self.CompOuts = nn.ModuleList(Ms) 374 | 375 | def forward(self, out): 376 | """ 377 | Defines the forward pass of the GumbelWeightedFusion module. 378 | 379 | Parameters: 380 | out (list): List of input tensors for each component. 381 | 382 | Returns: 383 | tuple: Output tensor and list of component outputs. 384 | """ 385 | outmasked = [] 386 | weights = F.gumbel_softmax(self.weights, tau=self.tau, hard=False) 387 | for i in range(len(out)): 388 | outmasked.append(weights[i] * out[i]) 389 | cont_emb = torch.cat(outmasked, -1) 390 | if self.useExtralin: 391 | op = self.relu(self.foreLinear(cont_emb)) 392 | op = self.LastLinear(op).squeeze(-1) 393 | else: 394 | op = self.foreLinear(cont_emb).squeeze(-1) 395 | Outs = [] 396 | for i in range(len(out)): 397 | o = self.CompOuts[i](out[i]) 398 | Outs.append(o.squeeze(-1)) 399 | self.wandblog(weights) 400 | return op, Outs 401 | 402 | def wandblog(self, weights): 403 | """ 404 | Logs the weights to Weights & Biases. 405 | 406 | Parameters: 407 | weights (torch.Tensor): Weights tensor. 408 | """ 409 | weightdict = {} 410 | for i, w in enumerate(weights): 411 | weightdict['weights_Model' + str(i)] = w 412 | weightdict['Tau'] = self.tau 413 | wandb.log(weightdict, commit=False) 414 | 415 | class AttentionFusion(nn.Module): 416 | """ 417 | AttentionFusion is a neural network module that applies attention-based fusion during the forward pass. 418 | """ 419 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 420 | """ 421 | Initializes the AttentionFusion module. 422 | 423 | Parameters: 424 | input_size_all (list): List of input sizes for each component. 425 | hidden_size (int): Size of the hidden layer. 426 | out_size (int): Size of the output layer. 427 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 428 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 429 | """ 430 | super(AttentionFusion, self).__init__() 431 | multiplier = 2 if bidirectional else 1 432 | self.attn = Attention(hidden_size, multiplier * max(input_size_all)) 433 | self.useExtralin = useExtralin 434 | if not useExtralin: 435 | self.foreLinear = nn.Linear(multiplier * max(input_size_all), out_size) 436 | else: 437 | self.foreLinear = nn.Linear(multiplier * max(input_size_all), hidden_size) 438 | self.LastLinear = nn.Linear(hidden_size, out_size) 439 | self.relu = nn.ReLU() 440 | Ms = [] 441 | for i in range(len(input_size_all)): 442 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 443 | Ms.append(ModelLinear) 444 | self.CompOuts = nn.ModuleList(Ms) 445 | 446 | def forward(self, out): 447 | """ 448 | Defines the forward pass of the AttentionFusion module. 449 | 450 | Parameters: 451 | out (list): List of input tensors for each component. 452 | 453 | Returns: 454 | tuple: Output tensor and list of component outputs. 455 | """ 456 | cont_emb = torch.stack(out, -2) 457 | masks = torch.ones([cont_emb.shape[0], cont_emb.shape[1]], device=cont_emb.device) 458 | attn_weights = self.attn(cont_emb, masks) 459 | op = torch.sum(cont_emb * attn_weights, dim=-2) 460 | if self.useExtralin: 461 | op = self.relu(self.foreLinear(cont_emb)) 462 | op = self.LastLinear(op).squeeze(-1) 463 | else: 464 | op = self.foreLinear(cont_emb).squeeze(-1) 465 | Outs = [] 466 | for i in range(len(out)): 467 | o = self.CompOuts[i](out[i]) 468 | Outs.append(o.squeeze(-1)) 469 | return op, Outs 470 | 471 | 472 | class AvgFusion(nn.Module): 473 | """ 474 | AvgFusion is a neural network module that performs average fusion during the forward pass. 475 | """ 476 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 477 | """ 478 | Initializes the AvgFusion module. 479 | 480 | Parameters: 481 | input_size_all (list): List of input sizes for each component. 482 | hidden_size (int): Size of the hidden layer. 483 | out_size (int): Size of the output layer. 484 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 485 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 486 | """ 487 | super(AvgFusion, self).__init__() 488 | multiplier = 2 if bidirectional else 1 489 | self.useExtralin = useExtralin 490 | if not useExtralin: 491 | self.foreLinear = nn.Linear(multiplier * max(input_size_all), out_size) 492 | else: 493 | self.foreLinear = nn.Linear(multiplier * max(input_size_all), hidden_size) 494 | self.LastLinear = nn.Linear(hidden_size, out_size) 495 | self.relu = nn.ReLU() 496 | Ms = [] 497 | for i in range(len(input_size_all)): 498 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 499 | Ms.append(ModelLinear) 500 | self.CompOuts = nn.ModuleList(Ms) 501 | 502 | def forward(self, out): 503 | """ 504 | Defines the forward pass of the AvgFusion module. 505 | 506 | Parameters: 507 | out (list): List of input tensors for each component. 508 | 509 | Returns: 510 | tuple: Output tensor and list of component outputs. 511 | """ 512 | cont_emb = torch.stack(out, -2) 513 | cont_emb = torch.mean(cont_emb, dim=-2) 514 | if self.useExtralin: 515 | op = self.relu(self.foreLinear(cont_emb)) 516 | op = self.LastLinear(op).squeeze(-1) 517 | else: 518 | op = self.foreLinear(cont_emb).squeeze(-1) 519 | Outs = [] 520 | for i in range(len(out)): 521 | o = self.CompOuts[i](out[i]) 522 | Outs.append(o.squeeze(-1)) 523 | return op, Outs 524 | 525 | class WeightedAvgFusion(nn.Module): 526 | """ 527 | WeightedAvgFusion is a neural network module that performs weighted average fusion during the forward pass. 528 | """ 529 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 530 | """ 531 | Initializes the WeightedAvgFusion module. 532 | 533 | Parameters: 534 | input_size_all (list): List of input sizes for each component. 535 | hidden_size (int): Size of the hidden layer. 536 | out_size (int): Size of the output layer. 537 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 538 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 539 | """ 540 | super(WeightedAvgFusion, self).__init__() 541 | multiplier = 2 if bidirectional else 1 542 | self.useExtralin = useExtralin 543 | if not useExtralin: 544 | self.foreLinear = nn.Linear(multiplier * max(input_size_all), out_size) 545 | else: 546 | self.foreLinear = nn.Linear(multiplier * max(input_size_all), hidden_size) 547 | self.LastLinear = nn.Linear(hidden_size, out_size) 548 | self.relu = nn.ReLU() 549 | self.weights = nn.Parameter(torch.ones([len(input_size_all)]) / len(input_size_all)) 550 | Ms = [] 551 | for i in range(len(input_size_all)): 552 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 553 | Ms.append(ModelLinear) 554 | self.CompOuts = nn.ModuleList(Ms) 555 | 556 | def forward(self, out): 557 | """ 558 | Defines the forward pass of the WeightedAvgFusion module. 559 | 560 | Parameters: 561 | out (list): List of input tensors for each component. 562 | 563 | Returns: 564 | tuple: Output tensor and list of component outputs. 565 | """ 566 | weights = self.weights / torch.sum(self.weights) 567 | cont_emb = torch.stack(out, -1) 568 | cont_emb = torch.matmul(cont_emb, weights) 569 | if self.useExtralin: 570 | op = self.relu(self.foreLinear(cont_emb)) 571 | op = self.LastLinear(op).squeeze(-1) 572 | else: 573 | op = self.foreLinear(cont_emb).squeeze(-1) 574 | Outs = [] 575 | for i in range(len(out)): 576 | o = self.CompOuts[i](out[i]) 577 | Outs.append(o.squeeze(-1)) 578 | self.wandblog(weights) 579 | return op, Outs 580 | 581 | def wandblog(self, weights): 582 | """ 583 | Logs the weights to Weights & Biases. 584 | 585 | Parameters: 586 | weights (torch.Tensor): Weights tensor. 587 | """ 588 | weightdict = {} 589 | for i, w in enumerate(weights): 590 | weightdict['weights_Model' + str(i)] = w 591 | wandb.log(weightdict, commit=False) 592 | 593 | class WeightedAvgEnsemble(nn.Module): 594 | """ 595 | WeightedAvgEnsemble is a neural network module that performs weighted average ensemble during the forward pass. 596 | """ 597 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 598 | """ 599 | Initializes the WeightedAvgEnsemble module. 600 | 601 | Parameters: 602 | input_size_all (list): List of input sizes for each component. 603 | hidden_size (int): Size of the hidden layer. 604 | out_size (int): Size of the output layer. 605 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 606 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 607 | """ 608 | super(WeightedAvgEnsemble, self).__init__() 609 | multiplier = 2 if bidirectional else 1 610 | self.weights = nn.Parameter(torch.ones([len(input_size_all)]) / len(input_size_all)) 611 | Ms = [] 612 | for i in range(len(input_size_all)): 613 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 614 | Ms.append(ModelLinear) 615 | self.CompOuts = nn.ModuleList(Ms) 616 | 617 | def forward(self, out): 618 | """ 619 | Defines the forward pass of the WeightedAvgEnsemble module. 620 | 621 | Parameters: 622 | out (list): List of input tensors for each component. 623 | 624 | Returns: 625 | tuple: Output tensor and list of component outputs. 626 | """ 627 | weights = self.weights / torch.sum(self.weights) 628 | op = 0.0 629 | Outs = [] 630 | for i in range(len(out)): 631 | o = self.CompOuts[i](out[i]) 632 | Outs.append(o.squeeze(-1)) 633 | op += weights[i] * o.squeeze(-1) 634 | self.wandblog(weights) 635 | return op, Outs 636 | 637 | def wandblog(self, weights): 638 | """ 639 | Logs the weights to Weights & Biases. 640 | 641 | Parameters: 642 | weights (torch.Tensor): Weights tensor. 643 | """ 644 | weightdict = {} 645 | for i, w in enumerate(weights): 646 | weightdict['weights_Model' + str(i)] = w 647 | wandb.log(weightdict, commit=False) 648 | 649 | class LinearFusionSameFC(nn.Module): 650 | """ 651 | LinearFusionSameFC is a neural network module that performs linear fusion using the same fully connected layer for each component. 652 | """ 653 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 654 | """ 655 | Initializes the LinearFusionSameFC module. 656 | 657 | Parameters: 658 | input_size_all (list): List of input sizes for each component. 659 | hidden_size (int): Size of the hidden layer. 660 | out_size (int): Size of the output layer. 661 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 662 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 663 | """ 664 | super(LinearFusionSameFC, self).__init__() 665 | multiplier = 2 if bidirectional else 1 666 | self.useExtralin = useExtralin 667 | if not useExtralin: 668 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), out_size) 669 | else: 670 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), hidden_size) 671 | self.LastLinear = nn.Linear(hidden_size, out_size) 672 | self.relu = nn.ReLU() 673 | Ms = [] 674 | self.ModelLinear = nn.Linear(multiplier * max(input_size_all), out_size) 675 | 676 | def forward(self, out): 677 | """ 678 | Defines the forward pass of the LinearFusionSameFC module. 679 | 680 | Parameters: 681 | out (list): List of input tensors for each component. 682 | 683 | Returns: 684 | tuple: Output tensor and list of component outputs. 685 | """ 686 | cont_emb = torch.cat(out, -1) 687 | if self.useExtralin: 688 | op = self.relu(self.foreLinear(cont_emb)) 689 | op = self.LastLinear(op).squeeze(-1) 690 | else: 691 | op = self.foreLinear(cont_emb).squeeze(-1) 692 | Outs = [] 693 | for i in range(len(out)): 694 | o = self.ModelLinear(out[i]) 695 | Outs.append(o.squeeze(-1)) 696 | return op, Outs 697 | 698 | class LinearFusion(nn.Module): 699 | """ 700 | LinearFusion is a neural network module that performs linear fusion using separate fully connected layers for each component. 701 | """ 702 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 703 | """ 704 | Initializes the LinearFusion module. 705 | 706 | Parameters: 707 | input_size_all (list): List of input sizes for each component. 708 | hidden_size (int): Size of the hidden layer. 709 | out_size (int): Size of the output layer. 710 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 711 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 712 | """ 713 | super(LinearFusion, self).__init__() 714 | multiplier = 2 if bidirectional else 1 715 | self.useExtralin = useExtralin 716 | print('input_size_all ', sum(input_size_all)) 717 | if not useExtralin: 718 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), out_size) 719 | else: 720 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), hidden_size) 721 | self.LastLinear = nn.Linear(hidden_size, out_size) 722 | self.relu = nn.ReLU() 723 | Ms = [] 724 | for i in range(len(input_size_all)): 725 | if input_size_all[i] > 0: 726 | ModelLinear = nn.Linear(multiplier * input_size_all[i], out_size) 727 | Ms.append(ModelLinear) 728 | self.CompOuts = nn.ModuleList(Ms) 729 | 730 | def forward(self, out): 731 | """ 732 | Defines the forward pass of the LinearFusion module. 733 | 734 | Parameters: 735 | out (list): List of input tensors for each component. 736 | 737 | Returns: 738 | tuple: Output tensor and list of component outputs. 739 | """ 740 | cont_emb = torch.cat(out, -1) 741 | if self.useExtralin: 742 | op = self.relu(self.foreLinear(cont_emb)) 743 | op = self.LastLinear(op).squeeze(-1) 744 | else: 745 | op = self.foreLinear(cont_emb).squeeze(-1) 746 | Outs = [] 747 | for i in range(len(out)): 748 | o = self.CompOuts[i](out[i]) 749 | Outs.append(o.squeeze(-1)) 750 | return op, Outs 751 | 752 | class HieLinFusion(nn.Module): 753 | """ 754 | HieLinFusion is a neural network module that performs hierarchical linear fusion during the forward pass. 755 | """ 756 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 757 | """ 758 | Initializes the HieLinFusion module. 759 | 760 | Parameters: 761 | input_size_all (list): List of input sizes for each component. 762 | hidden_size (int): Size of the hidden layer. 763 | out_size (int): Size of the output layer. 764 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 765 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 766 | """ 767 | super(HieLinFusion, self).__init__() 768 | Ms = [] 769 | Mouts = [] 770 | multiplier = 2 if bidirectional else 1 771 | for i in range(1, len(input_size_all)): 772 | if i < len(input_size_all) - 1: 773 | model = nn.Sequential(nn.Linear(multiplier * (input_size_all[i-1] + input_size_all[i]), multiplier * input_size_all[i]), 774 | nn.ReLU()) 775 | else: 776 | model = nn.Sequential(nn.Linear(2 * (input_size_all[i-1] + input_size_all[i]), out_size), 777 | nn.ReLU()) 778 | Ms.append(model) 779 | linearout = nn.Linear(multiplier * input_size_all[i-1], 1) 780 | Mouts.append(linearout) 781 | self.linears = nn.ModuleList(Ms) 782 | self.CompOuts = nn.ModuleList(Mouts) 783 | 784 | def forward(self, out): 785 | """ 786 | Defines the forward pass of the HieLinFusion module. 787 | 788 | Parameters: 789 | out (list): List of input tensors for each component. 790 | 791 | Returns: 792 | tuple: Output tensor and list of component outputs. 793 | """ 794 | Outs = [] 795 | o = out[0] 796 | for i in range(len(self.linears)): 797 | op = self.CompOuts[i](o) 798 | Outs.append(op.squeeze(-1)) 799 | o = self.linears[i](torch.cat((o, out[i+1]), -1)) 800 | return o, Outs 801 | 802 | 803 | class TransformerFusion(nn.Module): 804 | """ 805 | TransformerFusion is a neural network module that applies transformer-based fusion during the forward pass. 806 | """ 807 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True, dropout=0.0): 808 | """ 809 | Initializes the TransformerFusion module. 810 | 811 | Parameters: 812 | input_size_all (list): List of input sizes for each component. 813 | hidden_size (int): Size of the hidden layer. 814 | out_size (int): Size of the output layer. 815 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 816 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 817 | dropout (float, optional): Dropout rate. Defaults to 0.0. 818 | """ 819 | super(TransformerFusion, self).__init__() 820 | multiplier = 2 if bidirectional else 1 821 | self.varEmebedding = nn.Embedding(len(input_size_all), multiplier * input_size_all[0]) # Assuming all models have same size 822 | d, N, hes = multiplier * input_size_all[0], 1, 1 823 | self.transformer = Transformer(d, N, hes, dk=None, dv=None, dff=None, dropout=dropout) 824 | self.attn = Attention(2 * d, d) 825 | self.useExtralin = useExtralin 826 | if not useExtralin: 827 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), out_size) 828 | else: 829 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), hidden_size) 830 | self.LastLinear = nn.Linear(hidden_size, out_size) 831 | self.relu = nn.ReLU() 832 | Ms = [] 833 | for i in range(len(input_size_all)): 834 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 835 | Ms.append(ModelLinear) 836 | self.CompOuts = nn.ModuleList(Ms) 837 | 838 | def forward(self, out): 839 | """ 840 | Defines the forward pass of the TransformerFusion module. 841 | 842 | Parameters: 843 | out (list): List of input tensors for each component. 844 | 845 | Returns: 846 | tuple: Output tensor and list of component outputs. 847 | """ 848 | comb_emb = torch.stack(out, -2) 849 | masks = torch.ones([comb_emb.shape[0], comb_emb.shape[1]], device=comb_emb.device) 850 | varis = torch.arange(len(out), device=comb_emb.device) 851 | varis_emb = self.varEmebedding(varis) 852 | comb_emb = varis_emb + comb_emb 853 | cont_emb = self.transformer(comb_emb, mask=masks) 854 | attn_weights = self.attn(cont_emb, mask=masks) 855 | op = torch.sum(cont_emb * attn_weights, dim=-2) 856 | if self.useExtralin: 857 | op = self.relu(self.foreLinear(op)) 858 | op = self.LastLinear(op).squeeze(-1) 859 | else: 860 | op = self.foreLinear(op).squeeze(-1) 861 | Outs = [] 862 | for i in range(len(out)): 863 | o = self.CompOuts[i](out[i]) 864 | Outs.append(o.squeeze(-1)) 865 | return op, Outs 866 | 867 | class trainableswitch(nn.Module): 868 | """ 869 | trainableswitch is a neural network module that applies a trainable mask during the forward pass. 870 | """ 871 | def __init__(self): 872 | """ 873 | Initializes the trainableswitch module. 874 | """ 875 | super(trainableswitch, self).__init__() 876 | self.W = nn.Parameter(torch.tensor(0.6)) 877 | self.activation = nn.ReLU() 878 | self.Mask = 0.6 879 | 880 | def forward(self, x): 881 | """ 882 | Defines the forward pass of the trainableswitch module. 883 | 884 | Parameters: 885 | x (torch.Tensor): Input tensor. 886 | 887 | Returns: 888 | torch.Tensor: Masked input tensor. 889 | """ 890 | self.Mask = self.activation(self.W - 0.5) 891 | return x * self.Mask 892 | 893 | class TrainableFusionSwitch(nn.Module): 894 | """ 895 | TrainableFusionSwitch is a neural network module that applies trainable masks during the forward pass. 896 | """ 897 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 898 | """ 899 | Initializes the TrainableFusionSwitch module. 900 | 901 | Parameters: 902 | input_size_all (list): List of input sizes for each component. 903 | hidden_size (int): Size of the hidden layer. 904 | out_size (int): Size of the output layer. 905 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 906 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 907 | """ 908 | super(TrainableFusionSwitch, self).__init__() 909 | multiplier = 2 if bidirectional else 1 910 | self.useExtralin = useExtralin 911 | if not useExtralin: 912 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), out_size) 913 | else: 914 | self.foreLinear = nn.Linear(multiplier * sum(input_size_all), hidden_size) 915 | self.LastLinear = nn.Linear(hidden_size, out_size) 916 | self.relu = nn.ReLU() 917 | switches = [] 918 | Ms = [] 919 | for i in range(len(input_size_all)): 920 | if input_size_all[i] > 0: 921 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 922 | Ms.append(ModelLinear) 923 | self.CompOuts = nn.ModuleList(Ms) 924 | self.switchweights = nn.Parameter(torch.tensor([0.1 for _ in range(len(input_size_all))])) 925 | self.activation = nn.ReLU() 926 | 927 | def forward(self, out): 928 | """ 929 | Defines the forward pass of the TrainableFusionSwitch module. 930 | 931 | Parameters: 932 | out (list): List of input tensors for each component. 933 | 934 | Returns: 935 | tuple: Output tensor and list of component outputs. 936 | """ 937 | self.switchMasks = self.activation(self.switchweights) 938 | outmasked = [] 939 | for i in range(len(out)): 940 | outmasked.append(self.switchMasks[i] * out[i]) 941 | cont_emb = torch.cat(outmasked, -1) 942 | if self.useExtralin: 943 | op = self.relu(self.foreLinear(cont_emb)) 944 | op = self.LastLinear(op).squeeze(-1) 945 | else: 946 | op = self.foreLinear(cont_emb).squeeze(-1) 947 | Outs = [] 948 | for i in range(len(out)): 949 | o = self.CompOuts[i](out[i]) 950 | Outs.append(o.squeeze(-1)) 951 | self.wandblog() 952 | return op, Outs 953 | 954 | def l1_norm(self): 955 | """ 956 | Computes the L1 norm of the switch masks. 957 | 958 | Returns: 959 | torch.Tensor: L1 norm of the switch masks. 960 | """ 961 | L1_norm = torch.norm(self.switchMasks, 1) 962 | return L1_norm 963 | 964 | def wandblog(self): 965 | """ 966 | Logs the switch weights and masks to Weights & Biases. 967 | """ 968 | weightdict = {} 969 | for i, sw in enumerate(self.switchMasks): 970 | weightdict['weights_Model' + str(i)] = self.switchweights[i].cpu().detach() 971 | weightdict['Mask_Model' + str(i)] = sw.cpu().detach() 972 | wandb.log(weightdict, commit=False) 973 | 974 | class BranchLayer(nn.Module): 975 | """ 976 | BranchLayer is a neural network module that applies Gumbel-softmax-based branching during the forward pass. 977 | """ 978 | def __init__(self, num_ins): 979 | """ 980 | Initializes the BranchLayer module. 981 | 982 | Parameters: 983 | num_ins (int): Number of input branches. 984 | """ 985 | super(BranchLayer, self).__init__() 986 | self.prob = nn.Parameter(torch.ones(num_ins)) 987 | 988 | def forward(self, outs, temp): 989 | """ 990 | Defines the forward pass of the BranchLayer module. 991 | 992 | Parameters: 993 | outs (list): List of input tensors for each branch. 994 | temp (float): Temperature parameter for Gumbel-softmax. 995 | 996 | Returns: 997 | torch.Tensor: Combined output tensor. 998 | """ 999 | self.weights = F.gumbel_softmax(self.prob, tau=temp, hard=False) 1000 | outcomb = 0.0 1001 | for i, o in enumerate(outs): 1002 | outcomb += o * self.weights[i] 1003 | return outcomb 1004 | 1005 | class SimpleWaveBranch(nn.Module): 1006 | """ 1007 | SimpleWaveBranch is a neural network module that applies a simple wavelet-based branching during the forward pass. 1008 | """ 1009 | def __init__(self, input_size_all, hidden_size, out_size, useExtralin=False, bidirectional=True): 1010 | """ 1011 | Initializes the SimpleWaveBranch module. 1012 | 1013 | Parameters: 1014 | input_size_all (list): List of input sizes for each component. 1015 | hidden_size (int): Size of the hidden layer. 1016 | out_size (int): Size of the output layer. 1017 | useExtralin (bool, optional): Whether to use an extra linear layer. Defaults to False. 1018 | bidirectional (bool, optional): Whether to use bidirectional layers. Defaults to True. 1019 | """ 1020 | super(SimpleWaveBranch, self).__init__() 1021 | multiplier = 2 if bidirectional else 1 1022 | self.WaveLinear = nn.Linear(multiplier * sum(input_size_all[:-1]), hidden_size) 1023 | self.SimpleModelLinear = nn.Linear(multiplier * input_size_all[-1], hidden_size) 1024 | self.Branching = BranchLayer(2) 1025 | self.useExtralin = useExtralin 1026 | self.relu = nn.ReLU() 1027 | self.LastLinear = nn.Linear(hidden_size, out_size) 1028 | self.temp = 10 1029 | Ms = [] 1030 | for i in range(len(input_size_all)): 1031 | if input_size_all[i] > 0: 1032 | ModelLinear = nn.Linear(multiplier * input_size_all[i], 1) 1033 | Ms.append(ModelLinear) 1034 | self.CompOuts = nn.ModuleList(Ms) 1035 | 1036 | def forward(self, out): 1037 | """ 1038 | Defines the forward pass of the SimpleWaveBranch module. 1039 | 1040 | Parameters: 1041 | out (list): List of input tensors for each component. 1042 | 1043 | Returns: 1044 | tuple: Output tensor and list of component outputs. 1045 | """ 1046 | wave_out = torch.cat(out[:-1], -1) 1047 | wave_out = self.WaveLinear(wave_out) 1048 | Simp_out = self.SimpleModelLinear(out[-1]) 1049 | op = self.Branching([wave_out, Simp_out], self.temp) 1050 | op = self.LastLinear(op).squeeze(-1) 1051 | Outs = [] 1052 | for i in range(len(out)): 1053 | o = self.CompOuts[i](out[i]) 1054 | Outs.append(o.squeeze(-1)) 1055 | self.wandblog() 1056 | return op, Outs 1057 | 1058 | def wandblog(self): 1059 | """ 1060 | Logs the branch weights to Weights & Biases. 1061 | """ 1062 | weightdict = {} 1063 | weightdict['weights_Wave'] = self.Branching.weights[0].cpu().detach() 1064 | weightdict['weights_Simple'] = self.Branching.weights[1].cpu().detach() 1065 | weightdict['Tau'] = self.temp 1066 | wandb.log(weightdict, commit=False) 1067 | 1068 | def getFusion(fusionstr): 1069 | """ 1070 | Returns the appropriate fusion module based on the specified fusion string. 1071 | 1072 | Parameters: 1073 | fusionstr (str): Fusion module type. 1074 | 1075 | Returns: 1076 | nn.Module: The appropriate fusion module. 1077 | 1078 | Raises: 1079 | ValueError: If the fusion string is not recognized. 1080 | """ 1081 | if fusionstr == 'LinearFusion': 1082 | fusion = LinearFusion 1083 | elif fusionstr == 'AttentionFusion': 1084 | fusion = AttentionFusion 1085 | elif fusionstr == 'TransformerFusion': 1086 | fusion = TransformerFusion 1087 | elif fusionstr == 'HieLinFusion': 1088 | fusion = HieLinFusion 1089 | elif fusionstr == 'AvgFusion': 1090 | fusion = AvgFusion 1091 | elif fusionstr == 'WeightedAvgFusion': 1092 | fusion = WeightedAvgFusion 1093 | elif fusionstr == 'MaskedFusion': 1094 | fusion = MaskedFusion 1095 | elif fusionstr == 'MaskedFusionSwitch': 1096 | fusion = MaskedFusionSwitch 1097 | elif fusionstr == 'LinearFusionSameFC': 1098 | fusion = LinearFusionSameFC 1099 | elif fusionstr == 'MaskedGradLinear': 1100 | fusion = MaskedGradLinear 1101 | elif fusionstr == 'WeightedAvgEnsemble': 1102 | fusion = WeightedAvgEnsemble 1103 | elif fusionstr == 'SigWeightedFusion': 1104 | fusion = SigWeightedFusion 1105 | elif fusionstr == 'GumbelWeightedFusion': 1106 | fusion = GumbelWeightedFusion 1107 | elif fusionstr == 'MaskedFusionGradSwitch': 1108 | fusion = MaskedFusionGradSwitch 1109 | elif fusionstr == 'TrainableFusionSwitch': 1110 | fusion = TrainableFusionSwitch 1111 | elif fusionstr == 'SimpleWaveBranch': 1112 | fusion = SimpleWaveBranch 1113 | else: 1114 | raise ValueError('Fusion value provided not found: ' + fusionstr) 1115 | return fusion -------------------------------------------------------------------------------- /Models/TA_LSTM.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import Parameter 5 | 6 | class TA_LSTMCell(nn.Module): 7 | def __init__(self, input_size, hidden_size): 8 | super(TA_LSTMCell, self).__init__() 9 | self.input_size = input_size 10 | self.hidden_size = hidden_size 11 | self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) 12 | self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) 13 | self.bias_ih = Parameter(torch.randn(4 * hidden_size)) 14 | self.bias_hh = Parameter(torch.randn(4 * hidden_size)) 15 | self.W_decomp = Parameter(torch.randn(hidden_size, hidden_size)) 16 | self.b_decomp = Parameter(torch.randn(hidden_size)) 17 | 18 | def g(self,t): 19 | T = torch.zeros_like(t) 20 | T[t.nonzero(as_tuple=True)] = 1 / t[t.nonzero(as_tuple=True)] 21 | 22 | Ones = torch.ones([1, self.hidden_size], dtype=torch.float32).to(t.device) 23 | 24 | T = torch.mm(T, Ones) 25 | return T 26 | 27 | def forward(self, input, t, state): 28 | # type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] 29 | hx, cx = state 30 | gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih + 31 | torch.mm(hx, self.weight_hh.t()) + self.bias_hh) 32 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 33 | 34 | T = self.g(t) 35 | 36 | C_ST = torch.tanh(torch.mm(cx, self.W_decomp) + self.b_decomp) 37 | C_ST_dis = T * C_ST 38 | # if T is 0, then the weight is one 39 | cx = cx - C_ST + C_ST_dis 40 | 41 | ingate = torch.sigmoid(ingate) 42 | forgetgate = torch.sigmoid(forgetgate) 43 | cellgate = torch.tanh(cellgate) 44 | outgate = torch.sigmoid(outgate) 45 | 46 | cy = (forgetgate * cx) + (ingate * cellgate) 47 | hy = outgate * torch.tanh(cy) 48 | 49 | return (hy, cy) 50 | 51 | class TA_LSTM(nn.Module): 52 | def __init__(self, input_size, hidden_size): 53 | super(TA_LSTM, self).__init__() 54 | self.TA_lstm = TA_LSTMCell(input_size, hidden_size) 55 | self.hidden_size = hidden_size 56 | def forward(self, X, time): 57 | c = torch.zeros([self.hidden_size]) 58 | h = torch.zeros([self.hidden_size]) 59 | state = (h, c) 60 | AllStates = [] 61 | for x,t in zip(X,time): 62 | state = self.TA_lstm(x, t, state) 63 | AllStates.append(state[0]) 64 | -------------------------------------------------------------------------------- /Models/TLSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class TLSTM(nn.Module): 6 | def __init__(self, input_dim, output_dim, hidden_dim, fc_dim, train): 7 | super(TLSTM, self).__init__() 8 | self.input_dim = input_dim 9 | self.hidden_dim = hidden_dim 10 | self.fc_dim = fc_dim 11 | self.train = train 12 | 13 | self.Wi = nn.Linear(input_dim, hidden_dim) 14 | self.Ui = nn.Linear(hidden_dim, hidden_dim) 15 | self.bi = nn.Linear(hidden_dim, hidden_dim) 16 | 17 | self.Wf = nn.Linear(input_dim, hidden_dim) 18 | self.Uf = nn.Linear(hidden_dim, hidden_dim) 19 | self.bf = nn.Linear(hidden_dim, hidden_dim) 20 | 21 | self.Wog = nn.Linear(input_dim, hidden_dim) 22 | self.Uog = nn.Linear(hidden_dim, hidden_dim) 23 | self.bog = nn.Linear(hidden_dim, hidden_dim) 24 | 25 | self.Wc = nn.Linear(input_dim, hidden_dim) 26 | self.Uc = nn.Linear(hidden_dim, hidden_dim) 27 | self.bc = nn.Linear(hidden_dim, hidden_dim) 28 | 29 | self.W_decomp = nn.Linear(hidden_dim, hidden_dim) 30 | self.b_decomp = nn.Linear(hidden_dim, hidden_dim) 31 | 32 | self.Wo = nn.Linear(hidden_dim, fc_dim) 33 | self.bo = nn.Linear(fc_dim, fc_dim) 34 | 35 | self.W_softmax = nn.Linear(fc_dim, output_dim) 36 | self.b_softmax = nn.Linear(output_dim, output_dim) 37 | 38 | def forward(self, input, labels, time, keep_prob, hidden, prev_cell): 39 | # time decay 40 | T = self.map_elapse_time(time) 41 | 42 | # Decompose the previous cell if there is a elapse time 43 | C_ST = F.tanh(self.W_decomp(prev_cell) + self.b_decomp) 44 | C_ST_dis = torch.matmul(T, C_ST) 45 | # if T is 0, then the weight is one 46 | prev_cell = prev_cell - C_ST + C_ST_dis 47 | 48 | # input gate 49 | i = F.sigmoid(self.Wi(input) + self.Ui(hidden) + self.bi(hidden)) 50 | 51 | # forget gate 52 | f = F.sigmoid(self.Wf(input) + self.Uf(hidden) + self.bf(hidden)) 53 | 54 | # output gate 55 | og = F.sigmoid(self.Wog(input) + self.Uog(hidden) + self.bog(hidden)) 56 | 57 | # state 58 | state = F.tanh(self.Wc(input) + self.Uc(hidden) + self.bc(hidden)) 59 | c = f * prev_cell + i * state 60 | 61 | # ct-1 decomp 62 | hidden = og * F.tanh(c) 63 | 64 | return hidden, c 65 | -------------------------------------------------------------------------------- /Models/TorchModels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import wandb 4 | import numpy as np 5 | from torch.fft import fft 6 | 7 | class CVE(nn.Module): 8 | """ 9 | CVE (Custom Variational Encoder) is a neural network module that performs a linear transformation 10 | followed by a non-linear activation and another linear transformation. 11 | """ 12 | def __init__(self, hid_units, output_dim): 13 | """ 14 | Initializes the CVE module. 15 | 16 | Parameters: 17 | hid_units (int): Number of hidden units in the first linear layer. 18 | output_dim (int): Dimension of the output. 19 | """ 20 | super(CVE, self).__init__() 21 | self.hid_units = hid_units 22 | self.output_dim = output_dim 23 | self.W1 = nn.Linear(1, hid_units, bias=True) # First linear layer 24 | self.W2 = nn.Linear(hid_units, output_dim, bias=False) # Second linear layer 25 | self.reset_parameters() # Initialize parameters with Xavier uniform distribution 26 | 27 | def forward(self, x): 28 | """ 29 | Defines the forward pass of the CVE module. 30 | 31 | Parameters: 32 | x (Tensor): Input tensor. 33 | 34 | Returns: 35 | Tensor: Output tensor after applying the linear transformations and activation. 36 | """ 37 | x = x.unsqueeze(-1) # Add an extra dimension to the input tensor 38 | x = self.W2(torch.tanh(self.W1(x))) # Apply first linear layer, tanh activation, and second linear layer 39 | return x 40 | 41 | def reset_parameters(self): 42 | """ 43 | Initializes the parameters of the linear layers using Xavier uniform distribution. 44 | """ 45 | nn.init.xavier_uniform_(self.W1.weight) # Initialize weights of the first linear layer 46 | nn.init.xavier_uniform_(self.W2.weight) # Initialize weights of the second linear layer 47 | torch.nn.init.zeros_(self.W1.bias) # Initialize biases of the first linear layer to zero 48 | 49 | def compute_output_shape(self, input_shape): 50 | """ 51 | Computes the output shape of the CVE module given the input shape. 52 | 53 | Parameters: 54 | input_shape (tuple): Shape of the input tensor. 55 | 56 | Returns: 57 | tuple: Shape of the output tensor. 58 | """ 59 | return input_shape + (self.output_dim,) 60 | 61 | 62 | class Attention(nn.Module): 63 | """ 64 | Attention is a neural network module that computes attention weights for input sequences. 65 | """ 66 | 67 | def __init__(self, hid_dim, d): 68 | """ 69 | Initializes the Attention module. 70 | 71 | Parameters: 72 | hid_dim (int): Number of hidden units in the attention mechanism. 73 | d (int): Dimension of the input features. 74 | """ 75 | super(Attention, self).__init__() 76 | self.hid_dim = hid_dim 77 | self.W = nn.Linear(d, self.hid_dim, bias=True) # Linear layer to transform input features 78 | self.u = nn.Linear(self.hid_dim, 1, bias=False) # Linear layer to compute attention scores 79 | self.softmax = nn.Softmax(-2) # Softmax activation to normalize attention scores 80 | self.reset_parameters() # Initialize parameters with Xavier uniform distribution 81 | 82 | def forward(self, x, mask=None, mask_value=-1e30): 83 | """ 84 | Defines the forward pass of the Attention module. 85 | 86 | Parameters: 87 | x (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 88 | mask (Tensor, optional): Mask tensor to apply to the attention weights. Defaults to None. 89 | mask_value (float, optional): Value to use for masked positions. Defaults to -1e30. 90 | 91 | Returns: 92 | Tensor: Attention weights of shape (batch_size, sequence_length, 1). 93 | """ 94 | if mask is None: 95 | mask = torch.ones([x.shape[0], x.shape[1]], device=x.device) # Create a mask of ones if not provided 96 | attn_weights = self.u(torch.tanh(self.W(x))) # Compute attention scores 97 | mask = mask.unsqueeze(-1) # Add an extra dimension to the mask 98 | attn_weights = mask * attn_weights + (1 - mask) * mask_value # Apply mask to attention scores 99 | attn_weights = self.softmax(attn_weights) # Normalize attention scores with softmax 100 | return attn_weights 101 | 102 | def compute_output_shape(self, input_shape): 103 | """ 104 | Computes the output shape of the Attention module given the input shape. 105 | 106 | Parameters: 107 | input_shape (tuple): Shape of the input tensor. 108 | 109 | Returns: 110 | tuple: Shape of the output tensor. 111 | """ 112 | return input_shape[:-1] + (1,) 113 | 114 | def reset_parameters(self): 115 | """ 116 | Initializes the parameters of the linear layers using Xavier uniform distribution. 117 | """ 118 | nn.init.xavier_uniform_(self.W.weight) # Initialize weights of the first linear layer 119 | nn.init.xavier_uniform_(self.u.weight) # Initialize weights of the second linear layer 120 | torch.nn.init.zeros_(self.W.bias) # Initialize biases of the first linear layer to zero 121 | 122 | class Transformer(nn.Module): 123 | """ 124 | Transformer is a neural network module that implements the Transformer architecture. 125 | """ 126 | 127 | def __init__(self, d, N=2, h=8, dk=None, dv=None, dff=None, dropout=0): 128 | """ 129 | Initializes the Transformer module. 130 | 131 | Parameters: 132 | d (int): Dimension of the input features. 133 | N (int): Number of layers. Defaults to 2. 134 | h (int): Number of heads. Defaults to 8. 135 | dk (int, optional): Dimension of the key vectors. Defaults to d // h. 136 | dv (int, optional): Dimension of the value vectors. Defaults to d // h. 137 | dff (int, optional): Dimension of the feed-forward network. Defaults to 2 * d. 138 | dropout (float, optional): Dropout rate. Defaults to 0. 139 | """ 140 | super(Transformer, self).__init__() 141 | self.N, self.h, self.dk, self.dv, self.dff, self.dropout = N, h, dk, dv, dff, dropout 142 | eps = torch.finfo(torch.float32).eps 143 | self.epsilon = eps * eps 144 | if self.dk is None: 145 | self.dk = d // self.h 146 | if self.dv is None: 147 | self.dv = d // self.h 148 | if self.dff is None: 149 | self.dff = 2 * d 150 | self.Wq = nn.Parameter(torch.empty((self.N, self.h, d, self.dk))) 151 | self.Wk = nn.Parameter(torch.empty((self.N, self.h, d, self.dk))) 152 | self.Wv = nn.Parameter(torch.empty((self.N, self.h, d, self.dv))) 153 | self.Wo = nn.Parameter(torch.empty((self.N, self.dv * self.h, d))) 154 | self.W1 = nn.Parameter(torch.empty((self.N, d, self.dff))) 155 | self.b1 = nn.Parameter(torch.zeros((self.N, self.dff))) 156 | self.W2 = nn.Parameter(torch.empty((self.N, self.dff, d))) 157 | self.b2 = nn.Parameter(torch.zeros((self.N, d))) 158 | self.gamma = nn.Parameter(torch.zeros((2 * self.N,))) 159 | self.beta = nn.Parameter(torch.zeros((2 * self.N,))) 160 | self.dropoutA = nn.Dropout(p=self.dropout) 161 | self.dropoutproj = nn.Dropout(p=self.dropout) 162 | self.dropoutffn = nn.Dropout(p=self.dropout) 163 | self.reset_parameters() 164 | 165 | def forward(self, x, mask=None, mask_value=-1e-30): 166 | """ 167 | Defines the forward pass of the Transformer module. 168 | 169 | Parameters: 170 | x (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 171 | mask (Tensor, optional): Mask tensor to apply to the attention weights. Defaults to None. 172 | mask_value (float, optional): Value to use for masked positions. Defaults to -1e-30. 173 | 174 | Returns: 175 | Tensor: Output tensor after applying the Transformer layers. 176 | """ 177 | if mask: 178 | mask = mask.unsqueeze(-2) 179 | else: 180 | mask = torch.ones([x.shape[0], x.shape[1], 1], device=x.device) 181 | for i in range(self.N): 182 | # Multi-Head Attention (MHA) 183 | mha_ops = [] 184 | for j in range(self.h): 185 | q = torch.matmul(x, self.Wq[i, j, :, :]) 186 | k = torch.matmul(x, self.Wk[i, j, :, :]).permute((0, 2, 1)) 187 | v = torch.matmul(x, self.Wv[i, j, :, :]) 188 | A = torch.matmul(q, k) 189 | # Mask unobserved steps. 190 | A = mask * A + (1 - mask) * mask_value 191 | # Mask for attention dropout. 192 | A = self.dropoutA(A) 193 | A = nn.Softmax(dim=-1)(A) 194 | mha_ops.append(torch.matmul(A, v)) 195 | conc = torch.cat(mha_ops, dim=-1) 196 | proj = torch.matmul(conc, self.Wo[i, :, :]) 197 | # Dropout. 198 | proj = self.dropoutproj(proj) 199 | # Add & Layer Normalization (LN) 200 | x = x + proj 201 | mean = torch.mean(x, dim=-1, keepdim=True) 202 | variance = torch.mean((x - mean) ** 2, dim=-1, keepdim=True) 203 | std = torch.sqrt(variance + self.epsilon) 204 | x = (x - mean) / std 205 | x = x * self.gamma[2 * i] + self.beta[2 * i] 206 | # Feed-Forward Network (FFN) 207 | ffn_op = torch.matmul(nn.ReLU()(torch.matmul(x, self.W1[i, :, :]) + self.b1[i, :]), self.W2[i, :, :]) + self.b2[i, :, ] 208 | # Dropout. 209 | ffn_op = self.dropoutffn(ffn_op) 210 | # Add & Layer Normalization (LN) 211 | x = x + ffn_op 212 | mean = torch.mean(x, dim=-1, keepdim=True) 213 | variance = torch.mean((x - mean) ** 2, dim=-1, keepdim=True) 214 | std = torch.sqrt(variance + self.epsilon) 215 | x = (x - mean) / std 216 | x = x * self.gamma[2 * i + 1] + self.beta[2 * i + 1] 217 | return x 218 | 219 | def compute_output_shape(self, input_shape): 220 | """ 221 | Computes the output shape of the Transformer module given the input shape. 222 | 223 | Parameters: 224 | input_shape (tuple): Shape of the input tensor. 225 | 226 | Returns: 227 | tuple: Shape of the output tensor. 228 | """ 229 | return input_shape 230 | 231 | def reset_parameters(self): 232 | """ 233 | Initializes the parameters of the linear layers using Xavier uniform distribution. 234 | """ 235 | nn.init.xavier_uniform_(self.Wq) 236 | nn.init.xavier_uniform_(self.Wk) 237 | nn.init.xavier_uniform_(self.Wv) 238 | nn.init.xavier_uniform_(self.Wo) 239 | nn.init.xavier_uniform_(self.W1) 240 | torch.nn.init.zeros_(self.b1) 241 | nn.init.xavier_uniform_(self.W2) 242 | torch.nn.init.zeros_(self.b2) 243 | torch.nn.init.ones_(self.gamma) 244 | torch.nn.init.zeros_(self.beta) 245 | 246 | class RNNModel(nn.Module): 247 | """ 248 | RNNModel is a neural network module that implements a recurrent neural network (RNN) using LSTM. 249 | """ 250 | 251 | def __init__(self, dropout, h, d, numfeats, numLayers=1, bidirectional=True): 252 | """ 253 | Initializes the RNNModel module. 254 | 255 | Parameters: 256 | dropout (float): Dropout rate. 257 | h (int): Number of hidden units. 258 | d (int): Dimension of the input features. 259 | numfeats (int): Number of input features. 260 | numLayers (int, optional): Number of LSTM layers. Defaults to 1. 261 | bidirectional (bool, optional): Whether to use bidirectional LSTM. Defaults to True. 262 | """ 263 | super(RNNModel, self).__init__() 264 | self.lstm = nn.LSTM(numfeats, h, numLayers, batch_first=True, bidirectional=bidirectional, dropout=dropout) 265 | self.DO1 = nn.Dropout(p=dropout) 266 | 267 | def forward(self, data): 268 | """ 269 | Defines the forward pass of the RNNModel module. 270 | 271 | Parameters: 272 | data (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 273 | 274 | Returns: 275 | Tensor: Output tensor after applying the LSTM and dropout layers. 276 | """ 277 | out, _ = self.lstm(data) 278 | out = self.DO1(out[:, -1, :]) 279 | return out 280 | 281 | def load(self, checkpath): 282 | """ 283 | Loads the model parameters from a checkpoint file. 284 | 285 | Parameters: 286 | checkpath (str): Path to the checkpoint file. 287 | """ 288 | self.load_state_dict(torch.load(checkpath)) 289 | 290 | class TransformerModel(nn.Module): 291 | """ 292 | TransformerModel is a neural network module that implements a Transformer-based model. 293 | """ 294 | 295 | def __init__(self, dropout, h, d, numfeats, numLayers=1, NumHeads=10, bidirectional=True): 296 | """ 297 | Initializes the TransformerModel module. 298 | 299 | Parameters: 300 | dropout (float): Dropout rate. 301 | h (int): Number of hidden units. 302 | d (int): Dimension of the input features. 303 | numfeats (int): Number of input features. 304 | numLayers (int, optional): Number of Transformer layers. Defaults to 1. 305 | NumHeads (int, optional): Number of attention heads. Defaults to 10. 306 | bidirectional (bool, optional): Whether to use bidirectional LSTM. Defaults to True. 307 | """ 308 | super(TransformerModel, self).__init__() 309 | self.linearFirst = nn.Linear(numfeats, h) 310 | self.transformer = Transformer(d=h, N=numLayers, h=NumHeads, dk=None, dv=None, dff=None, dropout=dropout) 311 | self.attn = Attention(2 * h, h) 312 | self.DO1 = nn.Dropout(p=dropout) 313 | 314 | def forward(self, data): 315 | """ 316 | Defines the forward pass of the TransformerModel module. 317 | 318 | Parameters: 319 | data (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 320 | 321 | Returns: 322 | Tensor: Output tensor after applying the Transformer, attention, and dropout layers. 323 | """ 324 | out = self.linearFirst(data) 325 | out = self.transformer(out) 326 | attn_weights = self.attn(out) 327 | out = torch.sum(out * attn_weights, dim=-2) 328 | out = self.DO1(out) 329 | return out 330 | 331 | def load(self, checkpath): 332 | """ 333 | Loads the model parameters from a checkpoint file. 334 | 335 | Parameters: 336 | checkpath (str): Path to the checkpoint file. 337 | """ 338 | self.load_state_dict(torch.load(checkpath)) 339 | 340 | 341 | class BlockFCNConv(nn.Module): 342 | """ 343 | BlockFCNConv is a convolutional block used in the Fully Convolutional Network (FCN). 344 | """ 345 | def __init__(self, in_channel=1, out_channel=128, kernel_size=8, momentum=0.99, epsilon=0.001, squeeze=False): 346 | """ 347 | Initializes the BlockFCNConv module. 348 | 349 | Parameters: 350 | in_channel (int): Number of input channels. 351 | out_channel (int): Number of output channels. 352 | kernel_size (int): Size of the convolutional kernel. 353 | momentum (float): Momentum for batch normalization. 354 | epsilon (float): Epsilon for batch normalization. 355 | squeeze (bool): Whether to apply squeeze operation. 356 | """ 357 | super().__init__() 358 | self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=kernel_size, padding=kernel_size//2) 359 | self.batch_norm = nn.BatchNorm1d(num_features=out_channel, eps=epsilon, momentum=momentum) 360 | self.relu = nn.ReLU() 361 | 362 | def forward(self, x): 363 | """ 364 | Defines the forward pass of the BlockFCNConv module. 365 | 366 | Parameters: 367 | x (Tensor): Input tensor of shape (batch_size, num_variables, time_steps). 368 | 369 | Returns: 370 | Tensor: Output tensor after applying convolution, batch normalization, and ReLU activation. 371 | """ 372 | x = self.conv(x) 373 | x = self.batch_norm(x) 374 | y = self.relu(x) 375 | return y 376 | 377 | class FCN(nn.Module): 378 | """ 379 | FCN is a Fully Convolutional Network for time series data. 380 | """ 381 | def __init__(self, dropout, h, d, numfeats, kernels=[8, 5, 3], kernelsizemult=1.0, mom=0.99, eps=0.001): 382 | """ 383 | Initializes the FCN module. 384 | 385 | Parameters: 386 | dropout (float): Dropout rate. 387 | h (int): Number of hidden units. 388 | d (int): Dimension of the input features. 389 | numfeats (int): Number of input features. 390 | kernels (list): List of kernel sizes for convolutional layers. 391 | kernelsizemult (float): Multiplier for kernel sizes. 392 | mom (float): Momentum for batch normalization. 393 | eps (float): Epsilon for batch normalization. 394 | """ 395 | super().__init__() 396 | kernels = np.array(kernels) * kernelsizemult 397 | kernels = kernels.astype(int) 398 | channels = [h, 2*h, h] 399 | self.conv1 = BlockFCNConv(numfeats, channels[0], kernels[0], momentum=mom, epsilon=eps, squeeze=True) 400 | self.conv2 = BlockFCNConv(channels[0], channels[1], kernels[1], momentum=mom, epsilon=eps, squeeze=True) 401 | self.conv3 = BlockFCNConv(channels[1], channels[2], kernels[2], momentum=mom, epsilon=eps) 402 | self.DO1 = nn.Dropout(p=dropout) 403 | 404 | def forward(self, x): 405 | """ 406 | Defines the forward pass of the FCN module. 407 | 408 | Parameters: 409 | x (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 410 | 411 | Returns: 412 | Tensor: Output tensor after applying convolutional layers and dropout. 413 | """ 414 | x = x.transpose(1, 2) 415 | x = self.conv1(x) 416 | x = self.conv2(x) 417 | x = self.conv3(x) 418 | x = self.DO1(x) 419 | y = x.mean(axis=-1) 420 | return y 421 | 422 | class FCN_perchannel(nn.Module): 423 | """ 424 | FCN_perchannel is a Fully Convolutional Network that processes each channel separately. 425 | """ 426 | def __init__(self, dropout, h, d, numfeats, kernels=[8, 5, 3], kernelsizemult=1.0, mom=0.99, eps=0.001, input_size_perchannel=1): 427 | """ 428 | Initializes the FCN_perchannel module. 429 | 430 | Parameters: 431 | dropout (float): Dropout rate. 432 | h (int): Number of hidden units. 433 | d (int): Dimension of the input features. 434 | numfeats (int): Number of input features. 435 | kernels (list): List of kernel sizes for convolutional layers. 436 | kernelsizemult (float): Multiplier for kernel sizes. 437 | mom (float): Momentum for batch normalization. 438 | eps (float): Epsilon for batch normalization. 439 | input_size_perchannel (int): Size of the input for each channel. 440 | """ 441 | super().__init__() 442 | kernels = np.array(kernels) * kernelsizemult 443 | kernels = kernels.astype(int) 444 | channels = [h, 2*h, h] 445 | self.FCNs = [] 446 | self.finalAct = nn.ReLU() 447 | for i in range(numfeats): 448 | fcn_convs = nn.Sequential( 449 | BlockFCNConv(input_size_perchannel, channels[0], kernels[0], momentum=mom, epsilon=eps, squeeze=True), 450 | BlockFCNConv(channels[0], channels[1], kernels[1], momentum=mom, epsilon=eps, squeeze=True), 451 | BlockFCNConv(channels[1], channels[2], kernels[2], momentum=mom, epsilon=eps), 452 | nn.Dropout(p=dropout)) 453 | self.FCNs.append(fcn_convs) 454 | self.FCNs = nn.ModuleList(self.FCNs) 455 | self.LastLinear = nn.Linear(numfeats * h, h) 456 | print('numfeats', numfeats) 457 | 458 | def forward(self, x): 459 | """ 460 | Defines the forward pass of the FCN_perchannel module. 461 | 462 | Parameters: 463 | x (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 464 | 465 | Returns: 466 | Tensor: Output tensor after applying convolutional layers and dropout. 467 | """ 468 | Outs = [] 469 | for i in range(len(x)): 470 | if len(x[i].shape) == 2: 471 | Out = self.FCNs[i](x[i][:, None, :]) 472 | else: 473 | Out = self.FCNs[i](x[i]) 474 | Out = Out.mean(axis=-1) 475 | Outs.append(Out) 476 | y = torch.cat(Outs, 1) 477 | y = self.LastLinear(y) 478 | return y 479 | 480 | class Transformer_perchannel(nn.Module): 481 | """ 482 | Transformer_perchannel is a Transformer-based model that processes each channel separately. 483 | """ 484 | def __init__(self, dropout, h, d, numfeats, numLayers=1, NumHeads=10, bidirectional=True): 485 | """ 486 | Initializes the Transformer_perchannel module. 487 | 488 | Parameters: 489 | dropout (float): Dropout rate. 490 | h (int): Number of hidden units. 491 | d (int): Dimension of the input features. 492 | numfeats (int): Number of input features. 493 | numLayers (int, optional): Number of Transformer layers. Defaults to 1. 494 | NumHeads (int, optional): Number of attention heads. Defaults to 10. 495 | bidirectional (bool, optional): Whether to use bidirectional LSTM. Defaults to True. 496 | """ 497 | super().__init__() 498 | self.Transformers = [] 499 | self.finalAct = nn.ReLU() 500 | for i in range(numfeats): 501 | transformermodel = TransformerModel(dropout=dropout, h=h, d=d, numfeats=1, numLayers=numLayers, NumHeads=NumHeads, bidirectional=bidirectional) 502 | self.Transformers.append(transformermodel) 503 | self.Transformers = nn.ModuleList(self.Transformers) 504 | self.LastLinear = nn.Linear(numfeats * h, h) 505 | 506 | def forward(self, x): 507 | """ 508 | Defines the forward pass of the Transformer_perchannel module. 509 | 510 | Parameters: 511 | x (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 512 | 513 | Returns: 514 | Tensor: Output tensor after applying Transformer layers and dropout. 515 | """ 516 | Outs = [] 517 | for i in range(len(x)): 518 | Out = self.Transformers[i](x[i][:, :, None]) 519 | Outs.append(Out) 520 | y = torch.cat(Outs, 1) 521 | y = self.LastLinear(y) 522 | return y 523 | 524 | class CNNAttnModel(nn.Module): 525 | """ 526 | CNNAttnModel is a neural network module that combines CNN and attention mechanisms. 527 | """ 528 | def __init__(self, dropout, h, d, numfeats, kernelsize=3, numLayers=1, bidirectional=True): 529 | """ 530 | Initializes the CNNAttnModel module. 531 | 532 | Parameters: 533 | dropout (float): Dropout rate. 534 | h (int): Number of hidden units. 535 | d (int): Dimension of the input features. 536 | numfeats (int): Number of input features. 537 | kernelsize (int, optional): Size of the convolutional kernel. Defaults to 3. 538 | numLayers (int, optional): Number of CNN layers. Defaults to 1. 539 | bidirectional (bool, optional): Whether to use bidirectional LSTM. Defaults to True. 540 | """ 541 | super(CNNAttnModel, self).__init__() 542 | self.cnn_layers = nn.Sequential( 543 | nn.Conv1d(numfeats, h, kernel_size=kernelsize, padding=kernelsize//2), 544 | nn.BatchNorm1d(h), 545 | nn.ReLU(inplace=True), 546 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 547 | ) 548 | for i in range(1, numLayers): 549 | self.cnn_layers.add_module("conv_" + str(i), nn.Conv1d(h, h, kernel_size=kernelsize, padding=kernelsize//2)) 550 | self.cnn_layers.add_module("batchnorm_" + str(i), nn.BatchNorm1d(h)) 551 | self.cnn_layers.add_module("relu_" + str(i), nn.ReLU(inplace=True)) 552 | self.cnn_layers.add_module("maxpool_" + str(i), nn.MaxPool1d(kernel_size=2, stride=2, padding=1)) 553 | self.DO1 = nn.Dropout(p=dropout) 554 | self.attn = Attention(2*h, h) 555 | 556 | def forward(self, data): 557 | """ 558 | Defines the forward pass of the CNNAttnModel module. 559 | 560 | Parameters: 561 | data (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 562 | 563 | Returns: 564 | Tensor: Output tensor after applying CNN and attention layers. 565 | """ 566 | out = self.cnn_layers(data.transpose(1, 2)) 567 | attn_weights = self.attn(out.transpose(1, 2)) 568 | out = torch.sum(out.transpose(1, 2) * attn_weights, dim=-2) 569 | self.DO1(out) 570 | return out 571 | 572 | def load(self, checkpath): 573 | """ 574 | Loads the model parameters from a checkpoint file. 575 | 576 | Parameters: 577 | checkpath (str): Path to the checkpoint file. 578 | """ 579 | self.load_state_dict(torch.load(checkpath)) 580 | 581 | class CNNAttn_perchannel(nn.Module): 582 | """ 583 | CNNAttn_perchannel is a CNN and attention-based model that processes each channel separately. 584 | """ 585 | def __init__(self, dropout, h, d, numfeats, kernelsize=3, numLayers=1, bidirectional=True): 586 | """ 587 | Initializes the CNNAttn_perchannel module. 588 | 589 | Parameters: 590 | dropout (float): Dropout rate. 591 | h (int): Number of hidden units. 592 | d (int): Dimension of the input features. 593 | numfeats (int): Number of input features. 594 | kernelsize (int, optional): Size of the convolutional kernel. Defaults to 3. 595 | numLayers (int, optional): Number of CNN layers. Defaults to 1. 596 | bidirectional (bool, optional): Whether to use bidirectional LSTM. Defaults to True. 597 | """ 598 | super(CNNAttn_perchannel, self).__init__() 599 | self.CNNmodels = [] 600 | self.finalAct = nn.ReLU() 601 | for i in range(numfeats): 602 | cnnmodel = CNNAttnModel(dropout=dropout, h=h, d=d, numfeats=1, kernelsize=kernelsize, numLayers=numLayers, bidirectional=bidirectional) 603 | self.CNNmodels.append(cnnmodel) 604 | self.CNNmodels = nn.ModuleList(self.CNNmodels) 605 | self.LastLinear = nn.Linear(numfeats * h, h) 606 | 607 | def forward(self, x): 608 | """ 609 | Defines the forward pass of the CNNAttn_perchannel module. 610 | 611 | Parameters: 612 | x (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 613 | 614 | Returns: 615 | Tensor: Output tensor after applying CNN and attention layers. 616 | """ 617 | Outs = [] 618 | for i in range(len(x)): 619 | Out = self.CNNmodels[i](x[i][:, :, None]) 620 | Outs.append(Out) 621 | y = torch.cat(Outs, 1) 622 | y = self.LastLinear(y) 623 | return y 624 | 625 | class CNNLSTMModel(nn.Module): 626 | """ 627 | CNNLSTMModel is a neural network module that combines CNN and LSTM layers. 628 | """ 629 | def __init__(self, dropout, h, d, numfeats, kernelsize=3, numLayers=1, bidirectional=True): 630 | """ 631 | Initializes the CNNLSTMModel module. 632 | 633 | Parameters: 634 | dropout (float): Dropout rate. 635 | h (int): Number of hidden units. 636 | d (int): Dimension of the input features. 637 | numfeats (int): Number of input features. 638 | kernelsize (int, optional): Size of the convolutional kernel. Defaults to 3. 639 | numLayers (int, optional): Number of CNN layers. Defaults to 1. 640 | bidirectional (bool, optional): Whether to use bidirectional LSTM. Defaults to True. 641 | """ 642 | super(CNNLSTMModel, self).__init__() 643 | self.cnn_layers = nn.Sequential( 644 | nn.Conv1d(numfeats, h, kernel_size=kernelsize, padding=kernelsize//2), 645 | nn.BatchNorm1d(h), 646 | nn.ReLU(inplace=True), 647 | nn.MaxPool1d(kernel_size=2, stride=2, padding=1), 648 | ) 649 | for i in range(1, numLayers): 650 | self.cnn_layers.add_module("conv_" + str(i), nn.Conv1d(h, h, kernel_size=kernelsize, padding=kernelsize//2)) 651 | self.cnn_layers.add_module("batchnorm_" + str(i), nn.BatchNorm1d(h)) 652 | self.cnn_layers.add_module("relu_" + str(i), nn.ReLU(inplace=True)) 653 | self.cnn_layers.add_module("maxpool_" + str(i), nn.MaxPool1d(kernel_size=2, stride=2, padding=1)) 654 | self.DO1 = nn.Dropout(p=dropout) 655 | self.lstm = nn.LSTM(h, h, 1, batch_first=True, bidirectional=bidirectional, dropout=dropout) 656 | 657 | def forward(self, data): 658 | """ 659 | Defines the forward pass of the CNNLSTMModel module. 660 | 661 | Parameters: 662 | data (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 663 | 664 | Returns: 665 | Tensor: Output tensor after applying CNN and LSTM layers. 666 | """ 667 | out = self.cnn_layers(data.transpose(1, 2)) 668 | out = self.DO1(out.transpose(1, 2)) 669 | out, _ = self.lstm(out) 670 | out = out[:, -1, :] 671 | return out 672 | 673 | def load(self, checkpath): 674 | """ 675 | Loads the model parameters from a checkpoint file. 676 | 677 | Parameters: 678 | checkpath (str): Path to the checkpoint file. 679 | """ 680 | self.load_state_dict(torch.load(checkpath)) 681 | 682 | 683 | class TA_LSTMCell(nn.Module): 684 | """ 685 | TA_LSTMCell is a custom LSTM cell that incorporates time-aware mechanisms. 686 | """ 687 | def __init__(self, input_size, hidden_size): 688 | """ 689 | Initializes the TA_LSTMCell module. 690 | 691 | Parameters: 692 | input_size (int): Dimension of the input features. 693 | hidden_size (int): Number of hidden units. 694 | """ 695 | super(TA_LSTMCell, self).__init__() 696 | self.input_size = input_size 697 | self.hidden_size = hidden_size 698 | self.weight_ih = nn.Parameter(torch.randn(4 * hidden_size, input_size)) 699 | self.weight_hh = nn.Parameter(torch.randn(4 * hidden_size, hidden_size)) 700 | self.bias_ih = nn.Parameter(torch.randn(4 * hidden_size)) 701 | self.bias_hh = nn.Parameter(torch.randn(4 * hidden_size)) 702 | self.W_decomp = nn.Parameter(torch.randn(hidden_size, hidden_size)) 703 | self.b_decomp = nn.Parameter(torch.randn(hidden_size)) 704 | 705 | def g(self, t): 706 | """ 707 | Computes the time-aware gating mechanism. 708 | 709 | Parameters: 710 | t (Tensor): Time intervals. 711 | 712 | Returns: 713 | Tensor: Time-aware gating values. 714 | """ 715 | T = torch.zeros_like(t).to(t.device) 716 | T[t.nonzero(as_tuple=True)] = 1 / t[t.nonzero(as_tuple=True)] 717 | 718 | Ones = torch.ones([1, self.hidden_size], dtype=torch.float32).to(t.device) 719 | T = torch.mm(T, Ones) 720 | return T 721 | 722 | def forward(self, input, t, state): 723 | """ 724 | Defines the forward pass of the TA_LSTMCell module. 725 | 726 | Parameters: 727 | input (Tensor): Input tensor. 728 | t (Tensor): Time intervals. 729 | state (Tuple[Tensor, Tensor]): Previous hidden and cell states. 730 | 731 | Returns: 732 | Tuple[Tensor, Tensor]: Updated hidden and cell states. 733 | """ 734 | hx, cx = state 735 | gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih + 736 | torch.mm(hx, self.weight_hh.t()) + self.bias_hh) 737 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 738 | 739 | T = self.g(t) 740 | 741 | C_ST = torch.tanh(torch.mm(cx, self.W_decomp) + self.b_decomp) 742 | C_ST_dis = T * C_ST 743 | cx = cx - C_ST + C_ST_dis 744 | 745 | ingate = torch.sigmoid(ingate) 746 | forgetgate = torch.sigmoid(forgetgate) 747 | cellgate = torch.tanh(cellgate) 748 | outgate = torch.sigmoid(outgate) 749 | 750 | cy = (forgetgate * cx) + (ingate * cellgate) 751 | hy = outgate * torch.tanh(cy) 752 | 753 | return (hy, cy) 754 | 755 | class TA_LSTM(nn.Module): 756 | """ 757 | TA_LSTM is a custom LSTM module that incorporates time-aware mechanisms. 758 | """ 759 | def __init__(self, input_size, hidden_size): 760 | """ 761 | Initializes the TA_LSTM module. 762 | 763 | Parameters: 764 | input_size (int): Dimension of the input features. 765 | hidden_size (int): Number of hidden units. 766 | """ 767 | super(TA_LSTM, self).__init__() 768 | self.TA_lstm = TA_LSTMCell(input_size, hidden_size) 769 | self.hidden_size = hidden_size 770 | 771 | def forward(self, X, time): 772 | """ 773 | Defines the forward pass of the TA_LSTM module. 774 | 775 | Parameters: 776 | X (Tensor): Input tensor of shape (batch_size, sequence_length, feature_dim). 777 | time (Tensor): Time intervals. 778 | 779 | Returns: 780 | Tensor: Output tensor after applying the TA_LSTM layers. 781 | """ 782 | time = time[None, :, None].repeat(X.shape[0], 1, 1).float().to(X.device) 783 | c = torch.zeros([X.shape[0], self.hidden_size]).to(X.device) 784 | h = torch.zeros([X.shape[0], self.hidden_size]).to(X.device) 785 | state = (h, c) 786 | AllStates = [] 787 | time_steps = X.shape[1] 788 | for i in range(time_steps): 789 | state = self.TA_lstm(X[:, i, :], time[:, i, :], state) 790 | AllStates.append(state[0]) 791 | return state[0] 792 | 793 | def get_Model(Comp, dropout, multiplier, h, d, NumFeats, numLayers, bidirectional, config, input_size_perchannel=1): 794 | """ 795 | Returns the appropriate model based on the specified component type. 796 | 797 | Parameters: 798 | Comp (str): Component type. 799 | dropout (float): Dropout rate. 800 | multiplier (int): Multiplier for hidden units. 801 | h (int): Number of hidden units. 802 | d (int): Dimension of the input features. 803 | NumFeats (int): Number of input features. 804 | numLayers (int): Number of layers. 805 | bidirectional (bool): Whether to use bidirectional LSTM. 806 | config (dict): Configuration dictionary. 807 | input_size_perchannel (int, optional): Size of the input for each channel. Defaults to 1. 808 | 809 | Returns: 810 | nn.Module: The appropriate model. 811 | """ 812 | if (Comp == 'LSTM') or (Comp == 'BiLSTM'): 813 | model = RNNModel(dropout, h, d, NumFeats, numLayers=config['NumLayers'], bidirectional=bidirectional) 814 | elif Comp == 'Transformer': 815 | model = TransformerModel(dropout, multiplier*h, d, NumFeats, numLayers=config['NumLayers'], NumHeads=config['NumHeads'], bidirectional=bidirectional) 816 | elif Comp == 'CNNAttn': 817 | model = CNNAttnModel(dropout, multiplier*h, d, NumFeats, kernelsize=config['CNNKernelSize'], numLayers=config['NumLayers'], bidirectional=bidirectional) 818 | elif Comp == 'CNNLSTM': 819 | model = CNNLSTMModel(dropout, h, d, NumFeats, kernelsize=config['CNNKernelSize'], numLayers=config['NumLayers'], bidirectional=bidirectional) 820 | elif Comp == 'FCN': 821 | model = FCN(dropout, h, d, NumFeats, kernelsizemult=config['FCNKernelMult']) 822 | elif Comp == 'FCN_perchannel': 823 | model = FCN_perchannel(dropout, h, d, NumFeats, kernelsizemult=config['FCNKernelMult'], input_size_perchannel=input_size_perchannel) 824 | elif Comp == 'Transformer_perchannel': 825 | model = Transformer_perchannel(dropout, multiplier*h, d, NumFeats, numLayers=config['NumLayers'], NumHeads=config['NumHeads'], bidirectional=bidirectional) 826 | elif Comp == 'CNNAttn_perchannel': 827 | model = CNNAttn_perchannel(dropout, multiplier*h, d, NumFeats, kernelsize=config['CNNKernelSize'], numLayers=config['NumLayers'], bidirectional=bidirectional) 828 | elif Comp == 'TLSTM': 829 | model = TA_LSTM(NumFeats, h) 830 | else: 831 | raise ValueError('Comp value provided not found: ' + Comp) 832 | return model 833 | 834 | class Modelfreq(nn.Module): 835 | """ 836 | Modelfreq is a neural network module that combines multiple frequency models. 837 | """ 838 | def __init__(self, dropout, hs, d, numfreqs, NumFeats, Fusion, Comp='LSTM', classification=True, bidirectional=True, UseExtraLinear=False, config=None, regularized=True): 839 | """ 840 | Initializes the Modelfreq module. 841 | 842 | Parameters: 843 | dropout (float): Dropout rate. 844 | hs (list): List of hidden units for each frequency model. 845 | d (int): Dimension of the input features. 846 | numfreqs (int): Number of frequency models. 847 | NumFeats (list): List of input features for each frequency model. 848 | Fusion (nn.Module): Fusion module to combine frequency models. 849 | Comp (str, optional): Component type. Defaults to 'LSTM'. 850 | classification (bool, optional): Whether the task is classification. Defaults to True. 851 | bidirectional (bool, optional): Whether to use bidirectional LSTM. Defaults to True. 852 | UseExtraLinear (bool, optional): Whether to use extra linear layer in fusion. Defaults to False. 853 | config (dict, optional): Configuration dictionary. Defaults to None. 854 | regularized (bool, optional): Whether to apply regularization. Defaults to True. 855 | """ 856 | super(Modelfreq, self).__init__() 857 | self.numfreqs = numfreqs 858 | j = 0 859 | self.modelindx = {} 860 | Ms = [] 861 | for i in range(numfreqs): 862 | multiplier = 2 if bidirectional else 1 863 | self.modelindx[i] = j 864 | if hs[i] != 0: 865 | Ms.append(get_Model(Comp, dropout, multiplier, hs[i], d, NumFeats[i], config['NumLayers'], bidirectional, config)) 866 | # if (Comp == 'LSTM') or (Comp == 'BiLSTM'): 867 | # Ms.append(RNNModel(dropout, hs[i], d, NumFeats[i], numLayers=config['NumLayers'], bidirectional=bidirectional)) 868 | # elif Comp == 'Transformer': 869 | # Ms.append(TransformerModel(dropout, multiplier*hs[i], d, NumFeats[i], numLayers=config['NumLayers'], NumHeads=config['NumHeads'], bidirectional=bidirectional)) 870 | # elif Comp == 'CNNAttn': 871 | # Ms.append(CNNAttnModel(dropout, multiplier*hs[i], d, NumFeats[i], kernelsize=config['CNNKernelSize'], numLayers=config['NumLayers'], bidirectional=bidirectional)) 872 | # elif Comp == 'CNNLSTM': 873 | # Ms.append(CNNLSTMModel(dropout, hs[i], d, NumFeats[i], kernelsize=config['CNNKernelSize'], numLayers=config['NumLayers'], bidirectional=bidirectional)) 874 | # elif Comp == 'FCN': 875 | # Ms.append(FCN(dropout, hs[i], d, NumFeats[i], kernelsizemult=config['FCNKernelMult'])) 876 | # elif Comp == 'FCN_perchannel': 877 | # Ms.append(FCN_perchannel(dropout, hs[i], d, NumFeats[i], kernelsizemult=config['FCNKernelMult'])) 878 | # elif Comp == 'Transformer_perchannel': 879 | # Ms.append(Transformer_perchannel(dropout, multiplier*hs[i], d, NumFeats[i], numLayers=config['NumLayers'], NumHeads=config['NumHeads'], bidirectional=bidirectional)) 880 | # elif Comp == 'CNNAttn_perchannel': 881 | # Ms.append(CNNAttn_perchannel(dropout, multiplier*hs[i], d, NumFeats[i], kernelsize=config['CNNKernelSize'], numLayers=config['NumLayers'], bidirectional=bidirectional)) 882 | # elif Comp == 'TLSTM': 883 | # Ms.append(TA_LSTM(NumFeats[i], hs[i])) 884 | # else: 885 | # raise ValueError('Comp value provided not found: ' + Comp) 886 | j += 1 887 | print(self.modelindx, j) 888 | self.freqmodels = nn.ModuleList(Ms) 889 | self.hs = hs 890 | NumClasses = 1 891 | if 'NumClasses' in config: 892 | NumClasses = config['NumClasses'] 893 | self.fusion = Fusion(hs, d, NumClasses, bidirectional=bidirectional, useExtralin=UseExtraLinear) 894 | self.finalAct = None 895 | if NumClasses == 1 and classification: 896 | print('Using Sigmoid') 897 | self.finalAct = nn.Sigmoid() 898 | 899 | self.FeatIdxs = None 900 | self.regularized = regularized 901 | if 'times' in config: 902 | self.times = config['times'] 903 | else: 904 | self.times = None 905 | 906 | def forward(self, data): 907 | """ 908 | Defines the forward pass of the Modelfreq module. 909 | 910 | Parameters: 911 | data (list): List of input tensors for each frequency model. 912 | 913 | Returns: 914 | Tuple[Tensor, list]: Output tensor and list of component outputs. 915 | """ 916 | out = [] 917 | compouts = [] 918 | for i in range(self.numfreqs): 919 | if self.hs[i] != 0: 920 | if self.FeatIdxs is not None: 921 | if self.regularized: 922 | data[i] = data[i][:, :, self.FeatIdxs[i]] 923 | else: 924 | data[i] = np.array(data[i])[self.FeatIdxs[i].numpy()] 925 | if self.times: 926 | o = self.freqmodels[self.modelindx[i]](data[i], self.times[i]) 927 | # print('o shape', o.shape) 928 | else: 929 | o = self.freqmodels[self.modelindx[i]](data[i]) 930 | out.append(o) 931 | op, compouts = self.fusion(out) 932 | if self.finalAct: 933 | op = self.finalAct(op) 934 | for i in range(len(compouts)): 935 | compouts[i] = self.finalAct(compouts[i]) 936 | # op = self.LastLinear(op) 937 | # op = op.squeeze() 938 | self.wandblog() 939 | return op, compouts 940 | 941 | def wandblog(self): 942 | """ 943 | Logs mask weights to Weights & Biases. 944 | """ 945 | WeightDict = {} 946 | if self.FeatIdxs is not None: 947 | for i, Ms in enumerate(self.FeatIdxs): 948 | for j, m in enumerate(Ms): 949 | WeightDict['FeatIndxs_' + str(i) + '_' + str(j)] = m.int() 950 | wandb.log(WeightDict, commit=False) 951 | 952 | class model_FFT(nn.Module): 953 | """ 954 | model_FFT is a neural network module that combines FFT and base models. 955 | """ 956 | def __init__(self, dropout, hs, d, numfreqs, NumFeats, Fusion, Comp='LSTM', classification=True, bidirectional=True, UseExtraLinear=False, config=None, regularized=True): 957 | """ 958 | Initializes the model_FFT module. 959 | 960 | Parameters: 961 | dropout (float): Dropout rate. 962 | hs (list): List of hidden units for each frequency model. 963 | d (int): Dimension of the input features. 964 | numfreqs (int): Number of frequency models. 965 | NumFeats (list): List of input features for each frequency model. 966 | Fusion (nn.Module): Fusion module to combine frequency models. 967 | Comp (str, optional): Component type. Defaults to 'LSTM'. 968 | classification (bool, optional): Whether the task is classification. Defaults to True. 969 | bidirectional (bool, optional): Whether to use bidirectional LSTM. Defaults to True. 970 | UseExtraLinear (bool, optional): Whether to use extra linear layer in fusion. Defaults to False. 971 | config (dict, optional): Configuration dictionary. Defaults to None. 972 | regularized (bool, optional): Whether to apply regularization. Defaults to True. 973 | """ 974 | super(model_FFT, self).__init__() 975 | self.regularized = regularized 976 | multiplier = 2 if bidirectional else 1 977 | NumClasses = 1 978 | if 'NumClasses' in config: 979 | NumClasses = config['NumClasses'] 980 | if hs[-1] != 0: 981 | self.basemodel = get_Model(Comp, dropout, multiplier, hs[-1], d, NumFeats[-1], config['NumLayers'], bidirectional, config) 982 | numF = NumFeats[-1] 983 | if regularized: 984 | numF = NumFeats[-1] * 2 985 | self.fftmodel = get_Model(Comp, dropout, multiplier, hs[-1], d, numF, config['NumLayers'], bidirectional, config, input_size_perchannel=2) 986 | self.fc1 = nn.Linear(hs[-1] * 2, d) 987 | self.firstAct = nn.ReLU() 988 | self.fc2 = nn.Linear(d, NumClasses) 989 | self.finalAct = None 990 | if NumClasses == 1 and classification: 991 | print('Using Sigmoid') 992 | self.finalAct = nn.Sigmoid() 993 | 994 | def forward(self, data): 995 | """ 996 | Defines the forward pass of the model_FFT module. 997 | 998 | Parameters: 999 | data (list): List of input tensors for each frequency model. 1000 | 1001 | Returns: 1002 | Tuple[Tensor, list]: Output tensor and list of component outputs. 1003 | """ 1004 | x = data[-1] 1005 | if self.regularized: 1006 | print(f'x.shape {x.shape}') 1007 | fft_out = fft(x, dim=-2) 1008 | print(f'fft_out.shape {fft_out.shape}') 1009 | # fft_out is of shape (batch_size, sequence_length, input_dim) 1010 | real_part = torch.real(fft_out) 1011 | imag_part = torch.imag(fft_out) 1012 | # real_part and imag_part are of shape (batch_size, sequence_length, input_dim) 1013 | print(f'real_part.shape {real_part.shape}') 1014 | fft_in = torch.cat((real_part, imag_part), dim=-1) 1015 | print(f'fft_in.shape {fft_in.shape}') 1016 | else: 1017 | print(f'x shapes {[z.shape for z in x]}') 1018 | fft_outs = [fft(z, dim=-1) for z in x] 1019 | fft_in = [torch.stack((torch.real(fft_out), torch.imag(fft_out)), dim=1) for fft_out in fft_outs] 1020 | print(f'fft_in shapes {[z.shape for z in fft_in]}') 1021 | op = self.basemodel(x) 1022 | fft_out = self.fftmodel(fft_in) 1023 | print(f'op.shape {op.shape}, fft_out.shape {fft_out.shape}') 1024 | concat_out = torch.cat((op, fft_out), dim=-1) 1025 | print(f'concat_out.shape {concat_out.shape}') 1026 | fc1_out = self.firstAct(self.fc1(concat_out)) 1027 | print(f'fc1_out.shape {fc1_out.shape}') 1028 | fc2_out = self.fc2(fc1_out) 1029 | if self.finalAct: 1030 | fc2_out = self.finalAct(fc2_out) 1031 | print(f'fc2_out.shape {fc2_out.shape}') 1032 | return fc2_out, [] 1033 | 1034 | class Modelfreq_featMasks(nn.Module): 1035 | """ 1036 | Modelfreq_featMasks is a neural network module that combines multiple frequency models with feature masks. 1037 | """ 1038 | def __init__(self, dropout, hs, d, numfreqs, NumFeats, Fusion, Comp='LSTM', classification=True, bidirectional=True, UseExtraLinear=False, config=None, MaskWeightInit=0.5, regularized=True): 1039 | """ 1040 | Initializes the Modelfreq_featMasks module. 1041 | 1042 | Parameters: 1043 | dropout (float): Dropout rate. 1044 | hs (list): List of hidden units for each frequency model. 1045 | d (int): Dimension of the input features. 1046 | numfreqs (int): Number of frequency models. 1047 | NumFeats (list): List of input features for each frequency model. 1048 | Fusion (nn.Module): Fusion module to combine frequency models. 1049 | Comp (str, optional): Component type. Defaults to 'LSTM'. 1050 | classification (bool, optional): Whether the task is classification. Defaults to True. 1051 | bidirectional (bool, optional): Whether to use bidirectional LSTM. Defaults to True. 1052 | UseExtraLinear (bool, optional): Whether to use extra linear layer in fusion. Defaults to False. 1053 | config (dict, optional): Configuration dictionary. Defaults to None. 1054 | MaskWeightInit (float, optional): Initial weight for feature masks. Defaults to 0.5. 1055 | regularized (bool, optional): Whether to apply regularization. Defaults to True. 1056 | """ 1057 | super(Modelfreq_featMasks, self).__init__() 1058 | self.numfreqs = numfreqs 1059 | j = 0 1060 | self.modelindx = {} 1061 | FeatMaskWeights = [] 1062 | Ms = [] 1063 | multiplier = 2 if bidirectional else 1 1064 | self.FeatIdxs = [] 1065 | for i in range(numfreqs): 1066 | self.modelindx[i] = j 1067 | self.FeatIdxs.append(torch.ones(NumFeats[i]).bool()) 1068 | if hs[i] != 0: 1069 | Ms.append(get_Model(Comp, dropout, multiplier, hs[i], d, NumFeats[i], config['NumLayers'], bidirectional, config)) 1070 | FeatMaskWeights.append(nn.Parameter(torch.tensor([MaskWeightInit for _ in range(NumFeats[i])]))) 1071 | j += 1 1072 | print(self.modelindx, j) 1073 | self.FeatMaskWeights = nn.ParameterList(FeatMaskWeights) 1074 | self.freqmodels = nn.ModuleList(Ms) 1075 | self.hs = hs 1076 | NumClasses = 1 1077 | if 'NumClasses' in config: 1078 | NumClasses = config['NumClasses'] 1079 | print('NumClasses', NumClasses) 1080 | self.fusion = Fusion(hs, d, NumClasses, bidirectional=bidirectional, useExtralin=UseExtraLinear) 1081 | self.activation = nn.ReLU() 1082 | self.finalAct = None 1083 | self.regularized = regularized 1084 | if 'times' in config: 1085 | self.times = config['times'] 1086 | else: 1087 | self.times = None 1088 | if NumClasses == 1 and classification: 1089 | print('Using Sigmoid') 1090 | self.finalAct = nn.Sigmoid() 1091 | 1092 | def forward(self, data): 1093 | """ 1094 | Defines the forward pass of the Modelfreq_featMasks module. 1095 | 1096 | Parameters: 1097 | data (list): List of input tensors for each frequency model. 1098 | 1099 | Returns: 1100 | Tuple[Tensor, list]: Output tensor and list of component outputs. 1101 | """ 1102 | out = [] 1103 | compouts = [] 1104 | FeatMasks = [] 1105 | for i in range(self.numfreqs): 1106 | if self.hs[i] != 0: 1107 | if self.FeatIdxs is not None: 1108 | if self.regularized: 1109 | data[i] = data[i][:, :, self.FeatIdxs[i]] 1110 | else: 1111 | data[i] = np.array(data[i])[self.FeatIdxs[i].numpy()] 1112 | FeatMask = self.activation(self.FeatMaskWeights[self.modelindx[i]]) 1113 | if self.regularized: 1114 | D = data[i] * FeatMask 1115 | else: 1116 | D = [data[i][j] * FeatMask[j] for j in range(len(data[i]))] 1117 | if self.times: 1118 | o = self.freqmodels[self.modelindx[i]](D, self.times[i]) 1119 | else: 1120 | o = self.freqmodels[self.modelindx[i]](D) 1121 | # o = self.freqmodels[self.modelindx[i]](D) 1122 | out.append(o) 1123 | FeatMasks.append(FeatMask) 1124 | op, compouts = self.fusion(out) 1125 | if self.finalAct: 1126 | op = self.finalAct(op) 1127 | for i in range(len(compouts)): 1128 | compouts[i] = self.finalAct(compouts[i]) 1129 | self.FeatMasks = FeatMasks 1130 | self.wandblog(FeatMasks) 1131 | return op, compouts 1132 | 1133 | def l1_norm(self): 1134 | """ 1135 | Computes the L1 norm of the feature masks. 1136 | 1137 | Returns: 1138 | float: L1 norm of the feature masks. 1139 | """ 1140 | norm = 0.0 1141 | for fm in self.FeatMasks: 1142 | norm += torch.norm(fm, 1) 1143 | return norm 1144 | 1145 | def wandblog(self, Masks): 1146 | """ 1147 | Logs feature masks and weights to Weights & Biases. 1148 | 1149 | Parameters: 1150 | Masks (list): List of feature masks. 1151 | """ 1152 | WeightDict = {} 1153 | for i, Ms in enumerate(Masks): 1154 | for j, m in enumerate(Ms): 1155 | WeightDict['FeatMask_' + str(i) + '_' + str(j)] = m 1156 | WeightDict['FeatWeight_' + str(i) + '_' + str(j)] = self.FeatMaskWeights[i][j] 1157 | WeightDict['ModelNorm_' + str(i)] = torch.norm(Ms, 1) 1158 | WeightDict['FeatIndxs_' + str(i) + '_' + str(j)] = self.FeatIdxs[i][j].int() 1159 | wandb.log(WeightDict, commit=False) 1160 | 1161 | 1162 | def getFreqModel(config): 1163 | """ 1164 | Returns the appropriate frequency model based on the specified configuration. 1165 | 1166 | Parameters: 1167 | config (dict): Configuration dictionary. 1168 | 1169 | Returns: 1170 | nn.Module: The appropriate frequency model. 1171 | 1172 | Raises: 1173 | ValueError: If the model type is not recognized. 1174 | """ 1175 | regularized = True 1176 | if 'regularized' in config: 1177 | regularized = config['regularized'] 1178 | 1179 | if config['model'] == 'Modelfreq': 1180 | modelfreq = Modelfreq( 1181 | config['dropout'], 1182 | hs=config['hs'], 1183 | d=config['d'], 1184 | Comp=config['Comp'], 1185 | numfreqs=config['NumComps'], 1186 | NumFeats=config['NumFeats'], 1187 | classification=config['Classification'], 1188 | Fusion=config['Fusion'], 1189 | UseExtraLinear=config['UseExtraLinear'], 1190 | bidirectional=config['bidirectional'], 1191 | config=config, 1192 | regularized=regularized 1193 | ) 1194 | elif config['model'] == 'Modelfreq_featMasks': 1195 | modelfreq = Modelfreq_featMasks( 1196 | config['dropout'], 1197 | hs=config['hs'], 1198 | d=config['d'], 1199 | Comp=config['Comp'], 1200 | numfreqs=config['NumComps'], 1201 | NumFeats=config['NumFeats'], 1202 | classification=config['Classification'], 1203 | Fusion=config['Fusion'], 1204 | UseExtraLinear=config['UseExtraLinear'], 1205 | bidirectional=config['bidirectional'], 1206 | config=config, 1207 | MaskWeightInit=config['InitMaskW'], 1208 | regularized=regularized 1209 | ) 1210 | elif config['model'] == 'model_FFT': 1211 | modelfreq = model_FFT( 1212 | config['dropout'], 1213 | hs=config['hs'], 1214 | d=config['d'], 1215 | Comp=config['Comp'], 1216 | numfreqs=config['NumComps'], 1217 | NumFeats=config['NumFeats'], 1218 | classification=config['Classification'], 1219 | Fusion=config['Fusion'], 1220 | UseExtraLinear=config['UseExtraLinear'], 1221 | bidirectional=config['bidirectional'], 1222 | config=config, 1223 | regularized=regularized 1224 | ) 1225 | else: 1226 | raise ValueError('Model type not found: ' + config['model']) 1227 | 1228 | return modelfreq -------------------------------------------------------------------------------- /MultiWave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Information-Fusion-Lab-Umass/MultiWave/ca3003e9d72603e0d25c6a1df3c0632d4df09ffd/MultiWave.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiWave 2 | 3 | The code for the paper: Deznabi, Iman, and Madalina Fiterau. "MultiWave: Multiresolution Deep Architectures through Wavelet Decomposition for Multivariate Time Series Prediction." Conference on Health, Inference, and Learning. PMLR, 2023. https://proceedings.mlr.press/v209/deznabi23a.html 4 | 5 | Please cite this paper if you use the code in this repository as part of a published research project. 6 | 7 | ## Overview 8 | 9 | MultiWave implements multiresolution deep architectures for time series prediction using wavelet decomposition. The repository includes code for data processing, model definitions, training routines, and evaluation utilities. 10 | 11 | ![The full architecture of MultiWave](MultiWave.jpg) 12 | 13 | ## Abstract 14 | 15 | The analysis of multivariate time series data is challenging due to the various frequencies of signal changes that can occur over both short and long terms. Furthermore, standard deep learning models are often unsuitable for such datasets, as signals are typically sampled at different rates. To address these issues, we introduce MultiWave, a novel framework that enhances deep learning time series models by incorporating components that operate at the intrinsic frequencies of signals. MultiWave uses wavelets to decompose each signal into subsignals of varying frequencies and groups them into frequency bands. Each frequency band is handled by a different component of our model. A gating mechanism combines the output of the components to produce sparse models that use only specific signals at specific frequencies. Our experiments demonstrate that MultiWave accurately identifies informative frequency bands and improves the performance of various deep learning models, including LSTM, Transformer, and CNN-based models, for a wide range of applications. It attains top performance in stress and affect detection from wearables. It also increases the AUC of the best-performing model by 5\% for in-hospital COVID-19 mortality prediction from patient blood samples and for human activity recognition from accelerometer and gyroscope data. We show that MultiWave consistently identifies critical features and their frequency components, thus providing valuable insights into the applications studied. 16 | 17 | ## Repository Structure 18 | 19 | - **MultiWave/main.py**: Entry point for training and evaluation. See [main.py](main.py). 20 | - **MultiWave/DownloadProcessedWESADdata.py**: Script to download and extract the processed WESAD dataset. 21 | - **MultiWave/Models/**: Contains model definitions, fusion functions, and training routines: 22 | - [`Fusions.py`](Models/Fusions.py) 23 | - [`Routines.py`](Models/Routines.py) 24 | - Other model-specific classes and wrappers. 25 | - **MultiWave/utils/**: Provides utility functions for dataset handling, model training, loss computation, and wavelet transformations: 26 | - [`Dataset.py`](utils/Dataset.py) 27 | - [`ModelUtils.py`](utils/ModelUtils.py) 28 | - [`pytorchtools.py`](utils/pytorchtools.py) 29 | - [`WaveletUtils.py`](utils/WaveletUtils.py) 30 | 31 | ## Installation 32 | 33 | 1. **Clone the repository:** 34 | ```sh 35 | git clone https://github.com/username/MultiWave.git 36 | cd MultiWave 37 | ``` 38 | 39 | 2. **Install required packages**: Make sure you have conda installed. Then, create and activate a new environment with the dependencies: 40 | ```sh 41 | conda env create -f environment.yml 42 | conda activate MultiWave 43 | ``` 44 | 45 | 3. **Set up dataset**: Run the provided script to download and extract the WESAD dataset: 46 | ```sh 47 | python DownloadProcessedWESADdata.py 48 | ``` 49 | 50 | ## Usage 51 | **Training a Model** 52 | To train a model, run the main script with appropriate arguments. For example: 53 | ```sh 54 | python main.py --hs "[8, 8, 8, 8, 8, 0]" --d 32 --seed 123 --Fusion LinearFusion --Routine FeatNormLossWrapper --SubRoutine OnlyLastLoss --UseExtraLinear False --epochstotrain -1 --LW 0.1 --InitWs 0.5 --InitTemp 10.0 --Model Modelfreq_featMasks --Comp FCN_perchannel --NumLayers 1 --WaveletType db1 --LR 0.001 55 | ``` 56 | Refer to the argument definitions in main.py for available options. 57 | 58 | ## Evaluation 59 | Model evaluation and logging are integrated with WandB. Once training completes, performance metrics such as accuracy, AUC, and confusion matrices (for classification) or MSE, MAE, and R2 (for regression) are printed and logged. 60 | 61 | ## Code Details 62 | 63 | ### Models and Routines 64 | The `Models` folder defines various routines (e.g., `LossSwitches`, `CosineLosses`, `ResetModuleWrapper`) which handle training dynamics, loss weighting, and model optimization. These routines also interface with WandB for logging, as shown in the `wandblog` methods. 65 | 66 | ### Utilities 67 | The `utils` folder contains helper functions: 68 | - **Dataset.py**: Handles data loading and tensor conversion. 69 | - **ModelUtils.py**: Contains training loops, loss functions, and evaluation functions. 70 | - **WaveletUtils.py**: Provides utilities for wavelet transformations. 71 | 72 | ### Logging with WandB 73 | Logging is integrated across training and evaluation routines. Functions such as `wandblog` and `wandbLossLogs` are used to track training progress and model performance. 74 | 75 | ## Citation 76 | If you use this code in your research, please cite: 77 | 78 | Deznabi, Iman, and Madalina Fiterau. "MultiWave: Multiresolution Deep Architectures through Wavelet Decomposition for Multivariate Time Series Prediction." Conference on Health, Inference, and Learning. PMLR, 2023. 79 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: MultiWave 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | - conda-forge 7 | - soumith 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - _pytorch_select=0.2=gpu_0 12 | - appdirs=1.4.4=pyhd3eb1b0_0 13 | - blas=1.0=mkl 14 | - bottleneck=1.3.5=py38h7deecbd_0 15 | - brotli=1.0.9=h5eee18b_9 16 | - brotli-bin=1.0.9=h5eee18b_9 17 | - brotli-python=1.0.9=py38h6a678d5_8 18 | - bzip2=1.0.8=h5eee18b_6 19 | - c-ares=1.19.1=h5eee18b_0 20 | - ca-certificates=2024.12.31=h06a4308_0 21 | - certifi=2024.8.30=py38h06a4308_0 22 | - cffi=1.17.1=py38h1fdaa30_0 23 | - charset-normalizer=3.3.2=pyhd3eb1b0_0 24 | - click=8.1.7=py38h06a4308_0 25 | - contourpy=1.0.5=py38hdb19cb5_0 26 | - cuda-cudart=12.4.127=0 27 | - cuda-cupti=12.4.127=0 28 | - cuda-libraries=12.4.1=0 29 | - cuda-nvrtc=12.4.127=0 30 | - cuda-nvtx=12.4.127=0 31 | - cuda-opencl=12.4.127=0 32 | - cuda-runtime=12.4.1=0 33 | - cudatoolkit=9.2=0 34 | - cudnn=7.6.5=cuda9.2_0 35 | - cycler=0.11.0=pyhd3eb1b0_0 36 | - cyrus-sasl=2.1.28=h52b45da_1 37 | - dbus=1.13.18=hb2f20db_0 38 | - docker-pycreds=0.4.0=pyhd3eb1b0_0 39 | - expat=2.6.4=h6a678d5_0 40 | - ffmpeg=4.3=hf484d3e_0 41 | - fontconfig=2.14.1=h55d465d_3 42 | - fonttools=4.51.0=py38h5eee18b_0 43 | - freetype=2.12.1=h4a9f257_0 44 | - giflib=5.2.2=h5eee18b_0 45 | - gitdb=4.0.7=pyhd3eb1b0_0 46 | - gitpython=3.1.43=py38h06a4308_0 47 | - glib=2.78.4=h6a678d5_0 48 | - glib-tools=2.78.4=h6a678d5_0 49 | - gmp=6.3.0=h6a678d5_0 50 | - gmpy2=2.1.2=py38heeb90bb_0 51 | - gnutls=3.6.15=he1e5248_0 52 | - gst-plugins-base=1.14.1=h6a678d5_1 53 | - gstreamer=1.14.1=h5eee18b_1 54 | - icu=73.1=h6a678d5_0 55 | - idna=3.7=py38h06a4308_0 56 | - intel-openmp=2023.1.0=hdb19cb5_46306 57 | - jinja2=3.1.4=py38h06a4308_0 58 | - joblib=1.4.2=py38h06a4308_0 59 | - jpeg=9e=h5eee18b_3 60 | - kiwisolver=1.4.4=py38h6a678d5_0 61 | - krb5=1.20.1=h143b758_1 62 | - lame=3.100=h7b6447c_0 63 | - lcms2=2.16=hb9589c4_0 64 | - ld_impl_linux-64=2.40=h12ee557_0 65 | - lerc=4.0.0=h6a678d5_0 66 | - libabseil=20240116.2=cxx17_h6a678d5_0 67 | - libbrotlicommon=1.0.9=h5eee18b_9 68 | - libbrotlidec=1.0.9=h5eee18b_9 69 | - libbrotlienc=1.0.9=h5eee18b_9 70 | - libclang=14.0.6=default_hc6dbbc7_2 71 | - libclang13=14.0.6=default_he11475f_2 72 | - libcublas=12.4.5.8=0 73 | - libcufft=11.2.1.3=0 74 | - libcufile=1.9.1.3=0 75 | - libcups=2.4.2=h2d74bed_1 76 | - libcurand=10.3.5.147=0 77 | - libcurl=8.11.1=hc9e6f67_0 78 | - libcusolver=11.6.1.9=0 79 | - libcusparse=12.3.1.170=0 80 | - libdeflate=1.22=h5eee18b_0 81 | - libedit=3.1.20230828=h5eee18b_0 82 | - libev=4.33=h7f8727e_1 83 | - libffi=3.4.4=h6a678d5_1 84 | - libgcc-ng=11.2.0=h1234567_1 85 | - libgfortran-ng=7.5.0=ha8ba4b0_17 86 | - libgfortran4=7.5.0=ha8ba4b0_17 87 | - libglib=2.78.4=hdc74915_0 88 | - libgomp=11.2.0=h1234567_1 89 | - libiconv=1.16=h5eee18b_3 90 | - libidn2=2.3.4=h5eee18b_0 91 | - libjpeg-turbo=2.0.0=h9bf148f_0 92 | - libllvm14=14.0.6=hecde1de_4 93 | - libnghttp2=1.57.0=h2d74bed_0 94 | - libnpp=12.2.5.30=0 95 | - libnvfatbin=12.4.127=0 96 | - libnvjitlink=12.4.127=0 97 | - libnvjpeg=12.3.1.117=0 98 | - libpng=1.6.39=h5eee18b_0 99 | - libpq=17.2=hdbd6064_0 100 | - libprotobuf=4.25.3=he621ea3_0 101 | - libssh2=1.11.1=h251f7ec_0 102 | - libstdcxx-ng=11.2.0=h1234567_1 103 | - libtasn1=4.19.0=h5eee18b_0 104 | - libtiff=4.5.1=hffd6297_1 105 | - libunistring=0.9.10=h27cfd23_0 106 | - libuuid=1.41.5=h5eee18b_0 107 | - libuv=1.48.0=h5eee18b_0 108 | - libwebp=1.3.2=h11a3e52_0 109 | - libwebp-base=1.3.2=h5eee18b_1 110 | - libxcb=1.15=h7f8727e_0 111 | - libxkbcommon=1.0.1=h097e994_2 112 | - libxml2=2.13.5=hfdd30dd_0 113 | - llvm-openmp=14.0.6=h9e868ea_0 114 | - lz4-c=1.9.4=h6a678d5_1 115 | - markupsafe=2.1.3=py38h5eee18b_0 116 | - matplotlib=3.6.2=py38h06a4308_0 117 | - matplotlib-base=3.6.2=py38h945d387_0 118 | - mkl=2020.2=256 119 | - mkl-service=2.3.0=py38he904b0f_0 120 | - mkl_fft=1.3.0=py38h54f3939_0 121 | - mkl_random=1.1.1=py38h0573a6f_0 122 | - mpc=1.3.1=h5eee18b_0 123 | - mpfr=4.2.1=h5eee18b_0 124 | - mpmath=1.3.0=py38h06a4308_0 125 | - mysql=8.4.0=h29a9f33_1 126 | - ncurses=6.4=h6a678d5_0 127 | - nettle=3.7.3=hbbd107a_1 128 | - networkx=3.1=py38h06a4308_0 129 | - ninja=1.12.1=h06a4308_0 130 | - ninja-base=1.12.1=hdb19cb5_0 131 | - numexpr=2.7.3=py38hb2eb853_0 132 | - openh264=2.1.1=h4ff587b_0 133 | - openjpeg=2.5.2=he7f1fd0_0 134 | - openldap=2.6.4=h42fbc30_0 135 | - openssl=3.0.15=h5eee18b_0 136 | - packaging=24.1=py38h06a4308_0 137 | - pandas=1.4.4=py38h6a678d5_0 138 | - pcre2=10.42=hebb0a14_1 139 | - pillow=10.4.0=py38h5eee18b_0 140 | - pip=24.2=py38h06a4308_0 141 | - platformdirs=3.10.0=py38h06a4308_0 142 | - ply=3.11=py38_0 143 | - protobuf=4.25.3=py38h12ddb61_0 144 | - psutil=5.9.0=py38h5eee18b_0 145 | - pycparser=2.21=pyhd3eb1b0_0 146 | - pyparsing=3.1.2=py38h06a4308_0 147 | - pyqt=5.15.10=py38h6a678d5_0 148 | - pyqt5-sip=12.13.0=py38h5eee18b_0 149 | - pysocks=1.7.1=py38h06a4308_0 150 | - python=3.8.20=he870216_0 151 | - python-dateutil=2.9.0post0=py38h06a4308_2 152 | - pytorch=2.4.1=py3.8_cuda12.4_cudnn9.1.0_0 153 | - pytorch-cuda=12.4=hc786d27_7 154 | - pytorch-mutex=1.0=cuda 155 | - pytz=2024.1=py38h06a4308_0 156 | - pywavelets=1.4.1=py38h5eee18b_0 157 | - pyyaml=6.0.2=py38h5eee18b_0 158 | - qt-main=5.15.2=hb6262e9_11 159 | - readline=8.2=h5eee18b_0 160 | - requests=2.32.3=py38h06a4308_0 161 | - scikit-learn=1.2.1=py38h6a678d5_0 162 | - scipy=1.6.2=py38h91f5cce_0 163 | - seaborn=0.12.2=py38h06a4308_0 164 | - sentry-sdk=1.9.0=py38h06a4308_0 165 | - setproctitle=1.2.2=py38h27cfd23_1004 166 | - setuptools=75.1.0=py38h06a4308_0 167 | - sip=6.7.12=py38h6a678d5_0 168 | - six=1.16.0=pyhd3eb1b0_1 169 | - smmap=4.0.0=pyhd3eb1b0_0 170 | - sqlite=3.45.3=h5eee18b_0 171 | - sympy=1.13.3=py38h06a4308_0 172 | - threadpoolctl=3.5.0=py38h2f386ee_0 173 | - tk=8.6.14=h39e8969_0 174 | - tomli=2.0.1=py38h06a4308_0 175 | - torchaudio=2.4.1=py38_cu124 176 | - torchtriton=3.0.0=py38 177 | - torchvision=0.20.0=py38_cu124 178 | - tornado=6.4.1=py38h5eee18b_0 179 | - typing_extensions=4.11.0=py38h06a4308_0 180 | - unicodedata2=15.1.0=py38h5eee18b_0 181 | - urllib3=2.2.3=py38h06a4308_0 182 | - wandb=0.16.6=pyhd8ed1ab_1 183 | - wheel=0.44.0=py38h06a4308_0 184 | - xz=5.4.6=h5eee18b_1 185 | - yaml=0.2.5=h7b6447c_0 186 | - zlib=1.2.13=h5eee18b_1 187 | - zstd=1.5.6=hc292b87_0 188 | - pip: 189 | - beautifulsoup4==4.13.3 190 | - configparser==7.1.0 191 | - filelock==3.16.1 192 | - gdown==5.2.0 193 | - kymatio==0.2.1 194 | - numpy==1.24.4 195 | - soupsieve==2.6 196 | - tqdm==4.67.1 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from utils.Dataset import get_RNNdataloader 2 | import numpy as np 3 | import torch 4 | 5 | import pickle 6 | 7 | from utils.ModelUtils import train, get_n_params, set_seed, boolean_string, converttoTensor 8 | from Models.TorchModels import getFreqModel 9 | 10 | import wandb, copy, argparse, os 11 | 12 | from Models.Fusions import getFusion 13 | from Models.Routines import getRoutine 14 | 15 | from utils.WaveletUtils import getRNNFreqGroups_mr 16 | from sklearn.utils.class_weight import compute_class_weight 17 | 18 | import warnings 19 | from sklearn.exceptions import UndefinedMetricWarning 20 | 21 | if __name__ == "__main__": 22 | warnings.filterwarnings("ignore", category=UndefinedMetricWarning) 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--seed', help='The seed used for random seed generator', type=int, default=-1) 25 | parser.add_argument('--ExtraTag', help='extra tag for wandb', type=str, default='') 26 | parser.add_argument('--WandB', help='Use wandb', type=bool, default=True) 27 | parser.add_argument('--WandBEntity', help='Use wandb', type=str, default='') 28 | parser.add_argument('--d', help='dimension for all models', type=int, default=64) 29 | parser.add_argument('--hs', help='dimension for freq models', type=str, default=str([0,0,0,0,0,32])) 30 | parser.add_argument('--checkPath', help='path for checkpoint', type=str, default='') 31 | parser.add_argument('--Routine', help='The loss routine options: OnlyLastLoss, AllLosses, LowToHighFreq, CosineLosses', type=str, default='OnlyLastLoss') 32 | parser.add_argument('--SubRoutine', help='The routine for CosineSimilarityWrapper options: OnlyLastLoss, AllLosses, LowToHighFreq, CosineLosses', type=str, default='OnlyLastLossWithWarming') 33 | parser.add_argument('--epochstotrain', help='epochs to train on Routine', type=int, default=10) 34 | parser.add_argument('--UseExtraLinear', help='Use extra linear in fusion', type=boolean_string, default='False') 35 | parser.add_argument('--Fusion', help='The fusion options: LinearFusion, TransformerFusion, HieLinFusion, AttentionFusion', type=str, default='LinearFusion') 36 | parser.add_argument('--InitTemp', help='Initial temp for TempLossWrapper', type=float, default=10.0) 37 | parser.add_argument('--InitWs', help='Initial W multiplier for switch weights in NormLossWrapper', type=float, default=1.0) 38 | parser.add_argument('--LW', help='The norm weight in loss', type=float, default=2.0) 39 | parser.add_argument('--Model', help='The model for combining components', type=str, default='Modelfreq') 40 | parser.add_argument('--Comp', help='The model for components', type=str, default='BiLSTM') 41 | parser.add_argument('--NumLayers', help='number of layers for components', type=int, default=1) 42 | parser.add_argument('--WaveletType', help='the type of wavelet', type=str, default='db1') 43 | parser.add_argument('--LR', help='learning rate', type=float, default=0.0001) 44 | parser.add_argument('--fold', help='The fold for running', type=int, default=1) 45 | args=parser.parse_args() 46 | 47 | # Convert the hidden unit string into a list of integers. 48 | hs = args.hs 49 | hs = list(map(int, hs.replace("[","").replace("]","").split(', '))) 50 | # Load the dataset for the specified fold from a pickle file. 51 | with open('datasets/WESAD/WESAD_data_fold' + str(args.fold) + '.pkl', 'rb') as f: 52 | (X_train, X_val, X_test, Y_train, Y_val, Y_test, times_train, times_val, times_test) = pickle.load(f) 53 | 54 | # Convert the training, validation, and test input data into tensors. 55 | X_train = converttoTensor(X_train) 56 | X_val = converttoTensor(X_val) 57 | X_test = converttoTensor(X_test) 58 | 59 | # Convert labels to tensors and extract class indices (assuming one-hot encoding). 60 | Y_train = torch.tensor(Y_train) 61 | Y_val = torch.tensor(Y_val) 62 | _, Y_train = Y_train.max(-1); _, Y_val = Y_val.max(-1) 63 | 64 | # Determine whether to apply regularization. 65 | regularize = True 66 | if 'perchannel' in args.Comp: 67 | # Disable regularization if the component type is 'perchannel'. 68 | regularize = False 69 | 70 | # Apply frequency grouping to the input data using wavelet decomposition. 71 | X_train_freq = getRNNFreqGroups_mr(X_train, times_train, maxlevels=len(hs)-2, imputation='forward', waveletType=args.WaveletType, regularize=regularize) 72 | X_val_freq = getRNNFreqGroups_mr(X_val, times_val, maxlevels=len(hs)-2, imputation='forward', waveletType=args.WaveletType, regularize=regularize) 73 | X_test_freq = getRNNFreqGroups_mr(X_test, times_test, maxlevels=len(hs)-2, imputation='forward', waveletType=args.WaveletType, regularize=regularize) 74 | 75 | ExtraTags = args.ExtraTag.split(',') 76 | # ExtraTags += ["fold" + str(args.fold)] 77 | seed = None 78 | if args.seed > -1: 79 | seed = args.seed 80 | set_seed(seed) 81 | print(' ... run starting ...', args) 82 | 83 | RoutineParams = {'LW': args.LW, 'InitWs': args.InitWs, 'InitTemp': args.InitTemp} 84 | routine = getRoutine(args.Routine, NumComps=len(hs), epochstrain=args.epochstotrain, OtherParams=RoutineParams) 85 | subroutine = getRoutine(args.SubRoutine, NumComps=len(hs), epochstrain=args.epochstotrain, OtherParams=RoutineParams) 86 | 87 | fusion = getFusion(args.Fusion) 88 | device = torch.device("cuda:0") 89 | 90 | bidirectional = False 91 | Comp = args.Comp 92 | if args.Comp == 'BiLSTM': 93 | bidirectional = True 94 | 95 | class_weights = compute_class_weight('balanced', classes=np.unique(Y_train.numpy()), y = Y_train.numpy()) 96 | class_weights = torch.tensor(class_weights).to(device).float() 97 | if regularize: 98 | NumFeats = [x.shape[-1] for x in X_train_freq] 99 | else: 100 | NumFeats = [len(x) for x in X_train_freq] 101 | config = {'model': args.Model, 102 | 'NumComps': len(hs), 103 | 'NumFeats': NumFeats, 104 | 'd': args.d, 105 | 'hs': hs, 106 | 'dropout': 0.0, 107 | 'lr': args.LR, 108 | 'patience': 15, 109 | 'batch_size': 16, 110 | 'seed': seed, 111 | 'class_weights': class_weights, 112 | 'Fusion': fusion, 113 | 'RoutineEpochs': args.epochstotrain, 114 | 'UseExtraLinear': args.UseExtraLinear, 115 | 'LossRoutine': type(routine).__name__, 116 | 'SubRoutine': type(subroutine).__name__, 117 | 'Comp': Comp, 118 | 'ExtraTags': ExtraTags, 119 | 'NumLayers': args.NumLayers, 120 | 'bidirectional': bidirectional, 121 | 'NumHeads': 3, 122 | 'InitMaskW' : args.InitWs, 123 | 'Classification': True, 124 | 'CNNKernelSize': 7, 125 | 'WaveletType': args.WaveletType, 126 | 'fold': args.fold, 127 | 'NumClasses': 3, 128 | 'FCNKernelMult': 1.0, 129 | 'regularized': regularize} # HS: hidden size 130 | config.update(RoutineParams) 131 | routine.setConfig(config) 132 | subroutine.setConfig(config) 133 | 134 | modelfreq = getFreqModel(config) 135 | modelfreq = modelfreq.to(device) 136 | numparams = get_n_params(modelfreq) 137 | 138 | tags = [ 139 | config['model'] 140 | ] 141 | if args.ExtraTag != '': 142 | tags += ExtraTags 143 | if args.WandB: 144 | os.environ["WANDB__SERVICE_WAIT"] = "300" 145 | wandb.init( 146 | project="WESAD", 147 | config=copy.deepcopy(config), 148 | entity=args.WandBEntity, 149 | tags=tags, 150 | ) 151 | wandb.log({'num_params': numparams}) 152 | 153 | train_dataloaderFreq = get_RNNdataloader(X_train_freq, Y_train, config['batch_size'], shuffle=True, freq=True, regularized=regularize) 154 | val_dataloaderFreq = get_RNNdataloader(X_val_freq, Y_val, 128, shuffle=False, freq=True, regularized=regularize) 155 | test_dataloaderFreq = get_RNNdataloader(X_test_freq, Y_test, 128, shuffle=False, freq=True, regularized=regularize) 156 | 157 | optimizer = torch.optim.Adam(modelfreq.parameters(), lr=config['lr']) 158 | # optimizer = torch.optim.RMSprop(modelfreq.parameters(), lr=config['lr']) 159 | ChckPointfolder = 'Checkpoints/WESAD/' 160 | os.makedirs(ChckPointfolder, exist_ok=True) 161 | ChckPointPath = os.path.join(ChckPointfolder, 'CurrentChck_' + wandb.run.id) 162 | routine.setModelOptimizer(optimizer, modelfreq) 163 | subroutine.setModelOptimizer(optimizer, modelfreq) 164 | routine.SetSubRoutine(subroutine) 165 | train(modelfreq, device, train_dataloaderFreq, val_dataloaderFreq, test_dataloaderFreq, optimizer, 1000, 166 | LossRoutine=routine, class_weights = config['class_weights'], patience = config['patience'], checkpointPath=ChckPointPath, 167 | usewandb=True, convertdirectly=False, classification=config['Classification'], scaler=None, Yscaled=False, NumClasses=config['NumClasses']) 168 | -------------------------------------------------------------------------------- /utils/Dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | class EventData(torch.utils.data.Dataset): 3 | """ Event stream dataset. """ 4 | 5 | # fore_demo, fore_times_ip, fore_values_ip, fore_varis_ip 6 | def __init__(self, data, data_op, freq=True, hascovs=True): 7 | """ 8 | Data should be a list of event streams; each event stream is a list of dictionaries; 9 | each dictionary contains: time_since_start, time_since_last_event, type_event 10 | """ 11 | self.hascovs = hascovs 12 | self.freq = freq 13 | self.labels = data_op 14 | 15 | if hascovs: 16 | self.covs = data[3] 17 | self.times = data[1] 18 | self.values = data[2] 19 | self.feats = data[0] 20 | if freq: 21 | self.length = data[0][0].shape[0] 22 | else: 23 | self.length = data[0].shape[0] 24 | 25 | def __len__(self): 26 | return self.length 27 | 28 | def __getitem__(self, idx): 29 | """ Each returned element is a list, which represents an event stream """ 30 | if self.freq: 31 | out = [] 32 | for i in range(len(self.times)): 33 | out.append((self.times[i][idx], self.values[i][idx], self.feats[i][idx])) 34 | if self.hascovs: 35 | out.append(self.covs[idx]) 36 | else: 37 | if self.hascovs: 38 | out = (self.times[idx], self.values[idx], self.feats[idx], self.covs[idx]) 39 | else: 40 | out = (self.times[idx], self.values[idx], self.feats[idx]) 41 | return out, self.labels[idx] 42 | def get_dataloader(data, data_out, batch_size, shuffle=True, freq=True, hascovs=True): 43 | """ Prepare dataloader. """ 44 | 45 | ds = EventData(data, data_out, freq=freq, hascovs=hascovs) 46 | dl = torch.utils.data.DataLoader( 47 | ds, 48 | num_workers=2, 49 | batch_size=batch_size, 50 | shuffle=shuffle 51 | ) 52 | return dl 53 | 54 | class RNNData(torch.utils.data.Dataset): 55 | """ Event stream dataset. """ 56 | def __init__(self, data, data_op, freq=False, regularized=True): 57 | """ 58 | Data should be a list of event streams; each event stream is a list of dictionaries; 59 | each dictionary contains: time_since_start, time_since_last_event, type_event 60 | """ 61 | self.freq = freq 62 | self.data = data 63 | self.labels = data_op 64 | self.regularized = regularized 65 | 66 | def __len__(self): 67 | if self.freq: 68 | if self.regularized: 69 | return self.data[0].shape[0] 70 | else: 71 | return self.data[0][0].shape[0] 72 | else: 73 | return self.data.shape[0] 74 | 75 | def __getitem__(self, idx): 76 | """ Each returned element is a list, which represents an event stream """ 77 | if self.freq: 78 | if self.regularized: 79 | out = [d[idx, :, :] for d in self.data] 80 | else: 81 | out = [[d[idx, :] for d in d_arr] for d_arr in self.data] 82 | return out, self.labels[idx] 83 | else: 84 | return self.data[idx,:,:], self.labels[idx] 85 | def get_RNNdataloader(data, data_out, batch_size, shuffle=True, freq=False, regularized=True): 86 | """ Prepare dataloader. """ 87 | 88 | ds = RNNData(data, data_out, freq=freq, regularized=regularized) 89 | dl = torch.utils.data.DataLoader( 90 | ds, 91 | num_workers=2, 92 | batch_size=batch_size, 93 | shuffle=shuffle 94 | ) 95 | return dl -------------------------------------------------------------------------------- /utils/ModelUtils.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import precision_recall_fscore_support, accuracy_score 2 | from sklearn.metrics import confusion_matrix, roc_auc_score, precision_recall_curve, auc 3 | import torch 4 | import torch.nn as nn 5 | from utils.pytorchtools import EarlyStopping 6 | import numpy as np 7 | import wandb 8 | from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score 9 | import random 10 | 11 | def converttoTensor(data): 12 | """ 13 | Converts a list of numpy arrays to a list of PyTorch tensors. 14 | 15 | Parameters: 16 | data (list): List of numpy arrays. 17 | 18 | Returns: 19 | list: List of PyTorch tensors. 20 | """ 21 | alldata = [] 22 | for d in data: 23 | alldata.append(torch.tensor(d).float()) 24 | return alldata 25 | 26 | def set_seed(seed): 27 | """ 28 | Sets the random seed for reproducibility. 29 | 30 | Parameters: 31 | seed (int): The seed value to set. 32 | """ 33 | print('setting seed to', seed) 34 | torch.manual_seed(seed) 35 | np.random.seed(seed) 36 | random.seed(seed) 37 | torch.cuda.manual_seed(seed) 38 | torch.backends.cudnn.deterministic = True 39 | torch.backends.cudnn.benchmark = False 40 | 41 | def boolean_string(s): 42 | """ 43 | Converts a string to a boolean value. 44 | 45 | Parameters: 46 | s (str): Input string ('True' or 'False'). 47 | 48 | Returns: 49 | bool: Boolean value corresponding to the input string. 50 | 51 | Raises: 52 | ValueError: If the input string is not 'True' or 'False'. 53 | """ 54 | if s not in {'False', 'True'}: 55 | raise ValueError('Not a valid boolean string') 56 | return s == 'True' 57 | 58 | def get_n_params(model): 59 | """ 60 | Computes the number of parameters in a PyTorch model. 61 | 62 | Parameters: 63 | model (torch.nn.Module): The PyTorch model. 64 | 65 | Returns: 66 | int: Number of parameters in the model. 67 | """ 68 | pp = 0 69 | for p in list(model.parameters()): 70 | nn = 1 71 | for s in list(p.size()): 72 | nn = nn * s 73 | pp += nn 74 | return pp 75 | 76 | def wandblog(prfs, epoch, classification=True, postfix=''): 77 | """ 78 | Logs performance metrics to Weights & Biases. 79 | 80 | Parameters: 81 | prfs (dict): Dictionary of performance metrics. 82 | epoch (int): Current epoch number. 83 | classification (bool, optional): Whether the task is classification. Defaults to True. 84 | postfix (str, optional): Postfix for metric names. Defaults to ''. 85 | """ 86 | if classification: 87 | for prf_k in prfs: 88 | wandb.log({'auc_' + prf_k + postfix: prfs[prf_k]['roc_macro'], 89 | 'auc_weighted_' + prf_k + postfix: prfs[prf_k]['roc_weighted'], 90 | 'fscore_' + prf_k + postfix: prfs[prf_k]['fscore_macro'], 91 | 'confMat_' + prf_k + postfix: prfs[prf_k]['confusionMatrix'], 92 | 'minrp_' + prf_k + postfix: prfs[prf_k]['minrp'], 93 | 'pr_auc_' + prf_k + postfix: prfs[prf_k]['pr_auc'], 94 | 'accuracy_' + prf_k + postfix: prfs[prf_k]['accuracy'], 95 | 'loss_' + prf_k + postfix: prfs[prf_k]['loss'], 96 | 'epoch': epoch}) 97 | else: 98 | for prf_k in prfs: 99 | wandb.log({'mse_' + prf_k + postfix: prfs[prf_k]['mse'], 100 | 'mae_' + prf_k + postfix: prfs[prf_k]['mae'], 101 | 'r2_' + prf_k + postfix: prfs[prf_k]['r2'], 102 | 'epoch': epoch}) 103 | 104 | def wandbLossLogs(complosses, epoch): 105 | """ 106 | Logs component losses to Weights & Biases. 107 | 108 | Parameters: 109 | complosses (dict): Dictionary of component losses. 110 | epoch (int): Current epoch number. 111 | """ 112 | LossDict = {} 113 | for K in complosses: 114 | for i, L in enumerate(complosses[K]): 115 | LossDict['CompLoss_' + K + '_Model_' + str(i)] = L 116 | LossDict['epoch'] = epoch 117 | wandb.log(LossDict) 118 | 119 | def weighted_binary_cross_entropy(output, target, weights=None): 120 | """ 121 | Computes the weighted binary cross-entropy loss. 122 | 123 | Parameters: 124 | output (torch.Tensor): Model output. 125 | target (torch.Tensor): Ground truth labels. 126 | weights (list, optional): Weights for positive and negative classes. Defaults to None. 127 | 128 | Returns: 129 | torch.Tensor: Computed loss. 130 | """ 131 | if weights is not None: 132 | assert len(weights) == 2 133 | output = torch.clamp(output, min=1e-7, max=1-1e-7) 134 | loss = weights[1] * (target * torch.log(output)) + \ 135 | weights[0] * ((1 - target) * torch.log(1 - output)) 136 | else: 137 | loss = target * torch.log(output) + (1 - target) * torch.log(1 - output) 138 | 139 | return torch.neg(torch.mean(loss)) 140 | 141 | def weighted_cross_entropy(output, target, weights): 142 | """ 143 | Computes the weighted cross-entropy loss. 144 | 145 | Parameters: 146 | output (torch.Tensor): Model output. 147 | target (torch.Tensor): Ground truth labels. 148 | weights (torch.Tensor): Weights for each class. 149 | 150 | Returns: 151 | torch.Tensor: Computed loss. 152 | """ 153 | return nn.CrossEntropyLoss(weight=weights)(output, target) 154 | 155 | def getLoss(output, labels, class_weights=None, classification=True, NumClasses=1): 156 | """ 157 | Computes the loss based on the task type and class weights. 158 | 159 | Parameters: 160 | output (torch.Tensor): Model output. 161 | labels (torch.Tensor): Ground truth labels. 162 | class_weights (torch.Tensor, optional): Weights for each class. Defaults to None. 163 | classification (bool, optional): Whether the task is classification. Defaults to True. 164 | NumClasses (int, optional): Number of classes. Defaults to 1. 165 | 166 | Returns: 167 | torch.Tensor: Computed loss. 168 | """ 169 | if classification: 170 | if NumClasses == 1: 171 | loss = weighted_binary_cross_entropy(output, labels, class_weights) 172 | else: 173 | loss = weighted_cross_entropy(output, labels, class_weights) 174 | else: 175 | loss = nn.MSELoss()(output, labels) 176 | return loss 177 | 178 | def train_step(epoch, model, device, train_loader, optimizer, class_weights, NumClasses=1, LossRoutine=None, classification=True, convertdirectly=False, scaler=None, Yscaled=False): 179 | """ 180 | Performs a single training step. 181 | 182 | Parameters: 183 | epoch (int): Current epoch number. 184 | model (torch.nn.Module): The model to train. 185 | device (torch.device): Device to use for computation. 186 | train_loader (DataLoader): DataLoader for the training data. 187 | optimizer (torch.optim.Optimizer): Optimizer for updating model parameters. 188 | class_weights (list): Weights for each class. 189 | NumClasses (int, optional): Number of classes. Defaults to 1. 190 | LossRoutine (object, optional): Custom loss routine. Defaults to None. 191 | classification (bool, optional): Whether the task is classification. Defaults to True. 192 | convertdirectly (bool, optional): Whether to convert input directly to device. Defaults to False. 193 | scaler (object, optional): Scaler for input data. Defaults to None. 194 | Yscaled (bool, optional): Whether the labels are scaled. Defaults to False. 195 | 196 | Returns: 197 | dict: Training performance metrics. 198 | list: Component losses. 199 | """ 200 | model.train() 201 | correct = 0 202 | for batch_idx, (batch_in, labels) in enumerate(train_loader): 203 | labels = labels.to(device) 204 | if convertdirectly: 205 | batch_in = batch_in.to(device) 206 | else: 207 | for i in range(len(batch_in)): 208 | if type(batch_in[i]) == list: 209 | for j in range(len(batch_in[i])): 210 | batch_in[i][j] = batch_in[i][j].to(device) 211 | else: 212 | batch_in[i] = batch_in[i].to(device) 213 | optimizer.zero_grad() 214 | output, compouts = model(batch_in) 215 | CompLosses = [] 216 | for compout in compouts: 217 | L = getLoss(compout, labels, class_weights=class_weights, classification=classification, NumClasses=NumClasses) 218 | CompLosses.append(L) 219 | loss = getLoss(output, labels, class_weights=class_weights, classification=classification, NumClasses=NumClasses) 220 | if LossRoutine: 221 | loss = LossRoutine.getLoss(loss, CompLosses, epoch) 222 | if LossRoutine and LossRoutine.hascustombackward: 223 | LossRoutine.Backward(loss, CompLosses) 224 | else: 225 | loss.backward() 226 | optimizer.step() 227 | if batch_idx % 100 == 0: # Print loss every 100 batches 228 | print('batch_idx: {}\tLoss: {:.6f}'.format(batch_idx, loss.item())) 229 | train_prf, CompLosses = test(model, device, train_loader, class_weights=class_weights, classification=classification, convertdirectly=convertdirectly, scaler=scaler, Yscaled=Yscaled, NumClasses=NumClasses) 230 | return train_prf, CompLosses 231 | 232 | def train(model, device, train_loader, val_loader, test_loader, optimizer, epochs, NumClasses=1, LossRoutine=None, class_weights=[1.0, 1.0], classification=True, patience=5, checkpointPath='Checkpoints/CurrentChck', usewandb=True, convertdirectly=False, scaler=None, Yscaled=False, closewandb=True, wandbpostfix=''): 233 | """ 234 | Trains the model for a specified number of epochs. 235 | 236 | Parameters: 237 | model (torch.nn.Module): The model to train. 238 | device (torch.device): Device to use for computation. 239 | train_loader (DataLoader): DataLoader for the training data. 240 | val_loader (DataLoader): DataLoader for the validation data. 241 | test_loader (DataLoader): DataLoader for the test data. 242 | optimizer (torch.optim.Optimizer): Optimizer for updating model parameters. 243 | epochs (int): Number of epochs to train. 244 | NumClasses (int, optional): Number of classes. Defaults to 1. 245 | LossRoutine (object, optional): Custom loss routine. Defaults to None. 246 | class_weights (list, optional): Weights for each class. Defaults to [1.0, 1.0]. 247 | classification (bool, optional): Whether the task is classification. Defaults to True. 248 | patience (int, optional): Patience for early stopping. Defaults to 5. 249 | checkpointPath (str, optional): Path to save the model checkpoint. Defaults to 'Checkpoints/CurrentChck'. 250 | usewandb (bool, optional): Whether to use Weights & Biases for logging. Defaults to True. 251 | convertdirectly (bool, optional): Whether to convert input directly to device. Defaults to False. 252 | scaler (object, optional): Scaler for input data. Defaults to None. 253 | Yscaled (bool, optional): Whether the labels are scaled. Defaults to False. 254 | closewandb (bool, optional): Whether to close Weights & Biases after training. Defaults to True. 255 | wandbpostfix (str, optional): Postfix for Weights & Biases logging. Defaults to ''. 256 | 257 | Returns: 258 | None 259 | """ 260 | initEpochs = 0 261 | if LossRoutine: 262 | initEpochs = LossRoutine.startingepochs 263 | es = EarlyStopping(initEpochs=initEpochs, patience=patience, path=checkpointPath, verbose=True) 264 | model.train() 265 | for epoch in range(epochs): 266 | LossRoutine.PreTrainStep(epoch) 267 | train_prf, train_CompLosses = train_step(epoch, model, device, train_loader, optimizer, class_weights, classification=classification, convertdirectly=convertdirectly, scaler=scaler, Yscaled=Yscaled, LossRoutine=LossRoutine, NumClasses=NumClasses) 268 | if classification: 269 | print('\nEpoch: {}, Train set Accuracy: {:.2f}, Train set AUC macro: {:.4f}, Train set AUC weighted: {:.4f}\n'.format(epoch, train_prf['accuracy'], train_prf['roc_macro'], train_prf['roc_weighted'])) 270 | else: 271 | print('\nEpoch: {}, Train set MSE: {:.4f}, Train set MAE: {:.4f}, Train set R2: {:.4f}\n'.format(epoch, train_prf['mse'], train_prf['mae'], train_prf['r2'])) 272 | val_prf, val_CompLosses = test(model, device, val_loader, class_weights=class_weights, classification=classification, convertdirectly=convertdirectly, scaler=scaler, NumClasses=NumClasses) 273 | if classification: 274 | print('\nEpoch: {}, Val set Accuracy: {:.2f}, Val set AUC macro: {:.4f}, Val set AUC weighted: {:.4f}\n'.format(epoch, val_prf['accuracy'], val_prf['roc_macro'], val_prf['roc_weighted'])) 275 | es(-(val_prf['roc_macro'] + val_prf['pr_auc']), model) 276 | LossRoutine.saveLosses((train_prf['loss'], train_CompLosses), (val_prf['loss'], val_CompLosses), epoch) 277 | else: 278 | print('\nEpoch: {}, Val set MSE: {:.4f}, Val set MAE: {:.4f}, Val set R2: {:.4f}\n'.format(epoch, val_prf['mse'], val_prf['mae'], val_prf['r2'])) 279 | es(val_prf['mse'], model) 280 | LossRoutine.saveLosses((train_prf['mse'], train_CompLosses), (val_prf['mse'], val_CompLosses), epoch) 281 | if usewandb: 282 | wandblog({'train': train_prf, 'val': val_prf}, epoch, classification=classification, postfix=wandbpostfix) 283 | wandbLossLogs({'train': train_CompLosses, 'val': val_CompLosses}, epoch) 284 | if es.early_stop and LossRoutine.stopatES: 285 | print("Early stopping at epoch " + str(epoch)) 286 | break 287 | elif es.early_stop: 288 | model, optimizer = LossRoutine.ResetModel(es, checkpointPath, device) 289 | model.load_state_dict(torch.load(checkpointPath)) 290 | val_prf_final, val_CompLosses_final = test(model, device, val_loader, class_weights=class_weights, classification=classification, convertdirectly=convertdirectly, scaler=scaler, NumClasses=NumClasses) 291 | test_prf, test_CompLosses = test(model, device, test_loader, class_weights=class_weights, classification=classification, convertdirectly=convertdirectly, scaler=scaler, NumClasses=NumClasses) 292 | if usewandb: 293 | wandblog({'test': test_prf, 'val_final': val_prf_final}, epoch, classification=classification, postfix=wandbpostfix) 294 | wandbLossLogs({'test': test_CompLosses, 'val_final': val_CompLosses_final}, epoch) 295 | wandb.save(checkpointPath) 296 | if closewandb: 297 | wandb.finish() 298 | if classification: 299 | print('\nTest set Accuracy: {:.4f}, Test set f1: {:.4f}, Test set AUC macro: {:.4f}, Test set AUC weighted: {:.4f}, Test set Confmat: {}\n'.format(test_prf['accuracy'], test_prf['fscore_macro'], test_prf['roc_macro'], test_prf['roc_weighted'], str(test_prf['confusionMatrix']))) 300 | else: 301 | print('\nTest set MSE: {:.4f}, Test set MAE: {:.4f}, Test set R2: {:.4f}\n'.format(test_prf['mse'], test_prf['mae'], test_prf['r2'])) 302 | 303 | def get_pr_auc(y_true, y_pred): 304 | """ 305 | Computes the Precision-Recall AUC and the maximum of the minimum precision and recall. 306 | 307 | Parameters: 308 | y_true (array-like): True binary labels. 309 | y_pred (array-like): Target scores. 310 | 311 | Returns: 312 | list: [pr_auc, minrp] where pr_auc is the Precision-Recall AUC and minrp is the maximum of the minimum precision and recall. 313 | """ 314 | precision, recall, thresholds = precision_recall_curve(y_true, y_pred) 315 | pr_auc = auc(recall, precision) 316 | minrp = np.minimum(precision, recall).max() 317 | return [pr_auc, minrp] 318 | 319 | def Evaluate(Labels, Preds, PredScores, class_weights): 320 | """ 321 | Evaluates the model's performance on classification tasks. 322 | 323 | Parameters: 324 | Labels (array-like): True labels. 325 | Preds (array-like): Predicted labels. 326 | PredScores (array-like): Predicted scores. 327 | class_weights (list): Weights for each class. 328 | 329 | Returns: 330 | dict: Dictionary containing various evaluation metrics. 331 | """ 332 | avg = 'binary' 333 | NumClasses = 1 334 | if len(class_weights) > 2: 335 | avg = 'weighted' 336 | NumClasses = len(class_weights) 337 | if NumClasses > 1: 338 | PredScores = nn.Softmax(-1)(PredScores) 339 | percision, recall, fscore, support = precision_recall_fscore_support(Labels, Preds, average=avg) 340 | _, _, fscore_weighted, _ = precision_recall_fscore_support(Labels, Preds, average='weighted') 341 | _, _, fscore_macro, _ = precision_recall_fscore_support(Labels, Preds, average='macro') 342 | accuracy = accuracy_score(Labels, Preds) * 100 343 | confmat = confusion_matrix(Labels, Preds) 344 | loss = getLoss(PredScores.float(), Labels, class_weights=class_weights, classification=True, NumClasses=NumClasses) 345 | roc_macro, roc_weighted = roc_auc_score(Labels, PredScores, average='macro', multi_class='ovr'), roc_auc_score(Labels, PredScores, average='weighted', multi_class='ovr') 346 | if NumClasses == 1: 347 | pr_auc, minrp = get_pr_auc(Labels, PredScores) 348 | else: 349 | pr_auc, minrp = 0, 0 350 | prf_test = {'percision': percision, 'recall': recall, 'fscore': fscore, 'fscore_weighted': fscore_weighted, 351 | 'fscore_macro': fscore_macro, 'accuracy': accuracy, 'confusionMatrix': confmat, 'roc_macro': roc_macro, 352 | 'roc_weighted': roc_weighted, 'loss': loss, 'minrp': minrp, 'pr_auc': pr_auc} 353 | return prf_test 354 | 355 | """ 356 | From https://en.wikipedia.org/wiki/Coefficient_of_determination 357 | """ 358 | def r2_loss(output, target): 359 | """ 360 | Computes the R2 loss (coefficient of determination). 361 | 362 | Parameters: 363 | output (torch.Tensor): Model output. 364 | target (torch.Tensor): Ground truth labels. 365 | 366 | Returns: 367 | torch.Tensor: Computed R2 loss. 368 | """ 369 | target_mean = torch.mean(target) 370 | ss_tot = torch.sum((target - target_mean) ** 2) 371 | ss_res = torch.sum((target - output) ** 2) 372 | r2 = 1 - ss_res / ss_tot 373 | return r2 374 | 375 | def EvaluateReg(Labels, Preds, PredScores, class_weights, scaler=None, Yscaled=False): 376 | """ 377 | Evaluates the model's performance on regression tasks. 378 | 379 | Parameters: 380 | Labels (array-like): True labels. 381 | Preds (array-like): Predicted labels. 382 | PredScores (array-like): Predicted scores. 383 | class_weights (list): Weights for each class. 384 | scaler (object, optional): Scaler for input data. Defaults to None. 385 | Yscaled (bool, optional): Whether the labels are scaled. Defaults to False. 386 | 387 | Returns: 388 | dict: Dictionary containing various evaluation metrics. 389 | """ 390 | if scaler is not None: 391 | Preds = scaler.inverse_transform(Preds.reshape(-1, 1)) 392 | Preds = Preds.squeeze(-1) 393 | if Yscaled: 394 | Labels = scaler.inverse_transform(Labels.reshape(-1, 1)) 395 | Labels = Labels.squeeze(-1) 396 | mae = mean_absolute_error(Labels, Preds) 397 | mse = mean_squared_error(Labels, Preds) 398 | r2 = r2_score(Preds, Labels) 399 | prf_test = {'Preds': Preds, 'Labels': Labels, 'r2': r2, 'mae': mae, 'mse': mse} 400 | return prf_test 401 | 402 | def test(model, device, test_loader, class_weights=[1.0,1.0], classification=True, convertdirectly=False, scaler=None, Yscaled=False, NumClasses=1): 403 | """ 404 | Evaluates the model on the test dataset. 405 | 406 | Parameters: 407 | model (torch.nn.Module): The model to evaluate. 408 | device (torch.device): Device to use for computation. 409 | test_loader (DataLoader): DataLoader for the test data. 410 | class_weights (list, optional): Weights for each class. Defaults to [1.0, 1.0]. 411 | classification (bool, optional): Whether the task is classification. Defaults to True. 412 | convertdirectly (bool, optional): Whether to convert input directly to device. Defaults to False. 413 | scaler (object, optional): Scaler for input data. Defaults to None. 414 | Yscaled (bool, optional): Whether the labels are scaled. Defaults to False. 415 | NumClasses (int, optional): Number of classes. Defaults to 1. 416 | 417 | Returns: 418 | dict: Test performance metrics. 419 | list: Component losses. 420 | """ 421 | model.eval() 422 | correct = 0 423 | corrects = torch.tensor([], dtype=torch.int64).to(device) 424 | preds = torch.tensor([], dtype=torch.int64).to(device) 425 | predScores = torch.tensor([], dtype=torch.float).to(device) 426 | FirstTime = True 427 | total_loss = 0.0 428 | total_num = 0 429 | CompoutsAll = [] 430 | with torch.no_grad(): 431 | for (batch_in, labels) in test_loader: 432 | labels = labels.to(device) 433 | if convertdirectly: 434 | batch_in = batch_in.to(device) 435 | else: 436 | for i in range(len(batch_in)): 437 | if type(batch_in[i]) == list: 438 | for j in range(len(batch_in[i])): 439 | batch_in[i][j] = batch_in[i][j].to(device) 440 | else: 441 | batch_in[i] = batch_in[i].to(device) 442 | output, compouts = model(batch_in) 443 | if classification: 444 | if len(output.shape) > 1: 445 | _, pred = torch.max(output.data, 1) 446 | else: 447 | pred = output.data > 0.5 448 | else: 449 | pred = output.data 450 | if FirstTime: 451 | predScores = output; corrects = labels; preds = pred 452 | FirstTime = False 453 | CompoutsAll = compouts 454 | else: 455 | predScores = torch.cat((predScores, output)) 456 | corrects = torch.cat((corrects, labels)) 457 | preds = torch.cat((preds, pred)) 458 | for i in range(len(CompoutsAll)): 459 | CompoutsAll[i] = torch.cat((CompoutsAll[i], compouts[i])) 460 | CompoutLosses = [] 461 | for compout in CompoutsAll: 462 | CompLoss = getLoss(compout, corrects, class_weights=class_weights, classification=classification, NumClasses=NumClasses) 463 | CompoutLosses.append(CompLoss) 464 | if classification: 465 | prf_test = Evaluate(corrects.cpu(), preds.cpu(), predScores.cpu(), class_weights.cpu()) 466 | else: 467 | prf_test = EvaluateReg(corrects.cpu(), preds.cpu(), predScores.cpu(), class_weights.cpu(), scaler=scaler, Yscaled=Yscaled) 468 | return prf_test, CompoutLosses -------------------------------------------------------------------------------- /utils/WaveletUtils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pywt 4 | import pandas as pd 5 | import math 6 | 7 | def imputeVals(X, imputation='mean'): 8 | """ 9 | Imputes missing values in the input tensor X based on the specified imputation method. 10 | 11 | Parameters: 12 | X (numpy.ndarray): Input tensor with missing values. 13 | imputation (str): Imputation method ('mean', 'forward', 'zero', 'backward'). 14 | 15 | Returns: 16 | numpy.ndarray: Tensor with imputed values. 17 | """ 18 | Xs = [] 19 | for i in range(X.shape[-1]): 20 | df = pd.DataFrame(X[:, :, i].transpose()) 21 | if imputation == 'mean': 22 | df = df.fillna(df.mean()) 23 | elif imputation == 'forward': 24 | df = df.ffill() 25 | elif imputation == 'zero': 26 | df = df.fillna(0) 27 | elif imputation == 'backward': 28 | df = df.bfill() 29 | Xs.append(df.to_numpy().transpose()) 30 | return np.stack(Xs, -1) 31 | 32 | def Regularize(X, times, imputation='mean'): 33 | """ 34 | Regularizes the input tensor X based on the provided times and imputation method. 35 | 36 | Parameters: 37 | X (list): List of input tensors. 38 | times (list): List of time intervals. 39 | imputation (str): Imputation method ('mean', 'forward', 'zero', 'backward'). 40 | 41 | Returns: 42 | torch.Tensor: Regularized tensor. 43 | torch.Tensor: All times tensor. 44 | """ 45 | size = X[0].shape[0] 46 | series = [] 47 | for x, t in zip(X, times): 48 | A = pd.DataFrame(x.T, index=t) 49 | series.append(A) 50 | df = pd.concat(series, axis=1) 51 | AllTimes = torch.tensor(df.index.values).float() 52 | X = df.to_numpy() # shape: Len, size*feats 53 | X = X.reshape([X.shape[0], -1, size]) # shape: Len, feats, size 54 | X = X.transpose([2, 0, 1]) 55 | X = imputeVals(X, imputation=imputation) 56 | return torch.tensor(X), AllTimes 57 | 58 | def getdeltaTimes(times): 59 | """ 60 | Computes the delta times from the provided times. 61 | 62 | Parameters: 63 | times (torch.Tensor): Input times tensor. 64 | 65 | Returns: 66 | torch.Tensor: Delta times tensor. 67 | """ 68 | Times = times.clone() 69 | for i in reversed(range(1, len(times))): 70 | Times[i] = times[i] - times[i-1] 71 | return Times 72 | 73 | def getRNNFreqGroups_mr(data, times, device=torch.device("cuda:0"), maxlevels=4, waveletType='haar', imputation='mean', fulldata=None, regularize=True, return_times=False): 74 | """ 75 | Computes the frequency groups for RNN using multi-resolution wavelet decomposition. 76 | 77 | Parameters: 78 | data (list): List of input tensors. 79 | times (list): List of time intervals. 80 | device (torch.device): Device to use for computation. 81 | maxlevels (int): Maximum levels for wavelet decomposition. 82 | waveletType (str): Type of wavelet to use. 83 | imputation (str): Imputation method ('mean', 'forward', 'zero', 'backward'). 84 | fulldata (list, optional): Full data tensor. 85 | regularize (bool): Whether to regularize the data. 86 | return_times (bool): Whether to return times. 87 | 88 | Returns: 89 | list: List of frequency groups. 90 | list (optional): List of times if return_times is True. 91 | """ 92 | WL = pywt.Wavelet(waveletType) 93 | MLs = [] 94 | Outs = [[] for _ in range(maxlevels + 1)] 95 | Ts = [[] for _ in range(maxlevels + 1)] 96 | for d in data: 97 | ML = pywt.dwt_max_level(d.shape[1], WL) 98 | MLs.append(ML) 99 | dL = max(MLs) - maxlevels 100 | MaxT = max([max(t) for t in times]) 101 | for i, d in enumerate(data): 102 | out = pywt.wavedec(d, WL, level=MLs[i] - dL, axis=1, mode='periodization') 103 | for j, o in enumerate(out): 104 | Outs[j].append(o) 105 | TSubSamp = math.ceil(times[i].shape[0] / o.shape[1]) 106 | Ts[j].append(times[i][::TSubSamp]) 107 | if fulldata is None: 108 | Outs.append([d.cpu().numpy() for d in data]) # Convert tensors to CPU tensors and then to NumPy arrays 109 | Ts.append(times) 110 | if regularize: 111 | Times = [] 112 | Outs_ls = [] 113 | for x, t in zip(Outs, Ts): 114 | o, time = Regularize(x, t, imputation) 115 | time /= MaxT 116 | time = getdeltaTimes(time) 117 | Outs_ls.append(o) 118 | Times.append(time) 119 | Outs = Outs_ls 120 | else: 121 | Outs = [[torch.tensor(x) for x in x_arr] for x_arr in Outs] 122 | if fulldata is not None: 123 | Outs.append(fulldata) 124 | if return_times: 125 | return Outs, Times 126 | return Outs 127 | 128 | def getRNNFreqGroups(data, device=torch.device("cuda:0"), maxlevels=4, waveletType='haar'): 129 | """ 130 | Computes the frequency groups for RNN using wavelet decomposition. 131 | 132 | Parameters: 133 | data (torch.Tensor): Input tensor. 134 | device (torch.device): Device to use for computation. 135 | maxlevels (int): Maximum levels for wavelet decomposition. 136 | waveletType (str): Type of wavelet to use. 137 | 138 | Returns: 139 | list: List of frequency groups. 140 | """ 141 | WL = pywt.Wavelet(waveletType) 142 | ML = pywt.dwt_max_level(data.shape[1], WL) 143 | out = pywt.wavedec(data, WL, level=min(maxlevels, ML), axis=1) 144 | out.append(data) 145 | out = [torch.tensor(o) for o in out] # Convert to tensor 146 | return out -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Information-Fusion-Lab-Umass/MultiWave/ca3003e9d72603e0d25c6a1df3c0632d4df09ffd/utils/__init__.py -------------------------------------------------------------------------------- /utils/pytorchtools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class EarlyStopping: 5 | """Early stops the training if validation loss doesn't improve after a given patience.""" 6 | def __init__(self, initEpochs=0, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | path (str): Path for the checkpoint to be saved to. 16 | Default: 'checkpoint.pt' 17 | trace_func (function): trace print function. 18 | Default: print 19 | """ 20 | self.epochs = 0 21 | self.initEpochs = initEpochs 22 | self.patience = patience 23 | self.verbose = verbose 24 | self.counter = 0 25 | self.best_score = None 26 | self.early_stop = False 27 | self.val_loss_min = np.Inf 28 | self.delta = delta 29 | self.path = path 30 | self.trace_func = trace_func 31 | def __call__(self, val_loss, model): 32 | 33 | score = -val_loss 34 | if self.epochs >= self.initEpochs: 35 | if self.best_score is None: 36 | self.best_score = score 37 | self.save_checkpoint(val_loss, model) 38 | elif score <= self.best_score + self.delta: 39 | self.counter += 1 40 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 41 | if self.counter >= self.patience: 42 | self.early_stop = True 43 | else: 44 | self.best_score = score 45 | self.save_checkpoint(val_loss, model) 46 | self.counter = 0 47 | self.epochs += 1 48 | def restart(self): 49 | self.best_score = None 50 | self.early_stop = False 51 | self.val_loss_min = np.Inf 52 | self.epochs = 0 53 | self.counter = 0 54 | def save_checkpoint(self, val_loss, model): 55 | '''Saves model when validation loss decrease.''' 56 | if self.verbose: 57 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 58 | torch.save(model.state_dict(), self.path) 59 | self.val_loss_min = val_loss 60 | --------------------------------------------------------------------------------