├── .gitignore ├── README.md └── icldata.py /.gitignore: -------------------------------------------------------------------------------- 1 | features/* 2 | labels/* 3 | other/* 4 | cache/* 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ICLabel Dataset 2 | 3 | ---- 4 | ## What is ICLabel? 5 | ICLabel is a project aimed at advancing automated electroenephalographic (EEG) independent component (IC) classification. It is comprised of three interlinked parts: the [ICLabel classifier](https://github.com/lucapton/ICLabel), the [ICLabel website](https://iclabel.ucsd.edu/tutorial), and this dataset. The website crowdsources labels for the dataset that in turn is used to train the classifier. 6 | 7 | See the accompanying publication [Coming Soon]. 8 | 9 | 10 | ---- 11 | ## What is this dataset? 12 | The ICLabel dataset contains an unlabeled training dataset, several collections of labels for small subset of the training dataset, and a test dataset 130 ICs where each IC was labeled by 6 experts. In total it is comprised of features from hundreds of thousands of unique EEG ICs (millions if you count similar ICs from different processing stages of the same datasets). Roughly 8000 of those have labels, though the actual usable number is typically closer to 6000 depending on which features are being used. The features included are: 13 | * Scalp topography images (32x32 pixel flattened to 740 elements after removing white-space) 14 | * Power spectral densities (1-100 Hz) 15 | * Autocorrelation functions (1 second) 16 | * Equivalent current dipole fits (1 and 2 dipole) 17 | * Hand crafted features (some new and some from previously published classifiers) 18 | 19 | The original time series data are not available. All that is provided is included in this repository as is. I realize having the original time series would make this dataset much more versatile, but unfortunately that's not possible. 20 | 21 | ---- 22 | ## Usage 23 | 1. Load the class, passing any desired options. 24 | 2. Load the dataset. 25 | 26 | Example: 27 | 28 | icl = ICLabelDataset() 29 | icl.download_trainset_features() 30 | icldata = icl.load_semi_supervised() 31 | -------------------------------------------------------------------------------- /icldata.py: -------------------------------------------------------------------------------- 1 | from time import (time, gmtime, strftime) 2 | import h5py 3 | import os 4 | from shutil import rmtree 5 | from os.path import isdir, isfile, join, basename 6 | import cPickle as pkl 7 | import sqlite3 8 | from collections import OrderedDict 9 | from copy import copy 10 | 11 | import numpy as np 12 | from sklearn.decomposition import PCA 13 | import joblib 14 | from matplotlib import pyplot as plt 15 | import webbrowser as wb 16 | import requests 17 | import tqdm 18 | 19 | 20 | class ICLabelDataset: 21 | 22 | """ 23 | 24 | This class provides an easy interface to downloading, loading, organizing, and processing the ICLabel dataset. 25 | The ICLabel dataset is intended for training and validating electroencephalographic (EEG) independent component 26 | (IC) classifiers. 27 | 28 | It contains an unlabled training dataset, several collections of labels for small subset of the training dataset, 29 | and a test dataset 130 ICs where each IC was labeled by 6 experts. 30 | 31 | Features included: 32 | * Scalp topography images (32x32 pixel flattened to 740 elements after removing white-space) 33 | * Power spectral densities (1-100 Hz) 34 | * Autocorrelation functions (1 second) 35 | * Equivalent current dipole fits (1 and 2 dipole) 36 | * Hand crafted features (some new and some from previously published classifiers) 37 | 38 | :Example: 39 | 40 | icl = ICLabelDataset(); 41 | icldata = icl.load_semi_supervised() 42 | 43 | """ 44 | 45 | def __init__(self, 46 | features='all', 47 | label_type='all', 48 | datapath='', 49 | n_test_datasets=50, 50 | n_val_ics=200, 51 | transform='none', 52 | unique=True, 53 | do_pca=False, 54 | combine_output=False, 55 | seed=np.random.randint(0, int(1e5))): 56 | """ 57 | Initialize an ICLabelDataset object. 58 | :param features: The types of features to return. 59 | :param label_type: Which ICLabels to use. 60 | :param datapath: Where the dataset and cache is stored. 61 | :param n_test_datasets: How many unlabeled datasets to include in the test set. 62 | :param n_val_ics: How many labeled components to transfer to the validation set. 63 | :param transform: The inverse log-ratio transform to use for labels and their covariances. 64 | :param unique: Whether or not to use ICs with the same scalp topography. Non-unique is not implemented. 65 | :param combine_output: determines whether output features are dictionaries or an array of combined features. 66 | :param seed: The seed for the pseudo random shuffle of data points. 67 | :return: Initialized ICLabelDataset object. 68 | """ 69 | # data parameters 70 | self.datapath = datapath 71 | self.features = features 72 | self.n_test_datasets = n_test_datasets 73 | self.n_val_ics = n_val_ics 74 | self.transform = transform 75 | self.unique = unique 76 | if not self.unique: 77 | raise NotImplementedError 78 | self.do_pca = do_pca 79 | self.combine_output = combine_output 80 | self.label_type = label_type 81 | assert(label_type in ('all', 'luca', 'database')) 82 | self.seed = seed 83 | self.psd_mean = None 84 | self.psd_mean_var = None 85 | self.psd_mean_kurt = None 86 | self.psd_limits = None 87 | self.psd_var_limits = None 88 | self.psd_kurt_limits = None 89 | self.pscorr_mean = None 90 | self.pscorr_std = None 91 | self.pscorr_limits = None 92 | self.psd_freqs = 100 93 | 94 | # training feature-sets 95 | self.train_feature_indices = OrderedDict([ 96 | ('ids', np.arange(2)), 97 | ('topo', np.arange(2, 742)), 98 | ('handcrafted', np.arange(742, 760)), # one lost due to removal in load_data 99 | ('dipole', np.arange(760, 780)), 100 | ('psd', np.arange(780, 880)), 101 | ('psd_var', np.arange(880, 980)), 102 | ('psd_kurt', np.arange(980, 1080)), 103 | ('autocorr', np.arange(1080, 1180)), 104 | ]) 105 | self.test_feature_indices = OrderedDict([ 106 | ('ids', np.arange(3)), 107 | ('topo', np.arange(3, 743)), 108 | ('handcrafted', np.arange(743, 761)), # one lost due to removal in load_data 109 | ('dipole', np.arange(761, 781)), 110 | ('psd', np.arange(781, 881)), 111 | ('psd_var', np.arange(881, 981)), 112 | ('psd_kurt', np.arange(981, 1081)), 113 | ('autocorr', np.arange(1081, 1181)), 114 | ]) 115 | 116 | # reorganize features 117 | if self.features == 'all' or 'all' in self.features: 118 | self.features = self.train_feature_indices.keys() 119 | if isinstance(self.features, str): 120 | self.features = [self.features] 121 | if 'ids' not in self.features: 122 | self.features = ['ids'] + self.features 123 | 124 | # visualization parameters 125 | self.topo_ind = np.array([ 126 | 43, 127 | 44, 128 | 45, 129 | 46, 130 | 47, 131 | 48, 132 | 49, 133 | 50, 134 | 51, 135 | 52, 136 | 72, 137 | 73, 138 | 74, 139 | 75, 140 | 76, 141 | 77, 142 | 78, 143 | 79, 144 | 80, 145 | 81, 146 | 82, 147 | 83, 148 | 84, 149 | 85, 150 | 86, 151 | 87, 152 | 103, 153 | 104, 154 | 105, 155 | 106, 156 | 107, 157 | 108, 158 | 109, 159 | 110, 160 | 111, 161 | 112, 162 | 113, 163 | 114, 164 | 115, 165 | 116, 166 | 117, 167 | 118, 168 | 119, 169 | 120, 170 | 134, 171 | 135, 172 | 136, 173 | 137, 174 | 138, 175 | 139, 176 | 140, 177 | 141, 178 | 142, 179 | 143, 180 | 144, 181 | 145, 182 | 146, 183 | 147, 184 | 148, 185 | 149, 186 | 150, 187 | 151, 188 | 152, 189 | 153, 190 | 165, 191 | 166, 192 | 167, 193 | 168, 194 | 169, 195 | 170, 196 | 171, 197 | 172, 198 | 173, 199 | 174, 200 | 175, 201 | 176, 202 | 177, 203 | 178, 204 | 179, 205 | 180, 206 | 181, 207 | 182, 208 | 183, 209 | 184, 210 | 185, 211 | 186, 212 | 196, 213 | 197, 214 | 198, 215 | 199, 216 | 200, 217 | 201, 218 | 202, 219 | 203, 220 | 204, 221 | 205, 222 | 206, 223 | 207, 224 | 208, 225 | 209, 226 | 210, 227 | 211, 228 | 212, 229 | 213, 230 | 214, 231 | 215, 232 | 216, 233 | 217, 234 | 218, 235 | 219, 236 | 227, 237 | 228, 238 | 229, 239 | 230, 240 | 231, 241 | 232, 242 | 233, 243 | 234, 244 | 235, 245 | 236, 246 | 237, 247 | 238, 248 | 239, 249 | 240, 250 | 241, 251 | 242, 252 | 243, 253 | 244, 254 | 245, 255 | 246, 256 | 247, 257 | 248, 258 | 249, 259 | 250, 260 | 251, 261 | 252, 262 | 258, 263 | 259, 264 | 260, 265 | 261, 266 | 262, 267 | 263, 268 | 264, 269 | 265, 270 | 266, 271 | 267, 272 | 268, 273 | 269, 274 | 270, 275 | 271, 276 | 272, 277 | 273, 278 | 274, 279 | 275, 280 | 276, 281 | 277, 282 | 278, 283 | 279, 284 | 280, 285 | 281, 286 | 282, 287 | 283, 288 | 284, 289 | 285, 290 | 290, 291 | 291, 292 | 292, 293 | 293, 294 | 294, 295 | 295, 296 | 296, 297 | 297, 298 | 298, 299 | 299, 300 | 300, 301 | 301, 302 | 302, 303 | 303, 304 | 304, 305 | 305, 306 | 306, 307 | 307, 308 | 308, 309 | 309, 310 | 310, 311 | 311, 312 | 312, 313 | 313, 314 | 314, 315 | 315, 316 | 316, 317 | 317, 318 | 322, 319 | 323, 320 | 324, 321 | 325, 322 | 326, 323 | 327, 324 | 328, 325 | 329, 326 | 330, 327 | 331, 328 | 332, 329 | 333, 330 | 334, 331 | 335, 332 | 336, 333 | 337, 334 | 338, 335 | 339, 336 | 340, 337 | 341, 338 | 342, 339 | 343, 340 | 344, 341 | 345, 342 | 346, 343 | 347, 344 | 348, 345 | 349, 346 | 353, 347 | 354, 348 | 355, 349 | 356, 350 | 357, 351 | 358, 352 | 359, 353 | 360, 354 | 361, 355 | 362, 356 | 363, 357 | 364, 358 | 365, 359 | 366, 360 | 367, 361 | 368, 362 | 369, 363 | 370, 364 | 371, 365 | 372, 366 | 373, 367 | 374, 368 | 375, 369 | 376, 370 | 377, 371 | 378, 372 | 379, 373 | 380, 374 | 381, 375 | 382, 376 | 385, 377 | 386, 378 | 387, 379 | 388, 380 | 389, 381 | 390, 382 | 391, 383 | 392, 384 | 393, 385 | 394, 386 | 395, 387 | 396, 388 | 397, 389 | 398, 390 | 399, 391 | 400, 392 | 401, 393 | 402, 394 | 403, 395 | 404, 396 | 405, 397 | 406, 398 | 407, 399 | 408, 400 | 409, 401 | 410, 402 | 411, 403 | 412, 404 | 413, 405 | 414, 406 | 417, 407 | 418, 408 | 419, 409 | 420, 410 | 421, 411 | 422, 412 | 423, 413 | 424, 414 | 425, 415 | 426, 416 | 427, 417 | 428, 418 | 429, 419 | 430, 420 | 431, 421 | 432, 422 | 433, 423 | 434, 424 | 435, 425 | 436, 426 | 437, 427 | 438, 428 | 439, 429 | 440, 430 | 441, 431 | 442, 432 | 443, 433 | 444, 434 | 445, 435 | 446, 436 | 449, 437 | 450, 438 | 451, 439 | 452, 440 | 453, 441 | 454, 442 | 455, 443 | 456, 444 | 457, 445 | 458, 446 | 459, 447 | 460, 448 | 461, 449 | 462, 450 | 463, 451 | 464, 452 | 465, 453 | 466, 454 | 467, 455 | 468, 456 | 469, 457 | 470, 458 | 471, 459 | 472, 460 | 473, 461 | 474, 462 | 475, 463 | 476, 464 | 477, 465 | 478, 466 | 481, 467 | 482, 468 | 483, 469 | 484, 470 | 485, 471 | 486, 472 | 487, 473 | 488, 474 | 489, 475 | 490, 476 | 491, 477 | 492, 478 | 493, 479 | 494, 480 | 495, 481 | 496, 482 | 497, 483 | 498, 484 | 499, 485 | 500, 486 | 501, 487 | 502, 488 | 503, 489 | 504, 490 | 505, 491 | 506, 492 | 507, 493 | 508, 494 | 509, 495 | 510, 496 | 513, 497 | 514, 498 | 515, 499 | 516, 500 | 517, 501 | 518, 502 | 519, 503 | 520, 504 | 521, 505 | 522, 506 | 523, 507 | 524, 508 | 525, 509 | 526, 510 | 527, 511 | 528, 512 | 529, 513 | 530, 514 | 531, 515 | 532, 516 | 533, 517 | 534, 518 | 535, 519 | 536, 520 | 537, 521 | 538, 522 | 539, 523 | 540, 524 | 541, 525 | 542, 526 | 545, 527 | 546, 528 | 547, 529 | 548, 530 | 549, 531 | 550, 532 | 551, 533 | 552, 534 | 553, 535 | 554, 536 | 555, 537 | 556, 538 | 557, 539 | 558, 540 | 559, 541 | 560, 542 | 561, 543 | 562, 544 | 563, 545 | 564, 546 | 565, 547 | 566, 548 | 567, 549 | 568, 550 | 569, 551 | 570, 552 | 571, 553 | 572, 554 | 573, 555 | 574, 556 | 577, 557 | 578, 558 | 579, 559 | 580, 560 | 581, 561 | 582, 562 | 583, 563 | 584, 564 | 585, 565 | 586, 566 | 587, 567 | 588, 568 | 589, 569 | 590, 570 | 591, 571 | 592, 572 | 593, 573 | 594, 574 | 595, 575 | 596, 576 | 597, 577 | 598, 578 | 599, 579 | 600, 580 | 601, 581 | 602, 582 | 603, 583 | 604, 584 | 605, 585 | 606, 586 | 609, 587 | 610, 588 | 611, 589 | 612, 590 | 613, 591 | 614, 592 | 615, 593 | 616, 594 | 617, 595 | 618, 596 | 619, 597 | 620, 598 | 621, 599 | 622, 600 | 623, 601 | 624, 602 | 625, 603 | 626, 604 | 627, 605 | 628, 606 | 629, 607 | 630, 608 | 631, 609 | 632, 610 | 633, 611 | 634, 612 | 635, 613 | 636, 614 | 637, 615 | 638, 616 | 641, 617 | 642, 618 | 643, 619 | 644, 620 | 645, 621 | 646, 622 | 647, 623 | 648, 624 | 649, 625 | 650, 626 | 651, 627 | 652, 628 | 653, 629 | 654, 630 | 655, 631 | 656, 632 | 657, 633 | 658, 634 | 659, 635 | 660, 636 | 661, 637 | 662, 638 | 663, 639 | 664, 640 | 665, 641 | 666, 642 | 667, 643 | 668, 644 | 669, 645 | 670, 646 | 674, 647 | 675, 648 | 676, 649 | 677, 650 | 678, 651 | 679, 652 | 680, 653 | 681, 654 | 682, 655 | 683, 656 | 684, 657 | 685, 658 | 686, 659 | 687, 660 | 688, 661 | 689, 662 | 690, 663 | 691, 664 | 692, 665 | 693, 666 | 694, 667 | 695, 668 | 696, 669 | 697, 670 | 698, 671 | 699, 672 | 700, 673 | 701, 674 | 706, 675 | 707, 676 | 708, 677 | 709, 678 | 710, 679 | 711, 680 | 712, 681 | 713, 682 | 714, 683 | 715, 684 | 716, 685 | 717, 686 | 718, 687 | 719, 688 | 720, 689 | 721, 690 | 722, 691 | 723, 692 | 724, 693 | 725, 694 | 726, 695 | 727, 696 | 728, 697 | 729, 698 | 730, 699 | 731, 700 | 732, 701 | 733, 702 | 738, 703 | 739, 704 | 740, 705 | 741, 706 | 742, 707 | 743, 708 | 744, 709 | 745, 710 | 746, 711 | 747, 712 | 748, 713 | 749, 714 | 750, 715 | 751, 716 | 752, 717 | 753, 718 | 754, 719 | 755, 720 | 756, 721 | 757, 722 | 758, 723 | 759, 724 | 760, 725 | 761, 726 | 762, 727 | 763, 728 | 764, 729 | 765, 730 | 771, 731 | 772, 732 | 773, 733 | 774, 734 | 775, 735 | 776, 736 | 777, 737 | 778, 738 | 779, 739 | 780, 740 | 781, 741 | 782, 742 | 783, 743 | 784, 744 | 785, 745 | 786, 746 | 787, 747 | 788, 748 | 789, 749 | 790, 750 | 791, 751 | 792, 752 | 793, 753 | 794, 754 | 795, 755 | 796, 756 | 804, 757 | 805, 758 | 806, 759 | 807, 760 | 808, 761 | 809, 762 | 810, 763 | 811, 764 | 812, 765 | 813, 766 | 814, 767 | 815, 768 | 816, 769 | 817, 770 | 818, 771 | 819, 772 | 820, 773 | 821, 774 | 822, 775 | 823, 776 | 824, 777 | 825, 778 | 826, 779 | 827, 780 | 837, 781 | 838, 782 | 839, 783 | 840, 784 | 841, 785 | 842, 786 | 843, 787 | 844, 788 | 845, 789 | 846, 790 | 847, 791 | 848, 792 | 849, 793 | 850, 794 | 851, 795 | 852, 796 | 853, 797 | 854, 798 | 855, 799 | 856, 800 | 857, 801 | 858, 802 | 870, 803 | 871, 804 | 872, 805 | 873, 806 | 874, 807 | 875, 808 | 876, 809 | 877, 810 | 878, 811 | 879, 812 | 880, 813 | 881, 814 | 882, 815 | 883, 816 | 884, 817 | 885, 818 | 886, 819 | 887, 820 | 888, 821 | 889, 822 | 903, 823 | 904, 824 | 905, 825 | 906, 826 | 907, 827 | 908, 828 | 909, 829 | 910, 830 | 911, 831 | 912, 832 | 913, 833 | 914, 834 | 915, 835 | 916, 836 | 917, 837 | 918, 838 | 919, 839 | 920, 840 | 936, 841 | 937, 842 | 938, 843 | 939, 844 | 940, 845 | 941, 846 | 942, 847 | 943, 848 | 944, 849 | 945, 850 | 946, 851 | 947, 852 | 948, 853 | 949, 854 | 950, 855 | 951, 856 | 971, 857 | 972, 858 | 973, 859 | 974, 860 | 975, 861 | 976, 862 | 977, 863 | 978, 864 | 979, 865 | 980, 866 | ]) 867 | self.psd_ind = np.arange(1, 101) 868 | self.max_grid_plot = 144 869 | self.base_url_image = 'https://labeling.ucsd.edu/images/' 870 | 871 | # data url 872 | self.base_url_download = 'https://labeling.ucsd.edu/download/' 873 | self.feature_train_zip_url = self.base_url_download + 'features.zip' 874 | self.feature_train_zip_parts_url = self.base_url_download + 'features{:02d}.zip' 875 | self.num_feature_train_files = 25 876 | self.feature_train_urls = [ 877 | self.base_url_download + 'features_0D1D2D.mat', 878 | self.base_url_download + 'features_PSD_med_var_kurt.mat', 879 | self.base_url_download + 'features_AutoCorr.mat', 880 | self.base_url_download + 'features_ICAChanlocs.mat', 881 | self.base_url_download + 'features_MI.mat', 882 | ] 883 | self.label_train_urls = [ 884 | self.base_url_download + 'ICLabels_expert.pkl', 885 | self.base_url_download + 'ICLabels_onlyluca.pkl', 886 | ] 887 | self.feature_test_url = self.base_url_download + 'features_testset_full.mat' 888 | self.label_test_url = self.base_url_download + 'ICLabels_test.pkl' 889 | self.db_url = self.base_url_download + 'anonymized_database.sqlite' 890 | self.cls_url = self.base_url_download + 'other_classifiers.mat' 891 | 892 | # util 893 | 894 | @staticmethod 895 | def __load_matlab_cellstr(f, var_name=''): 896 | var = [] 897 | if var_name: 898 | for column in f[var_name]: 899 | row_data = [] 900 | for row_number in range(len(column)): 901 | row_data.append(''.join(map(unichr, f[column[row_number]][:]))) 902 | var.append(row_data) 903 | return [str(x)[3:-2] for x in var] 904 | 905 | @staticmethod 906 | def __match_indices(*indices): 907 | """ Match sets of multidimensional ids/indices when there is a 1-1 relationtionship """ 908 | 909 | # find matching indices 910 | index = np.concatenate(indices) # array of values 911 | _, duplicates, counts = np.unique(index, return_inverse=True, return_counts=True, axis=0) 912 | duplicates = np.split(duplicates, np.cumsum([x.shape[0] for x in indices[:-1]]), 0) # list of vectors of ints 913 | sufficient_counts = np.where(counts == len(indices))[0] # vector of ints 914 | matching_indices = [np.where(np.in1d(x, sufficient_counts))[0] for x in duplicates] # list of vectors of ints 915 | indices = [y[x] for x, y in zip(matching_indices, indices)] # list of arrays of values 916 | 917 | # organize to match first index array 918 | try: 919 | sort_inds = [np.lexsort(np.fliplr(x).T) for x in indices] 920 | except ValueError: 921 | sort_inds = [np.argsort(x) for x in indices] 922 | out = np.array([x[y[sort_inds[0]]] for x, y in zip(matching_indices, sort_inds)]) 923 | 924 | return out 925 | 926 | # data access 927 | 928 | def load_data(self): 929 | """ 930 | Load the ICL dataset in an unprocessed form. 931 | Follows the settings provided during initializations 932 | :return: Dictionary of unprocessed but matched feature-sets and labels. 933 | """ 934 | start = time() 935 | 936 | # organize info 937 | if self.transform in (None, 'none'): 938 | if self.label_type == 'all': 939 | file_name = 'ICLabels_expert.pkl' 940 | elif self.label_type == 'luca': 941 | file_name = 'ICLabels_onlyluca.pkl' 942 | processed_file_name = 'processed_dataset' 943 | if self.unique: 944 | processed_file_name += '_unique' 945 | if self.label_type == 'all': 946 | processed_file_name += '_all' 947 | self.check_for_download('train_labels') 948 | elif self.label_type == 'luca': 949 | processed_file_name += '_luca' 950 | self.check_for_download('train_labels') 951 | elif self.label_type == 'database': 952 | processed_file_name += '_database' 953 | self.check_for_download('database') 954 | processed_file_name += '.pkl' 955 | 956 | # load processed data file if it exists 957 | if isfile(join(self.datapath, 'cache', processed_file_name)): 958 | dataset = joblib.load(join(self.datapath, 'cache', processed_file_name)) 959 | 960 | # if not, create it 961 | else: 962 | # load features 963 | features = [] 964 | feature_labels = [] 965 | print('Loading full dataset...') 966 | 967 | self.check_for_download('train_features') 968 | # topo maps, old psd, dipole, and handcrafted 969 | with h5py.File(join(self.datapath, 'features', 'features_0D1D2D.mat'), 'r') as f: 970 | print('Loading 0D1D2D features...') 971 | features.append(np.asarray(f['features']).T) 972 | feature_labels.append(self.__load_matlab_cellstr(f, 'labels')) 973 | # new psd 974 | with h5py.File(join(self.datapath, 'features', 'features_PSD_med_var_kurt.mat'), 'r') as f: 975 | print('Loading PSD features...') 976 | features.append(list()) 977 | for element in f['features_out'][0]: 978 | data = np.array(f[element]).T 979 | # if no data, skip 980 | if data.ndim == 1 or data.dtype != np.float64: 981 | continue 982 | nyquist = (data.shape[1] - 2) / 3 983 | nfreq = 100 984 | # if more than nfreqs, remove extra 985 | if nyquist > nfreq: 986 | data = data[:, np.concatenate((range(2 + nfreq), 987 | range(2 + nyquist, 2 + nyquist + nfreq), 988 | range(2 + 2*nyquist, 2 + 2*nyquist + nfreq)))] 989 | # if less than nfreqs, repeat last frequency value 990 | elif nyquist < nfreq: 991 | data = data[:, np.concatenate((range(2 + nyquist), 992 | np.repeat(1 + nyquist, nfreq - nyquist), 993 | range(2 + nyquist, 2 + 2*nyquist), 994 | np.repeat(1 + 2*nyquist, nfreq - nyquist), 995 | range(2 + 2*nyquist, 2 + 3*nyquist), 996 | np.repeat(1 + 3*nyquist, nfreq - nyquist)) 997 | ).astype(int)] 998 | 999 | features[-1].append(data) 1000 | features[-1] = np.concatenate(features[-1], axis=0) 1001 | feature_labels.append(['ID_set', 'ID_ic'] + ['psd_median']*nfreq 1002 | + ['psd_var']*nfreq + ['psd_kurt']*nfreq) 1003 | # autocorrelation 1004 | with h5py.File(join(self.datapath, 'features', 'features_AutoCorr.mat'), 'r') as f: 1005 | print('Loading AutoCorr features...') 1006 | features.append(list()) 1007 | for element in f['features_out'][0]: 1008 | data = np.array(f[element]).T 1009 | if data.size > 2 and data.shape[1] == 102 and not len(data.dtype): 1010 | features[-1].append(data) 1011 | features[-1] = np.concatenate(features[-1], axis=0) 1012 | feature_labels.append(self.__load_matlab_cellstr(f, 'feature_labels')[:2] + ['Autocorr'] * 100) 1013 | 1014 | # find topomap duplicates 1015 | print('Finding topo duplicates...') 1016 | _, duplicate_order = np.unique(features[0][:, 2:742].astype(np.float32), return_inverse=True, axis=0) 1017 | do_sortind = np.argsort(duplicate_order) 1018 | do_sorted = duplicate_order[do_sortind] 1019 | do_indices = np.where(np.diff(np.concatenate(([-1], do_sorted))))[0] 1020 | group2indices = [do_sortind[do_indices[x]:do_indices[x + 1]] for x in range(0, duplicate_order.max())] 1021 | del _ 1022 | 1023 | # load labels 1024 | if self.label_type == 'database': 1025 | # load data from database 1026 | conn = sqlite3.connect(join(self.datapath, 'labels', 'database.sqlite')) 1027 | c = conn.cursor() 1028 | dblabels = c.execute('SELECT * FROM labels ' 1029 | 'INNER JOIN images ON labels.image_id = images.id ' 1030 | 'WHERE user_id IN ' 1031 | '(SELECT user_id FROM labels ' 1032 | 'GROUP BY user_id ' 1033 | 'HAVING COUNT(*) >= 30)' 1034 | ).fetchall() 1035 | conn.close() 1036 | # reformat as list of ndarrays 1037 | dblabels = [(x[1], np.array(x[15:17]), np.array(x[3:11])) for x in dblabels] 1038 | dblabels = [np.stack(x) for x in zip(*dblabels)] 1039 | # organize labels by image 1040 | udb = np.unique(dblabels[1], return_inverse=True, axis=0) 1041 | dblabels = [(dblabels[0][y], dblabels[1][y][0], dblabels[2][y]) 1042 | for y in (udb[1] == x for x in range(len(udb[0])))] 1043 | label_index = np.stack((x[1] for x in dblabels)) 1044 | 1045 | elif self.label_type == 'luca': 1046 | # load data from database 1047 | conn = sqlite3.connect(join(self.datapath, 'labels', 'database.sqlite')) 1048 | c = conn.cursor() 1049 | dblabelsluca = c.execute('SELECT * FROM labels ' 1050 | 'INNER JOIN images ON labels.image_id = images.id ' 1051 | 'WHERE user_id = 1').fetchall() 1052 | conn.close() 1053 | # remove low-confidence labels 1054 | dblabelsluca = [x for x in dblabelsluca if x[10] == 0] 1055 | # reformat as ndarray 1056 | labels = np.array([x[3:10] for x in dblabelsluca]).astype(np.float32) 1057 | labels /= labels.sum(1, keepdims=True) 1058 | labels = [labels] 1059 | label_index = np.array([x[15:17] for x in dblabelsluca]) 1060 | transforms = ['none'] 1061 | 1062 | else: 1063 | # load labels from files 1064 | with open(join(self.datapath, 'labels', file_name), 'rb') as f: 1065 | print('Loading labels...') 1066 | data = pkl.load(f) 1067 | if 'transform' in data.keys(): 1068 | transforms = data['transform'] 1069 | else: 1070 | transforms = ['none'] 1071 | labels = data['labels'] 1072 | if isinstance(labels, np.ndarray): 1073 | labels = [labels] 1074 | if 'labels_cov' in data.keys(): 1075 | label_cov = data['labels_cov'] 1076 | label_index = np.stack((data['instance_set_numbers'], data['instance_ic_numbers'])).T 1077 | del data 1078 | 1079 | # match components and labels 1080 | print('Matching components and labels...') 1081 | temp = self.__match_indices(label_index.astype(np.int), features[0][:, :2].astype(np.int)) 1082 | label2component = dict(zip(*temp)) 1083 | del temp 1084 | # match feature-sets 1085 | print('Matching features...') 1086 | feature_inds = self.__match_indices(*[x[:, :2].astype(np.int) for x in features]) 1087 | 1088 | # check which labels are not kept 1089 | print('Rearanging components and labels...') 1090 | kept_labels = [x for x, y in label2component.iteritems() if y in feature_inds[0]] 1091 | dropped_labels = [x for x, y in label2component.iteritems() if y not in feature_inds[0]] 1092 | 1093 | # for each label, pick a new component that is kept (if any) 1094 | ind_n_data_points = [x for x, y in enumerate(feature_labels[0]) if y == 'number of data points'][0] 1095 | for ind in dropped_labels: 1096 | group = duplicate_order[label2component[ind]] 1097 | candidate_components = np.intersect1d(group2indices[group], feature_inds[0]) 1098 | # if more than one choice, pick the one from the dataset with the most samples unless one from this 1099 | # group has already been found 1100 | if len(candidate_components) >= 1: 1101 | if len(candidate_components) == 1: 1102 | new_index = features[0][candidate_components, :2] 1103 | else: 1104 | new_index = features[0][candidate_components[features[0][candidate_components, 1105 | ind_n_data_points].argmax()], :2] 1106 | if not (new_index == label_index[dropped_labels]).all(1).any() \ 1107 | and not any([(x == label_index[kept_labels]).all(1).any() 1108 | for x in features[0][candidate_components, :2]]): 1109 | label_index[ind] = new_index 1110 | del label2component, kept_labels, dropped_labels, duplicate_order 1111 | 1112 | # feature labels (change with features) 1113 | psd_lims = np.where(np.char.startswith(feature_labels[0], 'psd'))[0][[0, -1]] 1114 | feature_labels = np.concatenate((feature_labels[0][:psd_lims[0]], 1115 | feature_labels[0][psd_lims[1] + 1:], 1116 | feature_labels[1][2:], 1117 | feature_labels[2][2:])) 1118 | 1119 | # combine features, keeping only components with all features 1120 | print('Combining feature-sets...') 1121 | 1122 | def index_features(data, new_index): 1123 | return np.concatenate((data[0][feature_inds[0][new_index], :psd_lims[0]].astype(np.float32), 1124 | data[0][feature_inds[0][new_index], psd_lims[1] + 1:].astype(np.float32), 1125 | data[1][feature_inds[1][new_index], 2:].astype(np.float32), 1126 | data[2][feature_inds[2][new_index], 2:].astype(np.float32)), 1127 | axis=1) 1128 | 1129 | # rematch with labels 1130 | print('Rematching components and labels...') 1131 | ind_labeled_labels, ind_labeled_features = self.__match_indices( 1132 | label_index.astype(np.int),features[0][feature_inds[0], :2].astype(np.int)) 1133 | del label_index 1134 | 1135 | # find topomap duplicates 1136 | _, duplicate_order = np.unique(features[0][feature_inds[0], 2:742].astype(np.float32), return_inverse=True, 1137 | axis=0) 1138 | do_sortind = np.argsort(duplicate_order) 1139 | do_sorted = duplicate_order[do_sortind] 1140 | do_indices = np.where(np.diff(np.concatenate(([-1], do_sorted))))[0] 1141 | group2indices = [do_sortind[do_indices[x]:do_indices[x + 1]] for x in range(0, duplicate_order.max())] 1142 | 1143 | # aggregate data 1144 | dataset = dict() 1145 | try: 1146 | dataset['transform'] = transforms 1147 | except UnboundLocalError: 1148 | pass 1149 | if self.label_type == 'database': 1150 | dataset['labeled_labels'] = [dblabels[x] for x in np.where(ind_labeled_labels)[0]] 1151 | else: 1152 | dataset['labeled_labels'] = [x[ind_labeled_labels, :] for x in labels] 1153 | if 'label_cov' in locals(): 1154 | dataset['labeled_label_covariances'] = [x[ind_labeled_labels, :].astype(np.float32) 1155 | for x in label_cov] 1156 | dataset['labeled_features'] = index_features(features, ind_labeled_features) 1157 | 1158 | # find equivalent datasets with most samples 1159 | unlabeled_groups = [x for it, x in enumerate(group2indices) 1160 | if not np.intersect1d(x, ind_labeled_features).size] 1161 | ndata = features[0][feature_inds[0]][:, ind_n_data_points] 1162 | ind_unique_unlabled = [x[ndata[x].argmax()] for x in unlabeled_groups] 1163 | dataset['unlabeled_features'] = index_features(features, ind_unique_unlabled) 1164 | 1165 | # close h5py pscorr file and clean workspace 1166 | del features, group2indices 1167 | try: 1168 | del labels 1169 | except NameError: 1170 | del dblabels 1171 | if 'label_cov' in locals(): 1172 | del label_cov 1173 | 1174 | # remove inf columns 1175 | print('Cleaning data of infs...') 1176 | inf_col = [ind for ind, x in enumerate(feature_labels) if x == 'SASICA snr'][0] 1177 | feature_labels = np.delete(feature_labels, inf_col) 1178 | dataset['unlabeled_features'] = np.delete(dataset['unlabeled_features'], inf_col, axis=1) 1179 | dataset['labeled_features'] = np.delete(dataset['labeled_features'], inf_col, axis=1) 1180 | 1181 | # remove nan total_rows 1182 | print('Cleaning data of nans...') 1183 | # unlabeled 1184 | unlabeled_not_nan_inf_index = np.logical_not( 1185 | np.logical_or(np.isnan(dataset['unlabeled_features']).any(axis=1), 1186 | np.isinf(dataset['unlabeled_features']).any(axis=1))) 1187 | dataset['unlabeled_features'] = \ 1188 | dataset['unlabeled_features'][unlabeled_not_nan_inf_index, :] 1189 | # labeled 1190 | labeled_not_nan_inf_index = np.logical_not(np.logical_or(np.isnan(dataset['labeled_features']).any(axis=1), 1191 | np.isinf(dataset['labeled_features']).any(axis=1))) 1192 | dataset['labeled_features'] = dataset['labeled_features'][labeled_not_nan_inf_index, :] 1193 | if self.label_type == 'database': 1194 | dataset['labeled_labels'] = [dataset['labeled_labels'][x] 1195 | for x in np.where(labeled_not_nan_inf_index)[0]] 1196 | else: 1197 | dataset['labeled_labels'] = [x[labeled_not_nan_inf_index, :] for x in dataset['labeled_labels']] 1198 | if 'labeled_label_covariances' in dataset.keys(): 1199 | dataset['labeled_label_covariances'] = [x[labeled_not_nan_inf_index, :, :] 1200 | for x in dataset['labeled_label_covariances']] 1201 | if not self.unique: 1202 | dataset['unlabeled_duplicates'] = dataset['unlabeled_duplicates'][unlabeled_not_nan_inf_index] 1203 | dataset['labeled_duplicates'] = dataset['labeled_duplicates'][labeled_not_nan_inf_index] 1204 | 1205 | # save feature labels (names, e.g. psd) 1206 | dataset['feature_labels'] = feature_labels 1207 | 1208 | # save the results 1209 | print('Saving aggregated dataset...') 1210 | joblib.dump(dataset, join(self.datapath, 'cache', processed_file_name), 0) 1211 | 1212 | # print time 1213 | total = time() - start 1214 | print('Time to load: ' + strftime("%H:%M:%S", gmtime(total)) + 1215 | ':' + np.mod(total, 1).astype(str)[2:5] + '\t(HH:MM:SS:sss)') 1216 | 1217 | return dataset 1218 | 1219 | def load_semi_supervised(self): 1220 | """ 1221 | Load the ICL dataset where only a fraction of data points are labeled. 1222 | Follows the settings provided during initializations 1223 | :return: (train set unlabeled, train set labeled, sample test set (unlabeled), validation set (labeled), 1224 | output labels) 1225 | """ 1226 | 1227 | rng = np.random.RandomState(seed=self.seed) 1228 | start = time() 1229 | 1230 | # get data 1231 | icl = self.load_data() 1232 | 1233 | # copy full dataset 1234 | icl['unlabeled_features'] = \ 1235 | OrderedDict([(key, icl['unlabeled_features'][:, ind]) for key, ind 1236 | in self.train_feature_indices.iteritems() if key in self.features]) 1237 | icl['labeled_features'] = \ 1238 | OrderedDict([(key, icl['labeled_features'][:, ind]) for key, ind 1239 | in self.train_feature_indices.iteritems() if key in self.features]) 1240 | 1241 | # set ids to int 1242 | icl['unlabeled_features']['ids'] = icl['unlabeled_features']['ids'].astype(int) 1243 | icl['labeled_features']['ids'] = icl['labeled_features']['ids'].astype(int) 1244 | 1245 | # decide how to split into train / validation / test 1246 | # validation set of random labeled components for overfitting / convergence estimation 1247 | try: 1248 | valid_ind = rng.choice(icl['labeled_features']['ids'].shape[0], size=100, replace=False) 1249 | except: 1250 | valid_ind = rng.choice(icl['labeled_features']['ids'].shape[0], size=100, replace=True) 1251 | # random unlabeled datasets for manual analysis 1252 | test_datasets = rng.choice(np.unique(icl['unlabeled_features']['ids'][:, 0]), 1253 | size=self.n_test_datasets, replace=False) 1254 | test_ind = np.where(np.array([x == icl['unlabeled_features']['ids'][:, 0] for x in test_datasets]).any(0))[0] 1255 | 1256 | # normalize other features 1257 | if 'topo' in self.features: 1258 | print('Normalizing topo features...') 1259 | icl['unlabeled_features']['topo'], pca = self.normalize_topo_features(icl['unlabeled_features']['topo']) 1260 | icl['labeled_features']['topo'] = self.normalize_topo_features(icl['labeled_features']['topo'], pca)[0] 1261 | 1262 | # normalize psd features 1263 | if 'psd' in self.features: 1264 | print('Normalizing psd features...') 1265 | icl['unlabeled_features']['psd'] = self.normalize_psd_features(icl['unlabeled_features']['psd']) 1266 | icl['labeled_features']['psd'] = self.normalize_psd_features(icl['labeled_features']['psd']) 1267 | 1268 | # normalize psd_var features 1269 | if 'psd_var' in self.features: 1270 | print('Normalizing psd_var features...') 1271 | icl['unlabeled_features']['psd_var'] = self.normalize_psd_features(icl['unlabeled_features']['psd_var']) 1272 | icl['labeled_features']['psd_var'] = self.normalize_psd_features(icl['labeled_features']['psd_var']) 1273 | 1274 | # normalize psd_kurt features 1275 | if 'psd_kurt' in self.features: 1276 | print('Normalizing psd_kurt features...') 1277 | icl['unlabeled_features']['psd_kurt'] = self.normalize_psd_features(icl['unlabeled_features']['psd_kurt']) 1278 | icl['labeled_features']['psd_kurt'] = self.normalize_psd_features(icl['labeled_features']['psd_kurt']) 1279 | 1280 | # normalize psd_kurt features 1281 | if 'autocorr' in self.features: 1282 | print('Normalizing autocorr features...') 1283 | icl['unlabeled_features']['autocorr'] = self.normalize_autocorr_features( 1284 | icl['unlabeled_features']['autocorr']) 1285 | icl['labeled_features']['autocorr'] = self.normalize_autocorr_features(icl['labeled_features']['autocorr']) 1286 | 1287 | # normalize dipole features 1288 | if 'dipole' in self.features: 1289 | print('Normalizing dipole features...') 1290 | icl['unlabeled_features']['dipole'] = self.normalize_dipole_features(icl['unlabeled_features']['dipole']) 1291 | icl['labeled_features']['dipole'] = self.normalize_dipole_features(icl['labeled_features']['dipole']) 1292 | 1293 | # normalize handcrafted features 1294 | if 'handcrafted' in self.features: 1295 | print('Normalizing hand-crafted features...') 1296 | icl['unlabeled_features']['handcrafted'] = \ 1297 | self.normalize_handcrafted_features(icl['unlabeled_features']['handcrafted'], 1298 | icl['unlabeled_features']['ids'][:, 1]) 1299 | icl['labeled_features']['handcrafted'] = self.normalize_handcrafted_features( 1300 | icl['labeled_features']['handcrafted'], icl['labeled_features']['ids'][:, 1]) 1301 | 1302 | # normalize mi features 1303 | if 'mi' in self.features: 1304 | print('Normalizing mi features...') 1305 | icl['unlabeled_features']['mi'] = self.normalize_mi_features(icl['unlabeled_features']['mi']) 1306 | icl['labeled_features']['mi'] = self.normalize_mi_features(icl['labeled_features']['mi']) 1307 | 1308 | # recast labels 1309 | if self.label_type == 'database': 1310 | pass 1311 | else: 1312 | icl['labeled_labels'] = [x.astype(np.float32) for x in icl['labeled_labels']] 1313 | if 'labeled_label_covariances' in icl.keys(): 1314 | icl['labeled_label_covariances'] = [x.astype(np.float32) for x in icl['labeled_label_covariances']] 1315 | 1316 | # separate data into train, validation, and test sets 1317 | print('Splitting and shuffling data...') 1318 | # unlabeled training set 1319 | ind = rng.permutation(np.setdiff1d(range(icl['unlabeled_features']['ids'].shape[0]), test_ind)) 1320 | x_u = OrderedDict([(key, val[ind]) for key, val in icl['unlabeled_features'].iteritems()]) 1321 | y_u = None 1322 | # labeled training set 1323 | ind = rng.permutation(np.setdiff1d(range(icl['labeled_features']['ids'].shape[0]), valid_ind)) 1324 | x_l = OrderedDict([(key, val[ind]) for key, val in icl['labeled_features'].iteritems()]) 1325 | if self.label_type == 'database': 1326 | print(icl['labeled_labels'][0]) 1327 | y_l = [icl['labeled_labels'][x] for x in ind] 1328 | else: 1329 | y_l = [x[ind] for x in icl['labeled_labels']] 1330 | if 'labeled_label_covariances' in icl.keys(): 1331 | c_l = [x[ind] for x in icl['labeled_label_covariances']] 1332 | # validation set. 1333 | rng.shuffle(valid_ind) 1334 | x_v = OrderedDict([(key, val[valid_ind]) for key, val in icl['labeled_features'].iteritems()]) 1335 | if self.label_type == 'database': 1336 | y_v = [icl['labeled_labels'][x] for x in valid_ind] 1337 | else: 1338 | y_v = [x[valid_ind] for x in icl['labeled_labels']] 1339 | if 'labeled_label_covariances' in icl.keys(): 1340 | c_v = [x[valid_ind] for x in icl['labeled_label_covariances']] 1341 | # unlabeled test set. 1342 | rng.shuffle(test_ind) 1343 | x_t = OrderedDict([(key, val[test_ind]) for key, val in icl['unlabeled_features'].iteritems()]) 1344 | y_t = None 1345 | 1346 | train_u = (x_u, y_u) 1347 | if 'labeled_label_covariances' in icl.keys(): 1348 | train_l = (x_l, y_l, c_l) 1349 | else: 1350 | train_l = (x_l, y_l) 1351 | test = (x_t, y_t) 1352 | if 'labeled_label_covariances' in icl.keys(): 1353 | val = (x_v, y_v, c_v) 1354 | else: 1355 | val = (x_v, y_v) 1356 | 1357 | # print time 1358 | total = time() - start 1359 | print('Time to load: ' + strftime("%H:%M:%S", gmtime(total)) + 1360 | ':' + np.mod(total, 1).astype(str)[2:5] + '\t(HH:MM:SS:sss)') 1361 | 1362 | return train_u, train_l, test, val, \ 1363 | ('train_unlabeled', 'train_labeled', 'test', 'validation', 'labels') 1364 | 1365 | def load_test_data(self, process_features=True): 1366 | """ 1367 | Load the ICL test dataset used in the publication. 1368 | Follows the settings provided during initializations. 1369 | :param process_features: Whether to preprocess/normalize features. 1370 | :return: (features, labels, channel_features) 1371 | """ 1372 | 1373 | # check for files and download if missing 1374 | self.check_for_download(('test_labels', 'test_features')) 1375 | 1376 | # load features 1377 | with h5py.File(join(self.datapath, 'features', 'features_testset_full.mat'), 'r') as f: 1378 | features = np.asarray(f['features']).T 1379 | feature_labels = self.__load_matlab_cellstr(f, 'feature_label') 1380 | channel_features = [] 1381 | for dataset in f['channel_features'].value.flatten(): 1382 | # expand 1383 | dataset = f[dataset].value.flatten() 1384 | # expand and format 1385 | id = f[dataset[0]].value.flatten() 1386 | chans = [''.join(map(unichr, f[x].value.flatten())) for x in f[dataset[1]].value.flatten()] 1387 | icamat = f[dataset[2]].value.T 1388 | # append 1389 | channel_features.append([id, chans, icamat[:, :3], icamat[:, 3:]]) 1390 | 1391 | # load labels 1392 | with open(join(self.datapath, 'labels', 'ICLabels_test.pkl'), 'rb') as f: 1393 | labels = pkl.load(f) 1394 | 1395 | # match features and labels 1396 | _, _, ind = np.intersect1d(labels['instance_id'], labels['instance_number'], return_indices=True) 1397 | label_id = np.stack((labels['instance_study_numbers'][ind], 1398 | labels['instance_set_numbers'][ind], 1399 | labels['instance_ic_numbers'][ind]), axis=1) 1400 | feature_id = features[:, :3].astype(int) 1401 | match = self.__match_indices(label_id, feature_id) 1402 | features = features[match[1, :][match[0, :]], :] 1403 | 1404 | # remove inf columns 1405 | print('Cleaning data of infs...') 1406 | inf_col = [ind for ind, x in enumerate(feature_labels) if x == 'SASICA snr'][0] 1407 | feature_labels = np.delete(feature_labels, inf_col) 1408 | features = np.delete(features, inf_col, axis=1) 1409 | 1410 | # convert to ordered dict 1411 | features = \ 1412 | OrderedDict([(key, features[:, ind]) for key, ind 1413 | in self.test_feature_indices.iteritems() if key in self.features]) 1414 | 1415 | # process features 1416 | if process_features: 1417 | 1418 | # normalize other features 1419 | if 'topo' in self.features: 1420 | print('Normalizing topo features...') 1421 | features['topo'] = self.normalize_topo_features(features['topo']) 1422 | 1423 | # normalize psd features 1424 | if 'psd' in self.features: 1425 | print('Normalizing psd features...') 1426 | features['psd'] = self.normalize_psd_features(features['psd']) 1427 | 1428 | # normalize psd_var features 1429 | if 'psd_var' in self.features: 1430 | print('Normalizing psd_var features...') 1431 | features['psd_var'] = self.normalize_psd_features(features['psd_var']) 1432 | 1433 | # normalize psd_kurt features 1434 | if 'psd_kurt' in self.features: 1435 | print('Normalizing psd_kurt features...') 1436 | features['psd_kurt'] = self.normalize_psd_features(features['psd_kurt']) 1437 | 1438 | # normalize psd_kurt features 1439 | if 'autocorr' in self.features: 1440 | print('Normalizing autocorr features...') 1441 | features['autocorr'] = self.normalize_autocorr_features(features['autocorr']) 1442 | 1443 | # normalize dipole features 1444 | if 'dipole' in self.features: 1445 | print('Normalizing dipole features...') 1446 | features['dipole'] = self.normalize_dipole_features(features['dipole']) 1447 | 1448 | # normalize handcrafted features 1449 | if 'handcrafted' in self.features: 1450 | print('Normalizing hand-crafted features...') 1451 | features['handcrafted'] = self.normalize_handcrafted_features(features['handcrafted'], 1452 | features['ids'][:, 1]) 1453 | 1454 | return features, labels, channel_features 1455 | 1456 | def load_channel_features(self): 1457 | # load features 1458 | with h5py.File(join(self.datapath, 'features', 'features_ICAChanlocs.mat'), 'r') as f: 1459 | ids, chans, xyz, icamats = [], [], [], [] 1460 | for dataset in f['features_out'].value.flatten(): 1461 | # expand 1462 | dataset = f[dataset].value.flatten() 1463 | if np.array_equal(dataset, np.zeros(2)): 1464 | continue 1465 | # expand and format 1466 | ids.append(f[dataset[0]].value.flatten()) 1467 | chans.append([''.join(map(unichr, f[x].value.flatten())) for x in f[dataset[1]].value.flatten()]) 1468 | icamat = f[dataset[2]].value.T 1469 | xyz.append(icamat[:, :3]) 1470 | icamats.append(icamat[:, 3:]) 1471 | 1472 | return ids, chans, xyz, icamats 1473 | 1474 | def load_classifications(self, n_cls, ids=None): 1475 | """ 1476 | Load classification of the ICLabel training set by several published and publicly available IC classifiers. 1477 | Classifiers included are MARA, ADJUST, FASTER, IC_MARC, and EyeCatch. MARA, and FASTER are only included in 1478 | the 2 class case. ADJUST is also included in the 3-class case. IC_MARC and EyeCatch are included in all 1479 | cases. Note that EyeCatch only has two classes (Eye and Not-Eye) but does not follow the patter of label 1480 | conflation used for the other classifiers as it has not Brain IC class. 1481 | :param n_cls: How many IC classes to consider. Must be 2, 3, or 5. 1482 | :param ids: If only a subset of ICs are desired, the relevant IC IDs may be passed here as an (n by 2) ndarray. 1483 | :return: Dictionary of classifications separated by classifier. 1484 | """ 1485 | # check inputs 1486 | assert n_cls in (2, 3, 5), 'n_cls must be 2, 3, or 5' 1487 | 1488 | # load raw classifications 1489 | raw = self._load_classifications(ids) 1490 | 1491 | # format and limit to number of desired classes 1492 | # 2: brain, other 1493 | # 3: brain, eye, other 1494 | # 5: brain, muscle, eye, heart, other 1495 | # exception for eye_catch which is always [eye] where eye >= 0.93 is the threshold for detection 1496 | classifications = {} 1497 | for cls, lab in raw.iteritems(): 1498 | if cls == 'adjust': 1499 | if n_cls == 2: 1500 | non_brain = raw[cls].max(1, keepdims=True) 1501 | classifications[cls] = np.concatenate((1 - non_brain, non_brain), 1) 1502 | elif n_cls == 3: 1503 | brain = 1 - raw[cls].max(1, keepdims=True) 1504 | eye = raw[cls][:, :-1].max(1, keepdims=True) 1505 | other = raw[cls][:, -1:] 1506 | classifications[cls] = np.concatenate((brain, eye, other), 1) 1507 | elif cls == 'mara': 1508 | if n_cls == 2: 1509 | classifications[cls] = np.concatenate((1 - raw[cls], raw[cls]), 1) 1510 | elif cls == 'faster': 1511 | if n_cls == 2: 1512 | classifications[cls] = np.concatenate((1 - raw[cls], raw[cls]), 1) 1513 | elif cls == 'ic_marc': # ['blink', 'neural', 'heart', 'lat. eye', 'muscle', 'mixed'] 1514 | brain = raw[cls][:, 1:2] 1515 | if n_cls == 2: 1516 | classifications[cls] = np.concatenate((brain, 1 - brain), 1) 1517 | elif n_cls == 3: 1518 | eye = raw[cls][:, [0, 3]].sum(1, keepdims=True) 1519 | other = raw[cls][:, [2, 4, 5]].sum(1, keepdims=True) 1520 | classifications[cls] = np.concatenate((brain, eye, other), 1) 1521 | elif n_cls == 5: 1522 | muscle = raw[cls][:, 4:5] 1523 | eye = raw[cls][:, [0, 3]].sum(1, keepdims=True) 1524 | heart = raw[cls][:, 2:3] 1525 | other = raw[cls][:, 5:] 1526 | classifications[cls] = np.concatenate((brain, muscle, eye, heart, other), 1) 1527 | elif cls == 'eye_catch': 1528 | classifications[cls] = raw[cls] 1529 | else: 1530 | raise UserWarning('Unknown classifier: {}'.format(cls)) 1531 | 1532 | # return 1533 | return classifications 1534 | 1535 | def _load_classifications(self, ids=None): 1536 | 1537 | # check for files and download if missing 1538 | self.check_for_download('classifications') 1539 | 1540 | # load classifications 1541 | classifications = {} 1542 | with h5py.File(join(self.datapath, 'other', 'other_classifiers.mat'), 'r') as f: 1543 | print('Loading classifications...') 1544 | for cls, lab in f.iteritems(): 1545 | classifications[cls] = lab[:].T 1546 | 1547 | # match to given ids 1548 | if ids is not None: 1549 | for cls, lab in classifications.iteritems(): 1550 | _, ind_id, ind_lab = np.intersect1d((ids * [100, 1]).sum(1), (lab[:, :2].astype(int) * [100, 1]).sum(1), 1551 | return_indices=True) 1552 | classifications[cls] = np.empty((ids.shape[0], lab.shape[1] - 2)) 1553 | classifications[cls][:] = np.nan 1554 | classifications[cls][ind_id] = lab[ind_lab, 2:] 1555 | 1556 | return classifications 1557 | 1558 | def generate_cache(self, refresh=False): 1559 | """ 1560 | Generate all possible training set cache files to speed up later requests. 1561 | :param refresh: If true, deletes previous cache files. Otherwise only missing cache files will be generated. 1562 | """ 1563 | 1564 | if refresh: 1565 | rmtree(join(self.datapath, 'cache')) 1566 | os.mkdir(join(self.datapath, 'cache')) 1567 | 1568 | urexpert = copy(self.label_type) 1569 | for label_type in ('luca', 'all', 'database'): 1570 | self.label_type = label_type 1571 | self.load_data() 1572 | self.label_type = urexpert 1573 | 1574 | @staticmethod 1575 | def _download(url, filename, attempts=3): 1576 | chunk_size = 256 * 1024 1577 | 1578 | for _ in range(attempts): 1579 | try: 1580 | # open connection to server 1581 | with requests.get(url, stream=True) as r: 1582 | r.raise_for_status() 1583 | total_size_in_bytes = int(r.headers.get('content-length', 0)) 1584 | # check if file is already downloaded 1585 | if os.path.exists(filename) and os.stat(filename).st_size == total_size_in_bytes: 1586 | print("File already downloaded. Skipping.") 1587 | return 1588 | # set up progress bar 1589 | with tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) as progress_bar: 1590 | # open file for writing 1591 | with open(filename, 'wb') as f: 1592 | for chunk in r.iter_content(chunk_size=chunk_size): 1593 | progress_bar.update(len(chunk)) 1594 | f.write(chunk) 1595 | 1596 | # check that file downloaded completely 1597 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 1598 | raise requests.exceptions.RequestException("Incomplete download.") 1599 | else: 1600 | break 1601 | except requests.exceptions.RequestException: 1602 | pass 1603 | else: 1604 | # all attempts failed 1605 | raise requests.exceptions.RequestException("Download failed.") 1606 | 1607 | def download_trainset_cllabels(self): 1608 | """ 1609 | Download labels for the ICLabel training set. 1610 | """ 1611 | print('Downloading individual ICLabel training set CL label files...') 1612 | folder = 'labels' 1613 | if not isdir(join(self.datapath, folder)): 1614 | os.mkdir(join(self.datapath, folder)) 1615 | for it, url in enumerate(self.label_train_urls): 1616 | print('Downloading label file {} of {}...'.format(it, len(self.label_train_urls))) 1617 | self._download(url, join(self.datapath, folder, basename(url))) 1618 | 1619 | def download_trainset_features(self): 1620 | """ 1621 | Download features for the ICLabel training set. 1622 | """ 1623 | folder = 'features' 1624 | base_filename = join(self.datapath, folder, 'features') 1625 | n_files = 25 1626 | 1627 | # check if files have already been downloaded 1628 | for it, url in enumerate(self.feature_train_urls): 1629 | if not isfile(join(self.datapath, folder, basename(url))): 1630 | break 1631 | else: 1632 | print('Feature files already downloaded.') 1633 | return 1634 | print('Caution: this download is approximately 25GB and requires twice that space on your drive for unzipping!') 1635 | 1636 | print('Downloading zipped ICLabel training set features...') 1637 | if not isdir(join(self.datapath, folder)): 1638 | os.mkdir(join(self.datapath, folder)) 1639 | for it in range(n_files): 1640 | print('Downloading file part {} of {}...'.format(it + 1, n_files)) 1641 | zip_name = base_filename + '{:02d}.zip'.format(it) 1642 | self._download(self.feature_train_zip_parts_url.format(it), zip_name) 1643 | 1644 | print('Combining file parts...') 1645 | with open(base_filename + '.zip', 'wb') as f: 1646 | for it in range(n_files): 1647 | with open(base_filename + '{:02d}.zip'.format(it), 'rb') as f_part: 1648 | f.write(f_part.read()) 1649 | for it in range(n_files): 1650 | os.remove(base_filename + '{:02d}.zip'.format(it)) 1651 | 1652 | print('Extracting zipped ICLabel training set features...') 1653 | from zipfile import ZipFile 1654 | with ZipFile(base_filename + '.zip') as myzip: 1655 | myzip.extractall(path=join(self.datapath, folder)) 1656 | print('Deleting zip archive...') 1657 | os.remove(base_filename + '.zip') 1658 | 1659 | def download_testset_cllabels(self): 1660 | """ 1661 | Download labels for the ICLabel test set. 1662 | """ 1663 | print('Downloading ICLabel test set CL label files...') 1664 | folder = 'labels' 1665 | if not isdir(join(self.datapath, folder)): 1666 | os.mkdir(join(self.datapath, folder)) 1667 | self._download(self.label_test_url, join(self.datapath, folder, 'ICLabels_test.pkl')) 1668 | 1669 | def download_testset_features(self): 1670 | """ 1671 | Download features for the ICLabel test set. 1672 | """ 1673 | print('Downloading ICLabel test set features...') 1674 | folder = 'features' 1675 | if not isdir(join(self.datapath, folder)): 1676 | os.mkdir(join(self.datapath, folder)) 1677 | self._download(self.feature_test_url, join(self.datapath, folder, 'features_testset_full.mat')) 1678 | 1679 | def download_database(self): 1680 | """ 1681 | Download anonymized ICLabel website database. 1682 | """ 1683 | print('Downloading anonymized ICLabel website database...') 1684 | folder = 'labels' 1685 | if not isdir(join(self.datapath, folder)): 1686 | os.mkdir(join(self.datapath, folder)) 1687 | self._download(self.db_url, join(self.datapath, folder, 'database.sqlite')) 1688 | 1689 | def download_icclassifications(self): 1690 | """ 1691 | Download precalculated classification for several publicly available IC classifiers. 1692 | """ 1693 | print('Downloading classifications for some publicly available classifiers...') 1694 | folder = 'other' 1695 | if not isdir(join(self.datapath, folder)): 1696 | os.mkdir(join(self.datapath, folder)) 1697 | self._download(self.cls_url, join(self.datapath, folder, 'other_classifiers.mat')) 1698 | 1699 | def check_for_download(self, data_type): 1700 | """ 1701 | Check if something has been downloaded and, if not, get it. 1702 | :param data_type: What data to check for. Can be: train_labels, train_features, test_labels, test_features, 1703 | database, and/or 'classifications'. 1704 | """ 1705 | 1706 | if not isinstance(data_type, (tuple, list)): 1707 | data_type = [data_type] 1708 | 1709 | for val in data_type: 1710 | if val == 'train_labels': 1711 | for it, url in enumerate(self.label_train_urls): 1712 | if not isfile(join(self.datapath, 'labels', basename(url))): 1713 | self.download_trainset_cllabels() 1714 | elif val == 'train_features': 1715 | for it, url in enumerate(self.feature_train_urls): 1716 | assert isfile(join(self.datapath, 'features', basename(url))), \ 1717 | 'Missing training feature file "' + basename(url) + '" and possibly others. ' \ 1718 | 'It is a large download which you may accomplish through calling the method ' \ 1719 | '"download_trainset_features()".' 1720 | elif val == 'test_labels': 1721 | if not isfile(join(self.datapath, 'labels', 'ICLabels_test.pkl')): 1722 | self.download_testset_cllabels() 1723 | elif val == 'test_features': 1724 | if not isfile(join(self.datapath, 'features', 'features_testset_full.mat')): 1725 | self.download_testset_features() 1726 | elif val == 'database': 1727 | if not isfile(join(self.datapath, 'labels', 'database.sqlite')): 1728 | self.download_database() 1729 | elif val == 'classifications': 1730 | if not isfile(join(self.datapath, 'other', 'other_classifiers.mat')): 1731 | self.download_icclassifications() 1732 | 1733 | 1734 | # data normalization 1735 | 1736 | @staticmethod 1737 | def _clip_and_rescale(vec, min, max): 1738 | return (np.clip(vec, min, max) - min) * 2. / (max - min) - 1 1739 | 1740 | @staticmethod 1741 | def _unscale(vec, min, max): 1742 | return (vec + 1) * (max-min) / 2 + min 1743 | 1744 | @staticmethod 1745 | def normalize_dipole_features(data): 1746 | """ 1747 | Normalize dipole features. 1748 | :param data: dipole features 1749 | :return: normalized dipole features 1750 | """ 1751 | 1752 | # indices 1753 | ind_dipole_pos = np.array([1, 2, 3, 8, 9, 10, 14, 15, 16]) 1754 | ind_dipole1_mom = np.array([4, 5, 6]) 1755 | ind_dipole2_mom = np.array([11, 12, 13, 17, 18, 19]) 1756 | ind_rv = np.array([0, 7]) 1757 | 1758 | # normalize dipole positions 1759 | data[:, ind_dipole_pos] /= 100 1760 | # clip dipole position 1761 | max_dist = 1.5 1762 | data[:, ind_dipole_pos] = np.clip(data[:, ind_dipole_pos], -max_dist, max_dist) / max_dist 1763 | # normalize single dipole moments 1764 | data[:, ind_dipole1_mom] /= np.abs(data[:, ind_dipole1_mom]).max(1, keepdims=True) 1765 | # normalize double dipole moments 1766 | data[:, ind_dipole2_mom] /= np.abs(data[:, ind_dipole2_mom]).max(1, keepdims=True) 1767 | # center residual variance 1768 | data[:, ind_rv] = data[:, ind_rv] * 2 - 1 1769 | return data.astype(np.float32) 1770 | 1771 | def normalize_topo_features(self, data, pca=None): 1772 | """ 1773 | Normalize scalp topography features. 1774 | :param data: scalp topography features 1775 | :param pca: A PCA matrix to use if for the test set if do_pca was set to true in __init__. 1776 | :return: (normalized dipole features, pca matrix or None) 1777 | """ 1778 | # apply pca 1779 | if self.do_pca: 1780 | if pca is None: 1781 | pca = PCA(whiten=True) 1782 | pca.fit_transform(data) 1783 | else: 1784 | data = pca.transform(data) 1785 | 1786 | # clip extreme values 1787 | data = np.clip(data, -2, 2) 1788 | 1789 | else: 1790 | # normalize to norm 1 1791 | data /= np.linalg.norm(data, axis=1, keepdims=True) 1792 | 1793 | return data.astype(np.float32), pca 1794 | 1795 | def normalize_psd_features(self, data): 1796 | """ 1797 | Normalize power spectral density features. 1798 | :param data: power spectral density features 1799 | :return: normalized power spectral density features 1800 | """ 1801 | 1802 | # undo notch filter 1803 | for linenoise_ind in (49, 59): 1804 | notch_ind = ( 1805 | data[:, [linenoise_ind - 1, linenoise_ind + 1]] - data[:, linenoise_ind, np.newaxis] > 5).all(1) 1806 | data[notch_ind, linenoise_ind] = data[notch_ind][:, [linenoise_ind - 1, linenoise_ind + 1]].mean(1) 1807 | 1808 | # divide by max abs 1809 | data /= np.amax(np.abs(data), axis=1, keepdims=True) 1810 | 1811 | return data.astype(np.float32) 1812 | 1813 | @staticmethod 1814 | def normalize_autocorr_features(data): 1815 | """ 1816 | Normalize autocorrelation function features. 1817 | :param data: autocorrelation function features 1818 | :return: normalized autocorrelation function features 1819 | """ 1820 | # normalize to max of 1 1821 | data[data > 1] = 1 1822 | return data.astype(np.float32) 1823 | 1824 | def normalize_handcrafted_features(self, data, ic_nums): 1825 | """ 1826 | Normalize hand crafted features. 1827 | :param data: hand crafted features 1828 | :param data: ic indices when sorted by power within their respective datasets. The 2nd ID number can be used for 1829 | this in the training dataset 1830 | :return: normalized handcrafted features 1831 | """ 1832 | # autocorreclation 1833 | data[:, 0] = self._clip_and_rescale(data[:, 0], -0.5, 1.) 1834 | # SASICA focal topo 1835 | data[:, 1] = self._clip_and_rescale(data[:, 1], 1.5, 12.) 1836 | # SASICA snr REMOVED 1837 | # SASICA ic variance 1838 | data[:, 2] = self._clip_and_rescale(np.log(data[:, 2]), -6., 7.) 1839 | # ADJUST diff_var 1840 | data[:, 3] = self._clip_and_rescale(data[:, 3], -0.05, 0.06) 1841 | # ADJUST Temporal Kurtosis 1842 | data[:, 4] = self._clip_and_rescale(np.tanh(data[:, 4]), -0.5, 1.) 1843 | # ADJUST Spatial Eye Difference 1844 | data[:, 5] = self._clip_and_rescale(data[:, 5], 0., 0.4) 1845 | # ADJUST spatial average difference 1846 | data[:, 6] = self._clip_and_rescale(data[:, 6], -0.2, 0.25) 1847 | # ADJUST General Discontinuity Spatial Feature 1848 | # ADJUST maxvar/meanvar 1849 | data[:, 8] = self._clip_and_rescale(data[:, 8], 1., 20.) 1850 | # FASTER Median gradient value 1851 | data[:, 9] = self._clip_and_rescale(data[:, 9], -0.2, 0.2) 1852 | # FASTER Kurtosis of spatial map 1853 | data[:, 10] = self._clip_and_rescale(data[:, 10], -50., 100.) 1854 | # FASTER Hurst exponent 1855 | data[:, 11] = self._clip_and_rescale(data[:, 11], -0.2, 0.2) 1856 | # number of channels 1857 | # number of ICs 1858 | # ic number relative to number of channels 1859 | ic_rel = self._clip_and_rescale(ic_nums * 1. / data[:, 13], 0., 1.) 1860 | # topoplot plot radius 1861 | data[:, 12] = self._clip_and_rescale(data[:, 14], 0.5, 1) 1862 | # epoched? 1863 | # sampling rate 1864 | # number of data points 1865 | 1866 | return np.hstack((data[:, :13], ic_rel.reshape(-1, 1))).astype(np.float32) 1867 | 1868 | # plotting functions 1869 | 1870 | @staticmethod 1871 | def _plot_grid(data, function): 1872 | nax = data.shape[0] 1873 | a = np.ceil(np.sqrt(nax)).astype(np.int) 1874 | b = np.ceil(1. * nax / a).astype(np.int) 1875 | f, axarr = plt.subplots(a, b, sharex='col', sharey='row') 1876 | axarr = axarr.flatten() 1877 | for x in range(nax): 1878 | function(data[x], axis=axarr[x]) 1879 | axarr[x].set_title(str(x)) 1880 | 1881 | def pad_topo(self, data): 1882 | """ 1883 | Reshape scalp topography images features and pad with zeros to make 32x32 pixel images. 1884 | :param data: Scalp topography features as provided by load_data() and load_semisupervised_data(). 1885 | :return: Padded scalp topography images. 1886 | """ 1887 | if data.ndim == 1: 1888 | ntopo = 1 1889 | else: 1890 | ntopo = data.shape[0] 1891 | topos = np.zeros((ntopo, 32 * 32)) 1892 | topos[:, self.topo_ind] = data 1893 | topos = topos.reshape(-1, 32, 32).transpose(0, 2, 1) 1894 | return np.squeeze(topos) 1895 | 1896 | def plot_topo(self, data, axis=plt): 1897 | """ 1898 | Plot an IC scalp topography. 1899 | :param data: Scalp topography vector (unpadded). 1900 | :param axis: Optional matplotlib axis in which to plot. 1901 | """ 1902 | topo = self.pad_topo(data) 1903 | topo = np.flipud(topo) 1904 | maxabs = np.abs(data).max() 1905 | axis.matshow(topo, cmap='jet', aspect='equal', vmin=-maxabs, vmax=maxabs) 1906 | 1907 | def plot_topo_grid(self, data): 1908 | """ 1909 | Plot a grid of IC scalp topographies. 1910 | :param data: Matrix of scalp topography vectors (unpadded). 1911 | """ 1912 | if data.ndim == 1: 1913 | self.plot_topo(data) 1914 | else: 1915 | nax = data.shape[0] 1916 | if nax == 740: 1917 | data = data.T 1918 | nax = data.shape[0] 1919 | if nax > self.max_grid_plot: 1920 | print 'Too many plots requested.' 1921 | return 1922 | 1923 | self._plot_grid(data, self.plot_topo) 1924 | 1925 | def plot_psd(self, data, axis=plt): 1926 | """ 1927 | Plot an IC power spectral density. 1928 | :param data: Power spectral density vector. 1929 | :param axis: Optional matplotlib axis in which to plot. 1930 | """ 1931 | if self.psd_limits is not None: 1932 | data = self._unscale(data, *self.psd_limits) 1933 | if self.psd_mean is not None: 1934 | data = data + self.psd_mean 1935 | axis.plot(self.psd_ind[:data.flatten().shape[0]], data.flatten()) 1936 | 1937 | def plot_psd_grid(self, data): 1938 | """ 1939 | Plot a grid of IC power spectral densities. 1940 | :param data: Matrix of power spectral density vectors. 1941 | """ 1942 | if data.ndim == 1: 1943 | self.plot_psd(data) 1944 | else: 1945 | nax = data.shape[0] 1946 | if nax > self.max_grid_plot: 1947 | print 'Too many plots requested.' 1948 | return 1949 | 1950 | self._plot_grid(data, self.plot_psd) 1951 | 1952 | @staticmethod 1953 | def plot_autocorr(data, axis=plt): 1954 | """ 1955 | Plot an IC autocorrelation function. 1956 | :param data: autocorrelation function vector. 1957 | :param axis: Optional matplotlib axis in which to plot. 1958 | """ 1959 | axis.plot(np.linspace(0, 1, 101)[1:], data.flatten()) 1960 | 1961 | def plot_autocorr_grid(self, data): 1962 | """ 1963 | Plot a grid of IC autocorrelation functions. 1964 | :param data: Matrix of autocorrelation function vectors. 1965 | """ 1966 | if data.ndim == 1: 1967 | self.plot_autocorr(data) 1968 | else: 1969 | nax = data.shape[0] 1970 | if nax > self.max_grid_plot: 1971 | print 'Too many plots requested.' 1972 | return 1973 | 1974 | self._plot_grid(data, self.plot_autocorr) 1975 | 1976 | def web_image(self, component_id): 1977 | """ 1978 | Open the component properties image from the ICLabel website (iclabel.ucsd.edu) for an IC. Not all ICs have 1979 | images available. 1980 | :param component_id: ID for the component which can be either 2 or 3 numbers if from the training set or test 1981 | set, respectively. 1982 | """ 1983 | if len(component_id) == 2: 1984 | wb.open_new_tab(self.base_url_image + '{0:0>6}_{1:0>3}.png'.format(*component_id)) 1985 | elif len(component_id) == 3: 1986 | wb.open_new_tab(self.base_url_image + '{0:0>2}_{1:0>2}_{2:0>3}.png'.format(*component_id)) 1987 | else: 1988 | raise ValueError('component_id must have 2 or 3 elements.') 1989 | --------------------------------------------------------------------------------