├── README.md ├── conf ├── criterion │ ├── baseline_criterion.yaml │ ├── feature_extract_criterion.yaml │ ├── finetune_criterion.yaml │ ├── pretrain_masked_criterion.py │ └── pretrain_masked_criterion.yaml ├── data │ ├── all_electrodes_single_subject.yaml │ ├── finetuning.yaml │ ├── finetuning_template.yaml │ ├── masked_morelet_dataset.yaml │ ├── masked_spec.yaml │ ├── masked_superlet_dataset.yaml │ ├── masked_tf_dataset.yaml │ ├── morelet_unmasked.yaml │ ├── movie_finetuning_superlet.yaml │ ├── onset_finetuning.yaml │ ├── onset_finetuning_superlet.yaml │ ├── pretrain_wavs_from_disk.yaml │ ├── pretraining.yaml │ ├── pretraining_template.yaml │ ├── single_electrode_all_subjects.yaml │ ├── single_subject_single_trial.yaml │ ├── single_subject_trial_electrode.yaml │ ├── speech_finetuning.yaml │ ├── speech_finetuning_superlet.yaml │ ├── superlet_unmasked.yaml │ ├── tf_unmasked.yaml │ ├── timestamp_data.yaml │ ├── timestamp_data_superlet.yaml │ └── uniform_division.yaml ├── data_prep │ ├── all_electrodes_wavs_to_disk.yaml │ ├── cwt_to_disk.yaml │ ├── remove_subjects.yaml │ ├── single_subject_all_electrode_superlet.yaml │ ├── superlet_to_disk.yaml │ ├── throw_out_zeros.yaml │ ├── wav_to_disk.yaml │ ├── write_pretrain_split.yaml │ └── write_train_split.yaml ├── exp │ ├── feature_extract.yaml │ ├── finetune.yaml │ ├── make_aligned_data_caches.yaml │ ├── seeg_wav2vec.yaml │ └── spec2vec.yaml ├── model │ ├── debug_finetune_model.yaml │ ├── deep_linear_wav_baseline.yaml │ ├── feature_extract_deep_model.yaml │ ├── feature_extract_hidden_model.yaml │ ├── feature_extract_model.yaml │ ├── finetune_model.yaml │ ├── hidden_linear_wav_baseline.yaml │ ├── linear_wav_baseline.yaml │ ├── masked_tf_model.yaml │ ├── masked_tf_model_base.yaml │ ├── masked_tf_model_large.yaml │ ├── masked_tf_model_small.yaml │ ├── morelet2vec_model.yaml │ ├── seeg_wav2vec.yaml │ ├── spec_pooled_model.yaml │ ├── superlet2vec_model.yaml │ └── superlet_finetune_model.yaml ├── plot │ ├── plot_all_electrode_tf.yaml │ ├── plot_connectivity.yaml │ ├── plot_embed.yaml │ ├── plot_sentence_dynamics.yaml │ ├── plot_superlet_embed.yaml │ └── write_region_embedding.yaml ├── preprocessor │ ├── spec_pooled.yaml │ ├── stft.yaml │ ├── stft_pooled_preprocessor.yaml │ ├── stft_pretrained.yaml │ ├── superlet.yaml │ ├── superlet_pooled_preprocessor.yaml │ ├── superlet_pretrained.yaml │ └── wav_preprocessor.yaml ├── task │ ├── baseline_task.yaml │ ├── baseline_wav_task.yaml │ ├── feature_extract.yaml │ ├── finetune_task.yaml │ ├── fixed_mask_pretrain.yaml │ └── variable_mask_pretrain.yaml └── test │ ├── debug_test.yaml │ ├── effective_dim.yaml │ ├── effective_dim_raw_spec.yaml │ ├── fewshot_test.yaml │ ├── held_out_subjects.yaml │ ├── process_linear_results.yaml │ └── single_electrode_test.yaml ├── criterions ├── __init__.py ├── base_criterion.py ├── baseline_criterion.py ├── feature_extract_criterion.py ├── finetune_criterion.py ├── pretrain_masked_criterion.py └── seeg_wav2vec_criterion.py ├── data ├── __init__.py ├── corrupted_elec.json ├── create_data_dirs.py ├── edf2h5.py ├── electrode_selection.py ├── electrode_subject_data.py ├── h5_data.py ├── h5_data_reader.py ├── make_aligned_data_caches.py ├── modify_manifest.py ├── pretrain_split_trials.json ├── speech_nonspeech_subject_data.py ├── subject_data.py ├── test_split_trials.json ├── throw_out_zeros.py ├── timestamped_subject_data.py ├── trial_data.py ├── trial_data_reader.py ├── utils.py ├── write_data_to_disk.py ├── write_preprocessed_inputs.py └── write_pretrain_data_wavs.py ├── datasets ├── __init__.py ├── base_tf_dataset.py ├── finetuning_datasets.py ├── masked_tf_dataset.py ├── pretraining_multi_elec_wavs_in_mem.py ├── pretraining_wavs_in_mem.py ├── raw_wav_file_dataset.py ├── single_subject_all_electrode.py ├── tf_unmasked.py └── utils.py ├── linear_results └── onset_finetuning │ └── linear_results.json ├── models ├── __init__.py ├── base_model.py ├── deep_linear_wav_baseline.py ├── feature_extract_deep_model.py ├── feature_extract_hidden.py ├── feature_extract_model.py ├── finetune_model.py ├── hidden_linear_wav_model.py ├── linear_spec_baseline.py ├── linear_wav_baseline.py ├── masked_tf_model.py ├── seeg_wav2vec.py ├── spec_prediction_head.py └── transformer_encoder_input.py ├── notebooks ├── demo.ipynb └── example_wav_1.npy ├── preprocessors ├── __init__.py ├── morelet_preprocessor.py ├── spec_pooled.py ├── spec_pretrained.py ├── stft.py ├── superlet.py ├── superlet_preprocessor.py └── wav_preprocessor.py ├── pretrain ├── __init__.py └── spec2vec │ └── spec2vec.py ├── pretrain_data └── manifests │ └── manifest.tsv ├── requirements.txt ├── run_tests.py ├── run_train.py ├── runner.py ├── schedulers ├── __init__.py ├── base_scheduler.py ├── ramp_up.py └── reduce_on_plateau.py ├── tasks ├── __init__.py ├── base_task.py ├── baseline_wav_task.py ├── batch_utils.py ├── feature_extract_task.py ├── finetune_task.py ├── seeg_wav_task.py ├── spec_pretrain.py └── utils.py ├── testing ├── collect_dataset_stats.py ├── effective_dimensionality.py ├── make_finetuning_datasets_stats.py ├── process_linear_results.py ├── run_fewshot_training_tests.py ├── run_single_electrode_tests.py ├── select_fewshot_learning_electrode.py └── utils.py └── util ├── mask_utils.py └── tensorboard_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # BrainBERT 2 | 3 | BrainBERT is an modeling approach for learning self-supervised representations of intracranial electrode data. See [paper](https://arxiv.org/abs/2302.14367) for details. 4 | 5 | We provide the training pipeline below. 6 | 7 | The trained weights have been released (see below) and pre-training data can be found at [braintreebank.dev](https://braintreebank.dev) 8 | 9 | ## Installation 10 | Requirements: 11 | - pytorch >= 1.12.1 12 | - [pytorch gradual warmup scheduler](https://github.com/ildoonet/pytorch-gradual-warmup-lr) 13 | 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ### Input 19 | It is expected that the input is intracranial electrode data that has been Laplacian re-referenced. 20 | 21 | ## Using BrainBERT embeddings 22 | - pretrained weights are available [here](https://drive.google.com/file/d/14ZBOafR7RJ4A6TsurOXjFVMXiVH6Kd_Q/view?usp=sharing) 23 | - see `notebooks/demo.ipynb` for an example input and example embedding 24 | 25 | ## Upstream 26 | ### BrainBERT pre-training data 27 | The data directory should be structured as: 28 | ``` 29 | /pretrain_data 30 | |_manifests 31 | |_manifests.tsv <-- each line contains the path to the example and the length 32 | |_ 33 | |_ 34 | |_.npy 35 | ``` 36 | If using the data from the Brain Treebank, the data can be written using this command: 37 | ``` 38 | python3 -m data.write_pretrain_data_wavs +data=pretraining_template.yaml \ 39 | +data_prep=write_pretrain_split ++data.duration=5 \ 40 | ++data_prep.pretrain_split=/storage/czw/BrainBERT/data/pretrain_split_trials.json 41 | ++data_prep.out_dir=pretrain_data \ 42 | ++data.raw_brain_data_dir=/path/to/braintreebank_data/ 43 | ``` 44 | This command expects the Brain Treebank data to have the following structure: 45 | ``` 46 | /braintreebank_data 47 | |_electrode_labels 48 | |_subject_metadata 49 | |_localization 50 | |_all_subject_data 51 | |_sub_*_trial*.h5 52 | ``` 53 | 54 | ### BrainBERT pre-training 55 | ``` 56 | python3 run_train.py +exp=spec2vec ++exp.runner.device=cuda ++exp.runner.multi_gpu=True \ 57 | ++exp.runner.num_workers=64 +data=masked_spec +model=masked_tf_model_large \ 58 | +data.data=/path/to/data ++data.val_split=0.01 +task=fixed_mask_pretrain.yaml \ 59 | +criterion=pretrain_masked_criterion +preprocessor=stft ++data.test_split=0.01 \ 60 | ++task.freq_mask_p=0.05 ++task.time_mask_p=0.05 ++exp.runner.total_steps=500000 61 | ``` 62 | Example parameters: 63 | ``` 64 | /path/to/data = /storage/user123/self_supervised_seeg/pretrain_data/manifests 65 | ``` 66 | -------------------------------------------------------------------------------- /conf/criterion/baseline_criterion.yaml: -------------------------------------------------------------------------------- 1 | name: baseline_criterion 2 | -------------------------------------------------------------------------------- /conf/criterion/feature_extract_criterion.yaml: -------------------------------------------------------------------------------- 1 | name: feature_extract_criterion 2 | -------------------------------------------------------------------------------- /conf/criterion/finetune_criterion.yaml: -------------------------------------------------------------------------------- 1 | name: finetune_criterion 2 | -------------------------------------------------------------------------------- /conf/criterion/pretrain_masked_criterion.py: -------------------------------------------------------------------------------- 1 | name: pretrain_masked_criterion 2 | -------------------------------------------------------------------------------- /conf/criterion/pretrain_masked_criterion.yaml: -------------------------------------------------------------------------------- 1 | name: pretrain_masked_criterion 2 | alpha: 2 #how much more important are the non-zero portions? 3 | -------------------------------------------------------------------------------- /conf/data/all_electrodes_single_subject.yaml: -------------------------------------------------------------------------------- 1 | name: single_subject_all_electrode 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | #subject: subject4 7 | subject: subject3 8 | #brain_runs: ['trial000'] 9 | brain_runs: ['trial000'] 10 | onsets_only: True 11 | #electrodes: ['F2Ia2'] 12 | electrodes: ['F2Ia2', 'F2Ia3', 'F2Ia4', 'F2Ia5', 'F2Ia6', 'F2Ia7', 'F2Ia8', 'F2Ia9', 'F2Ia10', 'F2Ia11', 'F2Ia12', 'F2Ia13', 'F2Ia14', 'F2Ia15', 'F3c2', 'F3c3', 'F3c4', 'F3c5', 'F3c6', 'F3c7', 'F3c8', 'F3c9', 'F3b2', 'F3b3', 'F3b4', 'F3b5', 'F3b6', 'F3b7', 'F3aOF2', 'F3aOF3', 'F3aOF4', 'F3aOF5', 'F3aOF6', 'F3aOF7', 'F3d2', 'F3d3', 'F3d4', 'F3d5', 'F3d6', 'F3d7', 'F3d8', 'F3d9', 'T1aIc2', 'T1aIc3', 'T1aIc4', 'T1aIc5', 'P2a2', 'P2a3', 'P2a4', 'P2a5', 'P2a6', 'P2a7', 'P2a8', 'P2a9', 'T1b2', 'T1b3', 'T1b4', 'T1b5', 'T1cIe2', 'T1cIe3', 'T1cIe4', 'T1cIe5', 'T1cIe6', 'T1cIe7', 'T1cIe8', 'T1cIe9', 'T1cIe10', 'T1cIe11', 'P2b2', 'P2b3', 'P2b4', 'P2b5', 'P2b6', 'P2b7', 'P2b8', 'P2b9', 'P2b10', 'P2b11', 'P2b12', 'P2b13', 'P2b14', 'P2b15', 'O1aIb2', 'O1aIb3', 'O1aIb4', 'O1aIb5', 'O1aIb6', 'O1aIb7', 'O1aIb8', 'O1aIb9', 'O1aIb10', 'O1aIb11', 'O1aIb12', 'O1aIb13', 'O1aIb14', 'O1aIb15', 'O1bId2', 'O1bId3', 'O1bId4', 'O1bId5', 'O1bId6', 'O1bId7', 'O1bId8', 'O1bId9', 'O1bId10', 'O1bId11', 'O1bId12', 'O1bId13', 'O1bId14', 'O1bId15'] 13 | rereference: laplacian 14 | normalization: False 15 | despike: False 16 | delta: -1.5 17 | duration: 3.0 18 | words: [] 19 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 20 | #cached_transcript_aligns: None 21 | val_split: 0.1 22 | preprocessor: stft 23 | -------------------------------------------------------------------------------- /conf/data/finetuning.yaml: -------------------------------------------------------------------------------- 1 | name: movie_finetuning 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | #subject: subject4 7 | subject: subject3 8 | #brain_runs: ['trial000'] 9 | brain_runs: ['trial000', 'trial001'] 10 | electrodes: ['T1b2'] 11 | #electrodes: ['T1Id6'] 12 | rereference: laplacian 13 | normalization: False 14 | despike: False 15 | delta: -1.5 16 | duration: 3.0 17 | words: [] 18 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 19 | #cached_transcript_aligns: None 20 | val_split: 0.1 21 | preprocessor: stft 22 | -------------------------------------------------------------------------------- /conf/data/finetuning_template.yaml: -------------------------------------------------------------------------------- 1 | name: onset_finetuning 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | subject: ??? 6 | electrodes: ??? 7 | brain_runs: ??? 8 | rereference: laplacian 9 | normalization: False 10 | despike: False 11 | delta: -2.5 12 | duration: 5.0 13 | interval_duration: 1.0 14 | words: [] 15 | cached_transcript_aligns: /storage/czw/BrainBERT/semantics/saved_aligns 16 | val_split: 0.1 17 | test_split: 0.1 18 | preprocessor: stft 19 | -------------------------------------------------------------------------------- /conf/data/masked_morelet_dataset.yaml: -------------------------------------------------------------------------------- 1 | name: masked_tf_dataset 2 | val_split: 0.1 3 | data: /storage/czw/self_supervised_seeg/all_day_data/manifests 4 | cached_features: /storage/czw/self_supervised_seeg/all_day_data_morelet/manifests 5 | preprocessor: morelet 6 | -------------------------------------------------------------------------------- /conf/data/masked_spec.yaml: -------------------------------------------------------------------------------- 1 | name: masked_tf_dataset 2 | val_split: 0.1 3 | data: ??? 4 | -------------------------------------------------------------------------------- /conf/data/masked_superlet_dataset.yaml: -------------------------------------------------------------------------------- 1 | name: masked_tf_dataset 2 | val_split: 0.01 3 | data: /storage/czw/self_supervised_seeg/pretrain_data/manifests 4 | cached_features: /storage/czw/self_supervised_seeg/pretrain_superlet/manifests 5 | preprocessor: superlet 6 | -------------------------------------------------------------------------------- /conf/data/masked_tf_dataset.yaml: -------------------------------------------------------------------------------- 1 | name: masked_tf_dataset 2 | val_split: 0.1 3 | data: ??? 4 | preprocessor: stft 5 | -------------------------------------------------------------------------------- /conf/data/morelet_unmasked.yaml: -------------------------------------------------------------------------------- 1 | preprocessor: morelet 2 | name: tf_unmasked #just the tf representations, no masking 3 | val_split: 0.1 4 | data: ??? 5 | -------------------------------------------------------------------------------- /conf/data/movie_finetuning_superlet.yaml: -------------------------------------------------------------------------------- 1 | name: movie_finetuning 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | subject: subject4 7 | electrodes: ['T1Id6'] 8 | brain_runs: ['trial000', 'trial001'] 9 | rereference: laplacian 10 | normalization: False 11 | despike: False 12 | delta: -0.5 13 | duration: 3.0 14 | words: [] 15 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 16 | #cached_transcript_aligns: None 17 | val_split: 0.1 18 | preprocessor: superlet 19 | cache_input_features: /storage/czw/self_supervised_seeg/cached_input_features/movie_finetuning_superlet_features 20 | -------------------------------------------------------------------------------- /conf/data/onset_finetuning.yaml: -------------------------------------------------------------------------------- 1 | name: onset_finetuning 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | #subject: subject4 7 | subject: subject3 8 | #electrodes: ['T1Id6'] 9 | electrodes: ['T1b2'] 10 | brain_runs: ['trial000'] 11 | rereference: laplacian 12 | normalization: False 13 | despike: False 14 | delta: -2.5 15 | duration: 5.0 16 | interval_duration: 1.0 17 | words: [] 18 | cached_data_arrays: /storage/czw/self_supervised_seeg/cached_input_features/onset_finetuning/ 19 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 20 | #cached_transcript_aligns: None 21 | val_split: 0.1 22 | test_split: 0.1 23 | preprocessor: stft 24 | -------------------------------------------------------------------------------- /conf/data/onset_finetuning_superlet.yaml: -------------------------------------------------------------------------------- 1 | name: onset_finetuning 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | subject: subject4 7 | electrodes: ['T1Id6'] 8 | brain_runs: ['trial000'] 9 | rereference: laplacian 10 | normalization: False 11 | despike: False 12 | delta: -0.5 13 | duration: 3.0 14 | words: [] 15 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 16 | #cached_transcript_aligns: None 17 | val_split: 0.1 18 | preprocessor: superlet 19 | cache_input_features: /storage/czw/self_supervised_seeg/cached_input_features/onset_finetuning_superlet_features 20 | -------------------------------------------------------------------------------- /conf/data/pretrain_wavs_from_disk.yaml: -------------------------------------------------------------------------------- 1 | name: raw_wav_file_dataset 2 | val_split: 0.1 3 | data: ??? 4 | -------------------------------------------------------------------------------- /conf/data/pretraining.yaml: -------------------------------------------------------------------------------- 1 | name: pretraining_wavs_in_mem 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | #subject: subject4 7 | subject: subject3 8 | #electrodes: ['T1Id6'] 9 | electrodes: ['T1b2'] 10 | brain_runs: ['trial000'] 11 | rereference: laplacian 12 | normalization: False 13 | despike: False 14 | duration: 3.0 15 | -------------------------------------------------------------------------------- /conf/data/pretraining_template.yaml: -------------------------------------------------------------------------------- 1 | name: pretraining_wavs_multi_elec_in_mem 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | subject: ??? 6 | electrodes: ??? 7 | brain_runs: ??? 8 | rereference: laplacian 9 | normalization: False 10 | despike: False 11 | duration: 3.0 12 | -------------------------------------------------------------------------------- /conf/data/single_electrode_all_subjects.yaml: -------------------------------------------------------------------------------- 1 | name: single_subject_all_electrode 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | #subject: subject4 7 | subject: subject3 8 | #brain_runs: ['trial000'] 9 | brain_runs: ['trial000'] 10 | #electrodes: ['F2Ia2'] 11 | rereference: laplacian 12 | normalization: False 13 | despike: False 14 | delta: -1.5 15 | duration: 3.0 16 | words: [] 17 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 18 | #cached_transcript_aligns: None 19 | val_split: 0.1 20 | preprocessor: stft 21 | -------------------------------------------------------------------------------- /conf/data/single_subject_single_trial.yaml: -------------------------------------------------------------------------------- 1 | name: finetuning_sentence_position 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | #subject: subject4 7 | subject: subject3 8 | #brain_runs: ['trial000'] 9 | brain_runs: ['trial000'] 10 | #electrodes: ['F2Ia2'] 11 | rereference: laplacian 12 | normalization: False 13 | despike: False 14 | delta: -1.0 15 | duration: 2.0 16 | words: [] 17 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 18 | #cached_transcript_aligns: None 19 | val_split: 0.1 20 | preprocessor: stft 21 | -------------------------------------------------------------------------------- /conf/data/single_subject_trial_electrode.yaml: -------------------------------------------------------------------------------- 1 | name: pretraining_wavs_in_mem 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | #subject: subject4 7 | electrodes: ['T1b2'] 8 | subject: subject3 9 | #brain_runs: ['trial000'] 10 | brain_runs: ['trial000'] 11 | #electrodes: ['F2Ia2'] 12 | rereference: laplacian 13 | normalization: False 14 | despike: False 15 | delta: -0.75 16 | duration: 1.5 17 | words: [] 18 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 19 | #cached_transcript_aligns: None 20 | val_split: 0.1 21 | preprocessor: stft 22 | -------------------------------------------------------------------------------- /conf/data/speech_finetuning.yaml: -------------------------------------------------------------------------------- 1 | name: speech_finetuning 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | #subject: subject4 7 | subject: subject3 8 | electrodes: ['T1b2'] 9 | #electrodes: ['T1Id6'] 10 | brain_runs: ['trial000'] 11 | #brain_runs: ['trial000', 'trial001'] 12 | rereference: laplacian 13 | normalization: False 14 | despike: False 15 | delta: -1.5 16 | duration: 3.0 17 | words: [] 18 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 19 | #cached_transcript_aligns: None 20 | val_split: 0.1 21 | preprocessor: stft 22 | -------------------------------------------------------------------------------- /conf/data/speech_finetuning_superlet.yaml: -------------------------------------------------------------------------------- 1 | name: speech_finetuning 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 5 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 6 | subject: subject4 7 | electrodes: ['T1Id6'] 8 | brain_runs: ['trial000', 'trial001'] 9 | rereference: laplacian 10 | normalization: False 11 | despike: False 12 | delta: -0.5 13 | duration: 3.0 14 | words: [] 15 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 16 | #cached_transcript_aligns: None 17 | val_split: 0.1 18 | preprocessor: superlet 19 | cache_input_features: /storage/czw/self_supervised_seeg/cached_input_features/speech_finetuning_superlet_features 20 | -------------------------------------------------------------------------------- /conf/data/superlet_unmasked.yaml: -------------------------------------------------------------------------------- 1 | preprocessor: superlet 2 | name: tf_unmasked #just the tf representations, no masking 3 | #val_split: 0.1 4 | data: ??? 5 | -------------------------------------------------------------------------------- /conf/data/tf_unmasked.yaml: -------------------------------------------------------------------------------- 1 | name: tf_unmasked #just the tf representations, no masking 2 | val_split: 0.1 3 | data: ??? 4 | -------------------------------------------------------------------------------- /conf/data/timestamp_data.yaml: -------------------------------------------------------------------------------- 1 | name: timestamp_finetuning 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data 5 | #raw_brain_data_dir: /storage/datasets/neuroscience/ecog 6 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 7 | subject: subject4 8 | #subject: subject3 9 | electrodes: ['T1Id6'] 10 | #electrodes: ['T1b2'] 11 | brain_runs: ['trial001'] 12 | rereference: None 13 | normalization: False 14 | despike: False 15 | delta: -0.5 16 | duration: 3.0 17 | words: [] 18 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 19 | #cached_transcript_aligns: None 20 | val_split: 0.1 21 | -------------------------------------------------------------------------------- /conf/data/timestamp_data_superlet.yaml: -------------------------------------------------------------------------------- 1 | name: timestamp_finetuning 2 | high_gamma: False 3 | samp_frequency: 2048 4 | raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data 5 | subject: subject4 6 | electrodes: ['T1Id6'] 7 | brain_runs: ['trial001'] 8 | rereference: None 9 | normalization: False 10 | despike: False 11 | delta: -0.5 12 | duration: 3.0 13 | words: [] 14 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 15 | #cached_transcript_aligns: None 16 | val_split: 0.1 17 | preprocessor: superlet 18 | cache_input_features: /storage/czw/self_supervised_seeg/cached_input_features/timestamp_finetuning_superlet_features 19 | -------------------------------------------------------------------------------- /conf/data/uniform_division.yaml: -------------------------------------------------------------------------------- 1 | #the data from a movie, chopped into 5s 2 | name: uniform_finetuning 3 | high_gamma: False 4 | samp_frequency: 2048 5 | raw_brain_data_dir: /storage/datasets/neuroscience/ecog 6 | #raw_brain_data_dir: /storage/czw/LanguageEcog/semantics/all_day_data/ 7 | #subject: subject4 8 | subject: subject3 9 | #electrodes: ['T1Id6'] 10 | electrodes: ['T1b2'] 11 | brain_runs: ['trial000'] 12 | rereference: laplacian 13 | normalization: False 14 | despike: False 15 | delta: -2.5 16 | duration: 5.0 17 | interval_duration: 1.0 18 | words: [] 19 | cached_data_arrays: /storage/czw/self_supervised_seeg/cached_input_features/onset_finetuning/ 20 | cached_transcript_aligns: /storage/czw/LanguageEcog/semantics/saved_aligns 21 | #cached_transcript_aligns: None 22 | val_split: 0.1 23 | test_split: 0.1 24 | preprocessor: stft 25 | 26 | -------------------------------------------------------------------------------- /conf/data_prep/all_electrodes_wavs_to_disk.yaml: -------------------------------------------------------------------------------- 1 | out_dir: all_electrode_data 2 | brain_runs: ['trial000'] 3 | -------------------------------------------------------------------------------- /conf/data_prep/cwt_to_disk.yaml: -------------------------------------------------------------------------------- 1 | out_dir: all_day_data_morelet 2 | brain_runs: ['trial000', 'trial001', 'trial002'] 3 | -------------------------------------------------------------------------------- /conf/data_prep/remove_subjects.yaml: -------------------------------------------------------------------------------- 1 | subjects_to_remove: ["subject1"] 2 | out_dir: ??? 3 | -------------------------------------------------------------------------------- /conf/data_prep/single_subject_all_electrode_superlet.yaml: -------------------------------------------------------------------------------- 1 | out_dir: all_electrode_data_superlet 2 | brain_runs: ['trial000'] 3 | -------------------------------------------------------------------------------- /conf/data_prep/superlet_to_disk.yaml: -------------------------------------------------------------------------------- 1 | out_dir: pretrain_superlet 2 | n_workers: 64 3 | #brain_runs: ['trial000', 'trial001', 'trial002'] 4 | -------------------------------------------------------------------------------- /conf/data_prep/throw_out_zeros.yaml: -------------------------------------------------------------------------------- 1 | n_workers: 32 2 | -------------------------------------------------------------------------------- /conf/data_prep/wav_to_disk.yaml: -------------------------------------------------------------------------------- 1 | out_dir: all_day_data 2 | brain_runs: ['trial000', 'trial001', 'trial002'] 3 | -------------------------------------------------------------------------------- /conf/data_prep/write_pretrain_split.yaml: -------------------------------------------------------------------------------- 1 | out_dir: pretrain_data 2 | pretrain_split: /storage/czw/self_supervised_seeg/data/pretrain_split_trials.json 3 | -------------------------------------------------------------------------------- /conf/data_prep/write_train_split.yaml: -------------------------------------------------------------------------------- 1 | out_dir: all_electrode_data 2 | brain_runs: ['trial000'] 3 | -------------------------------------------------------------------------------- /conf/exp/feature_extract.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | lr: 1e-3 3 | optim: AdamW 4 | train_batch_size: 64 5 | valid_batch_size: 128 6 | shuffle: False 7 | multi_gpu: False 8 | device: cuda 9 | total_steps: 1000 10 | #world_size: 4 11 | num_workers: 0 12 | log_step: 100 13 | checkpoint_step: 100 14 | grad_clip: 1.0 15 | output_tb: False 16 | #start_from_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-05-13/02-28-15/checkpoint_last.pth 17 | scheduler: 18 | name: reduce_on_plateau 19 | #name: ramp_up 20 | total_steps: ${exp.runner.total_steps} 21 | -------------------------------------------------------------------------------- /conf/exp/finetune.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | lr: 1e-3 3 | optim: AdamW_finetune 4 | train_batch_size: 64 5 | valid_batch_size: 128 6 | shuffle: False 7 | multi_gpu: True 8 | device: cuda 9 | total_steps: 1000 10 | #world_size: 4 11 | num_workers: 32 12 | log_step: 100 13 | checkpoint_step: 100 14 | grad_clip: 1.0 15 | output_tb: False 16 | #start_from_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-05-13/02-28-15/checkpoint_last.pth 17 | scheduler: 18 | name: reduce_on_plateau 19 | #name: ramp_up 20 | total_steps: ${exp.runner.total_steps} 21 | -------------------------------------------------------------------------------- /conf/exp/make_aligned_data_caches.yaml: -------------------------------------------------------------------------------- 1 | test_split_path: "/storage/czw/self_supervised_seeg/data/test_split_trials.json" 2 | -------------------------------------------------------------------------------- /conf/exp/seeg_wav2vec.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | lr: 1e-4 3 | optim: Adam 4 | train_batch_size: 16 5 | valid_batch_size: 128 6 | shuffle: False 7 | device: cuda 8 | total_steps: 400000 9 | #total_steps: 40000 10 | #world_size: 4 11 | num_workers: 16 12 | log_step: 500 13 | checkpoint_step: 1000 14 | grad_clip: 5.0 15 | multi_gpu: False 16 | #start_from_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-05-13/02-28-15/checkpoint_last.pth 17 | scheduler: 18 | #name: reduce_on_plateau 19 | name: ramp_up 20 | total_steps: ${exp.runner.total_steps} 21 | warmup: 0.07 22 | task: 23 | name: seeg_wav_task 24 | criterion: 25 | name: seeg_wav2vec_criterion 26 | 27 | -------------------------------------------------------------------------------- /conf/exp/spec2vec.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | lr: 1e-4 3 | optim: LAMB 4 | train_batch_size: 256 5 | valid_batch_size: 256 6 | shuffle: False 7 | multi_gpu: True 8 | device: cuda 9 | total_steps: 400000 10 | #world_size: 4 11 | num_workers: 16 12 | log_step: 1000 13 | checkpoint_step: 1000 14 | grad_clip: 1.0 15 | #start_from_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-05-13/02-28-15/checkpoint_last.pth 16 | scheduler: 17 | #name: reduce_on_plateau 18 | name: ramp_up 19 | total_steps: ${exp.runner.total_steps} 20 | warmup: 0.025 21 | -------------------------------------------------------------------------------- /conf/model/debug_finetune_model.yaml: -------------------------------------------------------------------------------- 1 | name: finetune_model 2 | hidden_dim: 768 3 | layer_dim_feedforward: 3072 4 | layer_activation: gelu 5 | nhead: 12 6 | encoder_num_layers: 3 7 | input_dim: 25 8 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-05-28/01-20-42/checkpoint_last.pth #trained weights 9 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-01/03-45-13/checkpoint_last.pth #random weights 10 | upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-07-19/17-25-17/checkpoint_last.pth #trained on all subject3 electrodes 11 | frozen_upstream: True 12 | -------------------------------------------------------------------------------- /conf/model/deep_linear_wav_baseline.yaml: -------------------------------------------------------------------------------- 1 | name: deep_linear_wav_baseline 2 | 3 | -------------------------------------------------------------------------------- /conf/model/feature_extract_deep_model.yaml: -------------------------------------------------------------------------------- 1 | name: feature_extract_deep_model 2 | input_dim: 768 3 | -------------------------------------------------------------------------------- /conf/model/feature_extract_hidden_model.yaml: -------------------------------------------------------------------------------- 1 | name: feature_extract_hidden_model 2 | input_dim: 768 3 | frozen_upstream: True 4 | -------------------------------------------------------------------------------- /conf/model/feature_extract_model.yaml: -------------------------------------------------------------------------------- 1 | name: feature_extract_model 2 | input_dim: 768 3 | -------------------------------------------------------------------------------- /conf/model/finetune_model.yaml: -------------------------------------------------------------------------------- 1 | name: finetune_model 2 | hidden_dim: 768 3 | #layer_dim_feedforward: 3072 4 | #layer_activation: gelu 5 | #nhead: 12 6 | #encoder_num_layers: 3 7 | input_dim: 40 8 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-17/02-03-52/checkpoint_last.pth #trained weights 9 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-05-28/01-20-42/checkpoint_last.pth #trained weights 10 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-14/13-14-18/checkpoint_last.pth #random weights 11 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-17/18-00-23/checkpoint_last.pth #trained_weights 12 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-07-19/17-25-17/checkpoint_last.pth #trained on all subject3 electrodes 13 | upstream_ckpt: /storage/czw/self_supervised_seeg/pretrained_weights/stft_large_pretrained.pth #trained on all subject3 electrodes 14 | frozen_upstream: False 15 | -------------------------------------------------------------------------------- /conf/model/hidden_linear_wav_baseline.yaml: -------------------------------------------------------------------------------- 1 | name: hidden_linear_wav_baseline 2 | 3 | -------------------------------------------------------------------------------- /conf/model/linear_wav_baseline.yaml: -------------------------------------------------------------------------------- 1 | name: linear_wav_baseline 2 | 3 | -------------------------------------------------------------------------------- /conf/model/masked_tf_model.yaml: -------------------------------------------------------------------------------- 1 | name: masked_tf_model 2 | hidden_dim: 768 3 | layer_dim_feedforward: 3072 4 | layer_activation: gelu 5 | nhead: 12 6 | encoder_num_layers: 3 7 | input_dim: 40 8 | -------------------------------------------------------------------------------- /conf/model/masked_tf_model_base.yaml: -------------------------------------------------------------------------------- 1 | name: masked_tf_model 2 | hidden_dim: 768 3 | layer_dim_feedforward: 3072 4 | layer_activation: gelu 5 | nhead: 12 6 | encoder_num_layers: 3 7 | input_dim: 40 8 | -------------------------------------------------------------------------------- /conf/model/masked_tf_model_large.yaml: -------------------------------------------------------------------------------- 1 | name: masked_tf_model 2 | hidden_dim: 768 3 | layer_dim_feedforward: 3072 4 | layer_activation: gelu 5 | nhead: 12 6 | encoder_num_layers: 6 7 | input_dim: 40 8 | -------------------------------------------------------------------------------- /conf/model/masked_tf_model_small.yaml: -------------------------------------------------------------------------------- 1 | name: masked_tf_model 2 | hidden_dim: 384 3 | layer_dim_feedforward: 1200 4 | layer_activation: gelu 5 | nhead: 12 6 | encoder_num_layers: 3 7 | input_dim: 40 8 | -------------------------------------------------------------------------------- /conf/model/morelet2vec_model.yaml: -------------------------------------------------------------------------------- 1 | name: debug_model 2 | hidden_dim: 768 3 | layer_dim_feedforward: 3072 4 | layer_activation: gelu 5 | nhead: 12 6 | encoder_num_layers: 3 7 | input_dim: 48 8 | -------------------------------------------------------------------------------- /conf/model/seeg_wav2vec.yaml: -------------------------------------------------------------------------------- 1 | name: seeg_wav2vec 2 | hidden_dim: 768 3 | layer_dim_feedforward: 3072 4 | layer_activation: gelu 5 | nhead: 12 6 | encoder_num_layers: 3 7 | input_dim: 45 8 | -------------------------------------------------------------------------------- /conf/model/spec_pooled_model.yaml: -------------------------------------------------------------------------------- 1 | name: feature_extract_model 2 | input_dim: 40 3 | 4 | -------------------------------------------------------------------------------- /conf/model/superlet2vec_model.yaml: -------------------------------------------------------------------------------- 1 | name: masked_tf_model 2 | hidden_dim: 768 3 | layer_dim_feedforward: 3072 4 | layer_activation: gelu 5 | nhead: 12 6 | encoder_num_layers: 6 7 | input_dim: 40 8 | -------------------------------------------------------------------------------- /conf/model/superlet_finetune_model.yaml: -------------------------------------------------------------------------------- 1 | name: debug_finetune_model 2 | hidden_dim: 768 3 | layer_dim_feedforward: 3072 4 | layer_activation: gelu 5 | nhead: 12 6 | encoder_num_layers: 3 7 | input_dim: 50 8 | upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-30/19-26-27/checkpoint_last.pth #trained_weights 9 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-23/01-17-59/checkpoint_last.pth #trained_weights 10 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-25/01-33-50/checkpoint_last.pth #random weights 11 | #TODO check whether losses improve between trained and random 12 | frozen_upstream: True 13 | -------------------------------------------------------------------------------- /conf/plot/plot_all_electrode_tf.yaml: -------------------------------------------------------------------------------- 1 | output_dir: /storage/czw/self_supervised_seeg/outputs/plotting_outputs/all_tf 2 | -------------------------------------------------------------------------------- /conf/plot/plot_connectivity.yaml: -------------------------------------------------------------------------------- 1 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-07-19/17-25-17/checkpoint_last.pth #trained on all subject3 electrodes 2 | upstream_ckpt: /storage/czw/self_supervised_seeg/pretrained_weights/stft_large_pretrained.pth #trained on all subject3 electrodes 3 | output_dir: /storage/czw/self_supervised_seeg/outputs/plotting_outputs/connectivity_plots 4 | -------------------------------------------------------------------------------- /conf/plot/plot_embed.yaml: -------------------------------------------------------------------------------- 1 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-09/20-54-06/checkpoint_last.pth #trained weights 2 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-14/13-14-18/checkpoint_last.pth #random weights 3 | upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-17/18-00-23/checkpoint_last.pth #trained weights 4 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-05-28/01-20-42/checkpoint_last.pth #trained weights 5 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-14/14-52-04/checkpoint_last.pth #trained weights 6 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-16/16-08-39/checkpoint_last.pth #trained weights 7 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-17/02-03-52/checkpoint_last.pth #trained weights 8 | dim_reduce: tsne 9 | -------------------------------------------------------------------------------- /conf/plot/plot_sentence_dynamics.yaml: -------------------------------------------------------------------------------- 1 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-07-19/17-25-17/checkpoint_last.pth #trained on all subject3 electrodes 2 | upstream_ckpt: /storage/czw/self_supervised_seeg/pretrained_weights/stft_large_pretrained.pth 3 | #upstream_ckpt: /storage/czw/self_supervised_seeg/pretrained_weights/random_large_model.pth 4 | output_dir: /storage/czw/self_supervised_seeg/outputs/plotting_outputs/sentence_dynamic_plots 5 | #output_dir: /storage/czw/self_supervised_seeg/outputs/plotting_outputs/sentence_dynamic_plots_random 6 | -------------------------------------------------------------------------------- /conf/plot/plot_superlet_embed.yaml: -------------------------------------------------------------------------------- 1 | #upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-23/01-17-59/checkpoint_last.pth #trained weights 2 | upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-06-25/01-33-50/checkpoint_last.pth #random weights 3 | dim_reduce: tsne 4 | -------------------------------------------------------------------------------- /conf/plot/write_region_embedding.yaml: -------------------------------------------------------------------------------- 1 | upstream_ckpt: /storage/czw/self_supervised_seeg/pretrained_weights/stft_large_pretrained.pth #trained on all subject3 electrodes 2 | output_dir: /storage/czw/self_supervised_seeg/outputs/plotting_outputs/region_tsne 3 | -------------------------------------------------------------------------------- /conf/preprocessor/spec_pooled.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czlwang/BrainBERT/711983690a4c5502038c70473ef0c4f450e1f7fe/conf/preprocessor/spec_pooled.yaml -------------------------------------------------------------------------------- /conf/preprocessor/stft.yaml: -------------------------------------------------------------------------------- 1 | name: stft 2 | freq_channel_cutoff: 40 3 | nperseg: 400 4 | noverlap: 350 5 | normalizing: zscore 6 | -------------------------------------------------------------------------------- /conf/preprocessor/stft_pooled_preprocessor.yaml: -------------------------------------------------------------------------------- 1 | name: spec_pooled_preprocessor 2 | spec_name: stft 3 | freq_channel_cutoff: 40 4 | nperseg: 400 5 | noverlap: 350 6 | normalizing: zscore 7 | -------------------------------------------------------------------------------- /conf/preprocessor/stft_pretrained.yaml: -------------------------------------------------------------------------------- 1 | #This config can only be used when the upstream is frozen 2 | name: spec_pretrained #TODO check if these are the params that the model was pretrained with 3 | spec_name: stft 4 | freq_channel_cutoff: 40 5 | nperseg: 400 6 | noverlap: 350 7 | normalizing: zscore 8 | upstream_ckpt: /storage/czw/self_supervised_seeg/outputs/2022-07-19/17-25-17/checkpoint_last.pth #trained on all subject3 electrodes 9 | -------------------------------------------------------------------------------- /conf/preprocessor/superlet.yaml: -------------------------------------------------------------------------------- 1 | name: superlet 2 | c1: 1 3 | order_min: 3 4 | order_max: 30 5 | decim: 50 6 | min_f: 0.1 7 | max_f: 200 8 | n_f_steps: 40 #foi = linspace(f_min, f_max, n_f_steps) 9 | -------------------------------------------------------------------------------- /conf/preprocessor/superlet_pooled_preprocessor.yaml: -------------------------------------------------------------------------------- 1 | name: spec_pooled_preprocessor 2 | spec_name: superlet 3 | c1: 1 4 | order_min: 3 5 | order_max: 30 6 | decim: 50 7 | min_f: 0.1 8 | max_f: 200 9 | n_f_steps: 40 #foi = linspace(f_min, f_max, n_f_steps) 10 | -------------------------------------------------------------------------------- /conf/preprocessor/superlet_pretrained.yaml: -------------------------------------------------------------------------------- 1 | #This config can only be used when the upstream is frozen 2 | name: spec_pretrained #TODO check if these are the params that the model was pretrained with 3 | spec_name: superlet 4 | c1: 1 5 | order_min: 3 6 | order_max: 30 7 | decim: 50 8 | min_f: 0.1 9 | max_f: 200 10 | n_f_steps: 40 #foi = linspace(f_min, f_max, n_f_steps) 11 | upstream_ckpt: ??? 12 | -------------------------------------------------------------------------------- /conf/preprocessor/wav_preprocessor.yaml: -------------------------------------------------------------------------------- 1 | name: wav_preprocessor 2 | sample_rate: 2048 3 | 4 | -------------------------------------------------------------------------------- /conf/task/baseline_task.yaml: -------------------------------------------------------------------------------- 1 | name: baseline_task 2 | 3 | -------------------------------------------------------------------------------- /conf/task/baseline_wav_task.yaml: -------------------------------------------------------------------------------- 1 | name: baseline_wav_task 2 | -------------------------------------------------------------------------------- /conf/task/feature_extract.yaml: -------------------------------------------------------------------------------- 1 | name: feature_extract_task 2 | 3 | -------------------------------------------------------------------------------- /conf/task/finetune_task.yaml: -------------------------------------------------------------------------------- 1 | name: finetune_task 2 | -------------------------------------------------------------------------------- /conf/task/fixed_mask_pretrain.yaml: -------------------------------------------------------------------------------- 1 | name: spec_pretrain 2 | mask_type: fixed 3 | #data: ??? 4 | time_mask_consecutive_min: 1 5 | time_mask_consecutive_max: 5 6 | freq_mask_consecutive_min: 1 7 | freq_mask_consecutive_max: 2 8 | time_mask_p: 0.10 9 | freq_mask_p: 0.10 10 | -------------------------------------------------------------------------------- /conf/task/variable_mask_pretrain.yaml: -------------------------------------------------------------------------------- 1 | name: spec_pretrain 2 | mask_type: variable 3 | #data: ??? 4 | max_f: 250 5 | min_f: 0.01 6 | mask_p: 0.10 7 | n_freq_steps: 40 8 | -------------------------------------------------------------------------------- /conf/test/debug_test.yaml: -------------------------------------------------------------------------------- 1 | test_split_path: "/storage/czw/BrainBERT/data/test_split_trials.json" #these are the trials (movies) you want to pull data from 2 | test_electrodes_path: "/storage/czw/BrainBERT/linear_results/" #this is the list of electrodes you want to fine-tune/feature extract over 3 | -------------------------------------------------------------------------------- /conf/test/effective_dim.yaml: -------------------------------------------------------------------------------- 1 | upstream_ckpt: /storage/czw/self_supervised_seeg/pretrained_weights/stft_large_pretrained.pth #trained on all subject3 electrodes 2 | raw_spec: False 3 | dim_reduce: pca 4 | n_components: 200 5 | test_split_path: /storage/czw/self_supervised_seeg/data/test_split_trials.json 6 | out_dir: /storage/czw/self_supervised_seeg/effective_dim_results/full_brain 7 | -------------------------------------------------------------------------------- /conf/test/effective_dim_raw_spec.yaml: -------------------------------------------------------------------------------- 1 | upstream_ckpt: /storage/czw/self_supervised_seeg/pretrained_weights/stft_large_pretrained.pth #trained on all subject3 electrodes 2 | raw_spec: True 3 | dim_reduce: pca 4 | n_components: 40 5 | test_split_path: /storage/czw/self_supervised_seeg/data/test_split_trials.json 6 | out_dir: /storage/czw/self_supervised_seeg/effective_dim_results/stft_only 7 | 8 | -------------------------------------------------------------------------------- /conf/test/fewshot_test.yaml: -------------------------------------------------------------------------------- 1 | out_dir: /storage/czw/self_supervised_seeg/fewshot_test/ 2 | test_runs: 3 3 | test_split_path: "/storage/czw/self_supervised_seeg/data/test_split_trials.json" 4 | test_ex_min: 50 5 | test_ex_max: 1050 6 | test_ex_step: 50 7 | -------------------------------------------------------------------------------- /conf/test/held_out_subjects.yaml: -------------------------------------------------------------------------------- 1 | test_split_path: "/storage/czw/self_supervised_seeg/data/test_split_trials.json" 2 | test_electrodes_path: "/storage/czw/self_supervised_seeg/linear_results/" 3 | -------------------------------------------------------------------------------- /conf/test/process_linear_results.yaml: -------------------------------------------------------------------------------- 1 | linear_results_path: /storage/czw/self_supervised_seeg/outputs/2022-08-30/01-14-20/all_test_results 2 | out_dir: /storage/czw/self_supervised_seeg/linear_results/speech_finetuning 3 | topk: 25 4 | -------------------------------------------------------------------------------- /conf/test/single_electrode_test.yaml: -------------------------------------------------------------------------------- 1 | output_dir: ??? 2 | 3 | -------------------------------------------------------------------------------- /criterions/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from pathlib import Path 4 | 5 | CRITERION_REGISTRY = {} 6 | 7 | __all__ = ["build_criterion"] 8 | 9 | def build_criterion(cfg): 10 | criterion_name = cfg.name 11 | assert criterion_name in CRITERION_REGISTRY 12 | criterion = CRITERION_REGISTRY[criterion_name]() 13 | criterion.build_criterion(cfg) 14 | return criterion 15 | 16 | def register_criterion(name): 17 | def register_criterion_cls(cls): 18 | if name in CRITERION_REGISTRY: 19 | raise ValueError(f'{name} already in registry') 20 | else: 21 | CRITERION_REGISTRY[name] = cls 22 | return cls 23 | return register_criterion_cls 24 | 25 | def import_criterions(): 26 | for file in os.listdir(os.path.dirname(__file__)): 27 | if file.endswith(".py") and not file.startswith("_"): 28 | module_name = str(Path(file).with_suffix("")) 29 | importlib.import_module('criterions.'+module_name) 30 | 31 | import_criterions() 32 | -------------------------------------------------------------------------------- /criterions/base_criterion.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class BaseCriterion(nn.Module): 4 | def __init__(self): 5 | super(BaseCriterion, self).__init__() 6 | pass 7 | 8 | def build_criterion(self, cfg): 9 | raise NotImplementedError 10 | 11 | def forward(self, model, batch, device): 12 | raise NotImplementedError 13 | -------------------------------------------------------------------------------- /criterions/baseline_criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_criterion import BaseCriterion 3 | from torch import nn 4 | from criterions import register_criterion 5 | 6 | @register_criterion("baseline_criterion") 7 | class BaselineCriterion(BaseCriterion): 8 | def __init__(self): 9 | super(BaselineCriterion, self).__init__() 10 | pass 11 | 12 | def build_criterion(self, cfg): 13 | self.cfg = cfg 14 | self.sigmoid = nn.Sigmoid() 15 | self.loss_fn = nn.BCEWithLogitsLoss(reduction="mean") 16 | 17 | def forward(self, model, batch, device, return_predicts=False): 18 | inputs = batch["input"].to(device) #potentially don't move to device if dataparallel 19 | 20 | output = model.forward(inputs) 21 | labels = torch.FloatTensor(batch["labels"]).to(output.device) 22 | 23 | output = output.squeeze(-1) 24 | loss = self.loss_fn(output, labels) 25 | images = {"wav": batch["input"][0], 26 | "wav_label": batch["labels"][0]} 27 | if return_predicts: 28 | predicts = self.sigmoid(output).squeeze().detach().cpu().numpy() 29 | logging_output = {"loss": loss.item(), 30 | "predicts": predicts, 31 | "images": images} 32 | else: 33 | logging_output = {"loss": loss.item(), 34 | "images": images} 35 | return loss, logging_output 36 | 37 | -------------------------------------------------------------------------------- /criterions/feature_extract_criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_criterion import BaseCriterion 3 | from torch import nn 4 | from criterions import register_criterion 5 | 6 | @register_criterion("feature_extract_criterion") 7 | class FeatureExtractCriterion(BaseCriterion): 8 | def __init__(self): 9 | super(FeatureExtractCriterion, self).__init__() 10 | pass 11 | 12 | def build_criterion(self, cfg): 13 | self.cfg = cfg 14 | self.sigmoid = nn.Sigmoid() 15 | self.loss_fn = nn.BCEWithLogitsLoss(reduction="mean") 16 | 17 | def forward(self, model, batch, device, return_predicts=False): 18 | #TODO fix the dataset here. 19 | inputs = batch["input"].to(device) #potentially don't move to device if dataparallel 20 | output = model.forward(inputs) 21 | labels = torch.FloatTensor(batch["labels"]).to(output.device) 22 | output = output.squeeze(-1) 23 | loss = self.loss_fn(output, labels) 24 | images = {"wav": batch["wavs"][0], 25 | "wav_label": batch["labels"][0]} 26 | if return_predicts: 27 | predicts = self.sigmoid(output).squeeze().detach().cpu().numpy() 28 | logging_output = {"loss": loss.item(), 29 | "predicts": predicts, 30 | "images": images} 31 | else: 32 | logging_output = {"loss": loss.item(), 33 | "images": images} 34 | return loss, logging_output 35 | 36 | -------------------------------------------------------------------------------- /criterions/finetune_criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_criterion import BaseCriterion 3 | from torch import nn 4 | from criterions import register_criterion 5 | 6 | @register_criterion("finetune_criterion") 7 | class FinetuneCriterion(BaseCriterion): 8 | def __init__(self): 9 | super(FinetuneCriterion, self).__init__() 10 | pass 11 | 12 | def build_criterion(self, cfg): 13 | self.cfg = cfg 14 | self.sigmoid = nn.Sigmoid() 15 | self.loss_fn = nn.BCEWithLogitsLoss(reduction="mean") 16 | 17 | def forward(self, model, batch, device, return_predicts=False): 18 | inputs = batch["input"].to(device) #potentially don't move to device if dataparallel 19 | pad_mask = batch["pad_mask"].to(device) 20 | 21 | output = model.forward(inputs, pad_mask) 22 | labels = torch.FloatTensor(batch["labels"]).to(output.device) 23 | output = output.squeeze(-1) 24 | loss = self.loss_fn(output, labels) 25 | images = {"wav": batch["wavs"][0], 26 | "wav_label": batch["labels"][0]} 27 | if return_predicts: 28 | predicts = self.sigmoid(output).squeeze().detach().cpu().numpy() 29 | logging_output = {"loss": loss.item(), 30 | "predicts": predicts, 31 | "images": images} 32 | else: 33 | logging_output = {"loss": loss.item(), 34 | "images": images} 35 | return loss, logging_output 36 | -------------------------------------------------------------------------------- /criterions/pretrain_masked_criterion.py: -------------------------------------------------------------------------------- 1 | from .base_criterion import BaseCriterion 2 | import torch 3 | from torch import nn 4 | from criterions import register_criterion 5 | 6 | @register_criterion("pretrain_masked_criterion") 7 | class PretrainMaskedCriterion(BaseCriterion): 8 | def __init__(self): 9 | super(PretrainMaskedCriterion, self).__init__() 10 | pass 11 | 12 | def build_criterion(self, cfg): 13 | self.cfg = cfg 14 | 15 | def forward(self, model, batch, device): 16 | #x = batch[:,:10] #potentially don't move to device if dataparallel 17 | pad_mask = batch["attn_mask"].to(device) 18 | masked_input = batch["masked_input"].to(device) #potentially don't move to device if dataparallel 19 | mask = batch["mask_label"].bool().to(device) 20 | output, pos_enc = model.forward(masked_input, pad_mask) 21 | labels = batch["target"].to(device) 22 | true_activity = labels.masked_select(mask) 23 | predicted = output.masked_select(mask) 24 | l1 = torch.mean(torch.abs(true_activity - predicted)) 25 | non_zero_idxs = torch.abs(true_activity)>1 26 | non_zero = torch.mean(torch.abs(true_activity[non_zero_idxs] - predicted[non_zero_idxs])) 27 | content_aware_loss = self.cfg.alpha*non_zero 28 | loss = l1 + content_aware_loss 29 | output_log_spec = output[1].detach().cpu() 30 | content_l1 = non_zero 31 | wav = batch["wavs"][1] 32 | images = {"input_spectrogram": masked_input[1].detach().cpu(), 33 | "ground_truth": labels[1].detach().cpu(), 34 | "pred_spectrogram": output_log_spec, 35 | "pos_enc": pos_enc[0].detach().cpu(), 36 | "wav": wav} 37 | logging_output = {"loss": loss.item(), 38 | "images": images, 39 | "l1_loss": l1.item(), 40 | "content_l1": content_l1.item(), 41 | "content_aware_loss": content_aware_loss.item()} 42 | return loss, logging_output 43 | 44 | -------------------------------------------------------------------------------- /criterions/seeg_wav2vec_criterion.py: -------------------------------------------------------------------------------- 1 | from .base_criterion import BaseCriterion 2 | import torch 3 | from torch import nn 4 | from criterions import register_criterion 5 | 6 | @register_criterion("seeg_wav2vec_criterion") 7 | class SeegWav2VecCriterion(BaseCriterion): 8 | def __init__(self): 9 | super(SeegWav2VecCriterion, self).__init__() 10 | pass 11 | 12 | def build_criterion(self, cfg): 13 | self.cfg = cfg 14 | 15 | def forward(self, model, batch, device): 16 | #x = batch[:,:10] #potentially don't move to device if dataparallel 17 | print(batch) 18 | import pdb; pdb.set_trace() 19 | return loss, logging_output 20 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czlwang/BrainBERT/711983690a4c5502038c70473ef0c4f450e1f7fe/data/__init__.py -------------------------------------------------------------------------------- /data/corrupted_elec.json: -------------------------------------------------------------------------------- 1 | {"subject1": ["F3dIe11", "F3dIe12", "F3dIe13", "F3dIe15", "F3dIe16", "T1cIf9", "T2c1", "T2c2", "T2c3", "T3aHb1", "T3aHb2", "T3aHb3", "T3aHb4", "T3aHb5", "T3aHb7", "T3aHb8", "T3aHb11", "T3bOT11", "TRIG4", "F3aOFa5", "F3aOFa6", "T2bHa2", "T2bHa6", "F3aOFa1", "T3bOT7"], "subject2": ["DC4", "DC10", "LT1c1", "LT2aA3", "LT3bHa1", "LT3bHa2", "LT3bHa3", "LT3bHa13", "LT3cHb1", "LT3cHb2", "LT3cHb3", "RT1c6", "RT2aA11#", "RT2aA12#", "RT2b6", "RT2b7", "RT2c3", "RT2c4", "RT2c5", "RT2c6", "RT3aHa1*", "RT3aHa2*", "RT3aHa3*", "RT3bHb1", "RT3bHb2", "RT3bHb3", "RT3bHb9", "RT3bHb13", "RT3bHb14"], "subject3": ["DC4", "DC10", "F2Ia15", "F2Ia16", "F3b5", "F3b6", "F3c7", "O1bId6", "O1bId7", "P2b2", "T1cIe3", "T1cIe4"], "subject8": ["DC4", "DC10", "F2bIc9", "F2bIc10", "F2bIc11", "F2c1", "F2c5", "F2c6", "F3IaOF6", "P1Cc12", "P1Cc13", "P1Cc14", "P1Cc16", "P2Ie1", "T2aHa2", "T2aHa3", "T2bHb3", "T2bHb4"], "subject9": ["DC4", "DC10", "F2aIb11", "F2aIb12", "F3aOFa10", "T2A5", "T2A6", "T2A7", "T2A8", "T3H1", "T3H2"], "subject10": ["DC4", "DC10", "F3a1", "F3a2", "F3a4", "P2c5", "P2c2", "P2c3", "T1c6"], "subject6": ["DC4", "DC10", "F1*0Fa5", "F1*0Fa6", "F1*0Fa7", "F1*0Fa13", "P2bIg1", "P2cCc16", "T2aA5", "T2aA6", "T2bH3"], "subject5": ["DC4", "DC10", "LT3Ha10", "LT1bId8", "LT1bId9"], "subject4": ["DC4", "DC10", "T2A5", "T2A6", "T2A7"], "subject7": ["DC4", "DC10", "RF1bCb8", "RF1bCb9", "RF1bCb10", "RF2I*14", "RF2I*15", "RF2I*16"]} -------------------------------------------------------------------------------- /data/create_data_dirs.py: -------------------------------------------------------------------------------- 1 | #example: 2 | #python3 -m data.create_data_dirs +data=pretraining +hydra.job.chdir=False 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | from hydra.core.hydra_config import HydraConfig 6 | from hydra.utils import get_original_cwd, to_absolute_path 7 | import logging 8 | from datasets import build_dataset 9 | from tqdm import tqdm as tqdm 10 | import os 11 | from scipy.io.wavfile import write 12 | from pathlib import Path 13 | import numpy as np 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | def print_time(start): 18 | return ((time.time() - start)/60) 19 | 20 | def write_dataset_to_dir(dataset, args, root_out): 21 | subject_name = args.subject 22 | run_name = args.brain_runs[0] 23 | subject_out = os.path.join(root_out, subject_name, run_name) 24 | Path(subject_out).mkdir(parents=True, exist_ok=True) 25 | 26 | trainable = 0 27 | samplerate = args.samp_frequency 28 | for i in tqdm(range(len(dataset))): 29 | file_name = os.path.join(subject_out, f'{i}.wav') 30 | data = dataset[i] 31 | data = data["input"] 32 | write(file_name, samplerate, data.astype(np.float64)) 33 | 34 | @hydra.main(config_path="../conf", version_base=None) 35 | def main(cfg: DictConfig) -> None: 36 | log.info("create data dir") 37 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 38 | print(HydraConfig.get().job.name) 39 | print(HydraConfig.get().run.dir) 40 | out_dir = HydraConfig.get().run.dir 41 | out_dir = to_absolute_path(out_dir) 42 | log.info(f'output directory {out_dir}') 43 | 44 | dataset = build_dataset(cfg.data) 45 | write_dataset_to_dir(dataset, cfg.data, out_dir) 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /data/edf2h5.py: -------------------------------------------------------------------------------- 1 | # example 2 | # python3 edf2h5.py --in_file /storage/abarbu/full-day-ecog/m00191/m00191-full-part-4.EDF --out_dir m00191_part4--output_timestamp_metadata 3 | from datetime import timedelta 4 | import json 5 | import time 6 | import math 7 | #import glob 8 | import os 9 | import h5py 10 | import numpy as np 11 | #from threading import Thread 12 | import mne 13 | from data_loading.utils import compute_m5_hash 14 | from args import parse_edf2h5_args 15 | 16 | def write_timestamp_metadata_file(file_name: str, out_dir=None, omit_prefix_percent=0) -> None: 17 | data = mne.io.read_raw_edf(file_name) 18 | info = data.info 19 | channels = data.ch_names 20 | timestamp = info["meas_date"] 21 | 22 | new_path = os.path.join(out_dir, 'time-stamp.json') 23 | idx = 0 24 | data_arr,_ = data[idx] #[1, n_samples] 25 | start_idx = int(omit_prefix_percent*data_arr.shape[-1]) 26 | sample_rate = 2048 27 | timestamp + timedelta(seconds=start_idx/sample_rate) 28 | 29 | with open(new_path, 'w') as f: 30 | json.dump({'timestamp': f'{timestamp}'}, f) 31 | 32 | def write_transposed_data_file(file_name: str, orig_channel_n, out_dir=None, omit_prefix_percent=0) -> None: 33 | data = mne.io.read_raw_edf(file_name) 34 | info = data.info 35 | channels = data.ch_names 36 | #raw_data = np.zeros([264,100]) 37 | #channels = [f'C{i}' for i in range(164)] + [f'DC{i}' for i in range(264-163)] 38 | 39 | new_path = os.path.join(out_dir, 'data-sharded.h5') 40 | computed_hash = compute_m5_hash(file_name) 41 | 42 | with h5py.File(new_path, 'a', libver='latest') as hf_new: 43 | new_group = hf_new.create_group('data') 44 | hf_new['data'].attrs['orig_data_hash'] = computed_hash 45 | channel_count = 0 46 | orig_channel_n = orig_channel_n.tolist() 47 | for i in orig_channel_n: 48 | idx = i-1 49 | ch_name = channels[idx] 50 | print(idx, ch_name) 51 | data_arr,_ = data[idx] #[1, n_samples] 52 | start_idx = int(omit_prefix_percent*data_arr.shape[-1]) 53 | data_arr = data_arr[0, start_idx:] 54 | data_arr = data_arr.squeeze() 55 | if ch_name in ['DC10', 'DC4']: #NOTE 56 | new_group.create_dataset(ch_name, data=data_arr, compression="gzip") 57 | channel_count += 1 58 | else: 59 | new_group.create_dataset(f'electrode_{idx}', data=data_arr, compression="gzip") 60 | channel_count += 1 61 | 62 | return channel_count 63 | 64 | if __name__=="__main__": 65 | start = time.time() 66 | args = parse_edf2h5_args() 67 | file = args.in_file 68 | subject = "m00191" 69 | trial = "trial000" 70 | 71 | headers_dir_format = '/storage/datasets/neuroscience/ecog/data-by-subject/{}/data/trials/{}/headers' 72 | 73 | def get_string_from_hdf5_reference(file, ref): 74 | return ''.join(chr(i) for i in file[ref[0]][:]) 75 | 76 | headers_dir = headers_dir_format.format(subject, trial) 77 | header_file_name = os.listdir(headers_dir)[0] 78 | header_file = h5py.File(os.path.join(headers_dir, header_file_name),'r') 79 | electrode_labels = [get_string_from_hdf5_reference(header_file, ref) for ref in header_file['channel_labels']] 80 | labels = [(i, e) for i, e in enumerate(electrode_labels)] 81 | 82 | orig_channel_n = np.array(header_file['orig_channel_n']).squeeze() 83 | orig_channel_n = orig_channel_n.astype('int32') 84 | 85 | assert args.omit_prefix <= 1.0 and args.omit_prefix >= 0 86 | if args.output_timestamp_metadata: 87 | omit_prefix_percent = args.omit_prefix 88 | write_timestamp_metadata_file(file, args.out_dir, omit_prefix_percent=omit_prefix_percent) 89 | exit() 90 | 91 | channel_count = write_transposed_data_file(file, orig_channel_n, args.out_dir, omit_prefix_percent=omit_prefix_percent) 92 | assert len(labels)==channel_count 93 | print(labels) 94 | print(file) 95 | total_time = time.time() - start 96 | print(f'that took {total_time/60}s') 97 | -------------------------------------------------------------------------------- /data/electrode_selection.py: -------------------------------------------------------------------------------- 1 | from glob import glob as glob 2 | from .utils import stem_electrode_name 3 | import os 4 | import json 5 | import h5py 6 | 7 | def get_all_laplacian_electrodes(elec_list): 8 | stems = [stem_electrode_name(e) for e in elec_list] 9 | def has_nbrs(stem, stems): 10 | (x,y) = stem 11 | return ((x,y+1) in stems) and ((x,y-1) in stems) 12 | laplacian_stems = [x for x in stems if has_nbrs(x, stems)] 13 | electrodes = [f'{x}{y}' for (x,y) in laplacian_stems] 14 | return electrodes 15 | 16 | def get_all_electrodes(subject, data_root=None): 17 | ''' 18 | returns list of electrodes in this subject and trial 19 | NOTE: the order of these labels is important. Their position corresponds with a row in data.h5 20 | ''' 21 | electrode_labels_file = glob(os.path.join(data_root, "electrode_labels", subject, "electrode_labels.json")) 22 | assert len(electrode_labels_file)==1 23 | electrode_labels_file = electrode_labels_file[0] 24 | with open(electrode_labels_file, "r") as f: 25 | electrode_labels = json.load(f) 26 | strip_string = lambda x: x.replace("*","").replace("#","").replace("_","") 27 | electrode_labels = [strip_string(e) for e in electrode_labels] 28 | return electrode_labels 29 | 30 | def clean_electrodes(subject, electrodes, data_root=None): 31 | corrupted_electrodes_path = os.path.join(data_root, "corrupted_elec.json") 32 | with open(corrupted_electrodes_path, "r") as f: 33 | corrupted_elecs = json.load(f) 34 | corrupt = corrupted_elecs[subject] 35 | return list(set(electrodes).difference(corrupt)) 36 | 37 | def get_clean_laplacian_electrodes(subject, data_root=None): 38 | electrodes = get_all_electrodes(subject, data_root=data_root) 39 | electrodes = clean_electrodes(subject, electrodes, data_root=data_root) 40 | laplacian_electrodes = get_all_laplacian_electrodes(electrodes) 41 | return laplacian_electrodes 42 | 43 | def main(): 44 | with open("data/pretrain_split_trials.json", "r") as f: 45 | subjects = json.load(f) 46 | all_electrodes = [] 47 | for subject in subjects: 48 | electrodes = get_clean_laplacian_electrodes(subject) 49 | print(subject, len(electrodes)) 50 | all_electrodes += [(subject, e) for e in electrodes] 51 | print(len(all_electrodes)) 52 | 53 | if __name__=="__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /data/electrode_subject_data.py: -------------------------------------------------------------------------------- 1 | from scipy import signal, stats#TODO remove import 2 | import time 3 | import os 4 | import torch 5 | import string 6 | import numpy as np 7 | import h5py 8 | import logging 9 | from pathlib import Path 10 | # import numpy.typing as npt 11 | 12 | from torch.utils import data 13 | from .h5_data import H5Data 14 | from .h5_data_reader import H5DataReader 15 | from typing import Optional, List, Dict, Any, Tuple 16 | import pandas as pd 17 | from types import SimpleNamespace 18 | 19 | log = logging.getLogger(__name__) 20 | 21 | class ElectrodeSubjectData(): 22 | def __init__(self, subject, cfg) -> None: 23 | self.selected_electrodes = cfg.electrodes 24 | self.cfg = cfg 25 | self.neural_data, self.trials = self.get_subj_data(subject) 26 | 27 | def save_cache(self, cached_data_path, seeg_data, subject_data_type): 28 | assert len(self.selected_electrodes)==1 29 | e = self.selected_electrodes[0] 30 | data_dir = self.make_cached_data_array_file_name(e, subject_data_type) 31 | e_path = os.path.join(cached_data_path, data_dir) 32 | Path(e_path).mkdir(exist_ok=True, parents=True) 33 | data_array_path = os.path.join(e_path, "array.npy") 34 | np.save(data_array_path, seeg_data) 35 | 36 | def make_cached_data_array_file_name(self, electrode, subject_data_type): 37 | cfg = self.cfg 38 | return f'cache_{cfg.duration}_d_{cfg.delta}_del_{cfg.rereference}_{cfg.subject}_{electrode}_{subject_data_type}_{cfg.high_gamma}_hg' 39 | 40 | def load_from_cache(self, cached_data_path, subject_data_type): 41 | assert len(self.cfg.electrodes)==1 42 | 43 | data_dir = self.make_cached_data_array_file_name(self.cfg.electrodes[0], subject_data_type) 44 | arr = None 45 | data_array_path = os.path.join(cached_data_path, data_dir, "array.npy") 46 | if os.path.exists(data_array_path): 47 | arr = np.load(data_array_path) 48 | 49 | return arr 50 | 51 | 52 | def get_subj_data(self, subject): 53 | ''' 54 | returns: 55 | numpy array of words 56 | numpy array of shape [n_electrodes, n_words, n_samples] which holds the 57 | aligned data across all trials 58 | ''' 59 | 60 | seeg_data, trials = [], [] 61 | run_ids = self.cfg.brain_runs 62 | 63 | cached_data_path = self.cfg.get("cached_data_array", None) 64 | cache_name = "electrode_finetuning" 65 | reload_caches = self.cfg.get("reload_caches", False) 66 | use_cache = cached_data_path is not None and not reload_caches 67 | cache_exists = False 68 | if use_cache: 69 | data_dir = self.make_cached_data_array_file_name(self.cfg.electrodes[0], cache_name) 70 | data_array_path = os.path.join(cached_data_path, data_dir, "array.npy") 71 | cache_exists = os.path.exists(data_array_path) 72 | 73 | for run_id in run_ids: 74 | t = H5Data(subject, run_id, self.cfg) 75 | trials.append(t) 76 | reader = H5DataReader(t, self.cfg) 77 | 78 | log.info("Getting filtered data") 79 | if use_cache and cache_exists: 80 | continue 81 | else: 82 | seeg_trial_data = reader.get_filtered_data() 83 | seeg_data.append(seeg_trial_data) 84 | assert len(run_ids)==1 85 | if use_cache and cache_exists: 86 | seeg_data = self.load_from_cache(cached_data_path, cache_name) 87 | else: 88 | seeg_data = np.concatenate(seeg_data) 89 | cutoff_len = int(seeg_data.shape[-1] / (2048*self.cfg.duration))* 2048 * self.cfg.duration #how many 3 second samples should we take? 90 | cutoff_len = int(cutoff_len) 91 | seeg_data = seeg_data[:,:cutoff_len] 92 | seeg_data = seeg_data.reshape([seeg_data.shape[0],-1, int(2048*self.cfg.duration)]) #NOTE hardcode 93 | 94 | if cached_data_path is not None: 95 | self.save_cache(cached_data_path, seeg_data, cache_name) 96 | return seeg_data, trials 97 | -------------------------------------------------------------------------------- /data/h5_data.py: -------------------------------------------------------------------------------- 1 | from glob import glob as glob 2 | from dateutil import parser 3 | import json 4 | from datetime import datetime 5 | import h5py 6 | import os 7 | from typing import Tuple, Dict, List 8 | from types import SimpleNamespace 9 | 10 | class H5Data(): 11 | def __init__(self, subject: str, run_id: str, cfg) -> None: 12 | ''' 13 | input: 14 | subject=subject id 15 | ''' 16 | 17 | self.subject_id = subject 18 | self.samp_frequency = cfg.samp_frequency 19 | dataset_dir = cfg.raw_brain_data_dir 20 | trial = run_id 21 | 22 | # Path to neural data h5 file 23 | self.neural_data_file = os.path.join(dataset_dir, f'all_subject_data/{subject}_{trial}.h5') 24 | 25 | # Path to brain regions csv file 26 | self.regions_file = os.path.join(dataset_dir, f'localization/{subject}/depth-wm.csv') 27 | 28 | electrode_labels_file = glob(os.path.join(dataset_dir, "electrode_labels", subject, "electrode_labels.json")) 29 | assert len(electrode_labels_file)==1 30 | electrode_labels_file = electrode_labels_file[0] 31 | self.electrode_labels_file = electrode_labels_file 32 | 33 | self.timestamp = self.get_timestamp() 34 | 35 | def get_timestamp(self): 36 | if not os.path.exists(self.timestamp_data): 37 | return None 38 | 39 | with open(self.timestamp_data, 'r') as f: 40 | d = json.load(f) 41 | timestamp = parser.parse(d["timestamp"]) 42 | return timestamp 43 | 44 | def get_brain_region_localization(self) -> List[str]: 45 | ''' 46 | returns list of electrodes in this subject and trial 47 | NOTE: the order of these labels is important. Their position corresponds with a row in data.h5 48 | ''' 49 | with open(self.electrode_labels_file, "r") as f: 50 | electrode_labels = json.load(f) 51 | strip_string = lambda x: x.replace("*","").replace("#","").replace("_","") 52 | electrode_labels = [strip_string(e) for e in electrode_labels] 53 | return electrode_labels 54 | -------------------------------------------------------------------------------- /data/h5_data_reader.py: -------------------------------------------------------------------------------- 1 | from scipy.signal import hilbert, chirp 2 | from tqdm import tqdm 3 | import os 4 | import h5py 5 | import numpy as np 6 | import pandas as pd 7 | import scipy.stats 8 | # import numpy.typing as npt 9 | 10 | from typing import Optional, List, Tuple 11 | from scipy import signal, stats 12 | from .utils import compute_m5_hash, stem_electrode_name 13 | 14 | class H5DataReader: 15 | def __init__(self, trial_data, cfg) -> None: 16 | ''' 17 | Input: trial_data=ecog and word data to perform processing on 18 | ''' 19 | self.freqs_to_filter = [60, 120, 180, 240, 300, 360] 20 | 21 | self.trial_data = trial_data 22 | self.cfg = cfg 23 | self.selected_electrodes = cfg.electrodes 24 | if self.selected_electrodes==[]: 25 | self.selected_electrodes = trial_data.get_brain_region_localization() 26 | 27 | if self.cfg.rereference=="laplacian": 28 | self.adj_electrodes = self.get_all_adj_electrodes() 29 | else: 30 | assert self.cfg.rereference=="None" 31 | 32 | def get_adj_electrodes(self, name): 33 | labels = self.trial_data.get_brain_region_localization() 34 | all_electrode_stems = [stem_electrode_name(l) for l in labels] 35 | 36 | stem, num = stem_electrode_name(name) 37 | same_wire = [(s,n) for (s,n) in all_electrode_stems if s==stem] 38 | nbrs = [(stem, num+1), (stem, num-1)] 39 | nbrs = [n for n in nbrs if n in all_electrode_stems] 40 | assert len(nbrs)==2 41 | return [e+str(s) for (e,s) in nbrs] 42 | 43 | def get_all_adj_electrodes(self): 44 | nbrs = [self.get_adj_electrodes(n) for n in self.selected_electrodes] #TODO debug 45 | flat_nbrs = [x for y in nbrs for x in y] 46 | return list(set(flat_nbrs)) 47 | 48 | def notch_filter(self, data, freq, Q=30) -> np.ndarray: 49 | samp_frequency = self.trial_data.samp_frequency 50 | w0 = freq / (samp_frequency / 2) 51 | b, a = signal.iirnotch(w0, Q) 52 | y = signal.lfilter(b, a, data, axis = 1) 53 | return y 54 | 55 | def highpass_filter(self, data, freq, Q=30) -> np.ndarray: 56 | samp_frequency = self.trial_data.samp_frequency 57 | sos = signal.butter(Q, freq, 'highpass', fs=samp_frequency, output='sos') 58 | y = signal.sosfilt(sos, data, axis = 1) 59 | return y 60 | 61 | def band_filter(self, data, freqs, Q=30) -> np.ndarray: 62 | samp_frequency = self.trial_data.samp_frequency 63 | sos = signal.butter(Q, freqs, 'bandpass', analog=False, fs=samp_frequency, output='sos') 64 | y = signal.sosfilt(sos, data, axis = 1) 65 | return y 66 | 67 | def car_rereference(self, data_arr): 68 | all_ordered_labels = self.trial_data.get_brain_region_localization() 69 | selected = [(i,e) for i,e in enumerate(all_ordered_labels) if e in self.selected_electrodes] 70 | sel_idxs, sel_labels = zip(*selected) 71 | 72 | all_data = self.select_electrodes(all_ordered_labels) 73 | reref = data_arr - np.mean(all_data, axis=0) 74 | return reref 75 | 76 | def laplacian_rereference(self, data_arr): 77 | all_ordered_labels = self.trial_data.get_brain_region_localization() 78 | selected = [(i,e) for i,e in enumerate(all_ordered_labels) if e in self.selected_electrodes] 79 | sel_idxs, sel_labels = zip(*selected) 80 | 81 | ordered_nbrs = [e for e in all_ordered_labels if e in self.adj_electrodes] 82 | label2idx = {v:k for (k,v) in enumerate(ordered_nbrs)} 83 | 84 | adj_data_arr = self.select_electrodes(self.adj_electrodes) 85 | sel_nbrs = [self.get_adj_electrodes(n) for n in sel_labels] 86 | sel_nbrs_idxs = [[label2idx[l] for l in nbr_list] for nbr_list in sel_nbrs] 87 | sel_nbr_data = [[adj_data_arr[i] for i in idx_list] for idx_list in sel_nbrs_idxs] 88 | sel_nbr_data = np.array(sel_nbr_data) 89 | sel_nbr_data = self.filter_data(sel_nbr_data) 90 | 91 | #sel_nbr_data is [n_electrodes, 2, n_samples] 92 | laplacian = data_arr - np.mean(sel_nbr_data, axis=1) 93 | return laplacian 94 | 95 | def filter_data(self, data_arr): 96 | for f in self.freqs_to_filter: 97 | data_arr = self.notch_filter(data_arr, f) 98 | if self.cfg.high_gamma: 99 | band_data_arr = self.band_filter(data_arr, [70,250], Q=5) 100 | data_arr = band_data_arr 101 | return data_arr 102 | 103 | def de_spike(self, y): 104 | fuzz = 125 #number of samples around the spike to subtract out 105 | scaling_factor = 0.95 #how much of the spike to remove 106 | 107 | zscored = scipy.stats.zscore(y) 108 | mask = (np.abs(zscored)>4) 109 | fuzzed_mask = np.sign(np.convolve(mask, np.ones(fuzz), mode='same')) 110 | de_spiked = y - (fuzzed_mask*y*scaling_factor) 111 | return de_spiked 112 | 113 | def get_ordered_electrodes(self, selected): 114 | labels = self.trial_data.get_brain_region_localization() 115 | for e in selected: 116 | assert e in labels 117 | re_ordered_electrodes = [e for i,e in enumerate(labels) if e in selected] 118 | return selected 119 | 120 | def get_filtered_data(self) -> np.ndarray: 121 | ''' 122 | filters out freqs from the trial data 123 | ''' 124 | data_arr = self.select_electrodes(self.selected_electrodes) 125 | 126 | data_arr = self.filter_data(data_arr) 127 | if self.cfg.rereference == "CAR": 128 | data_arr = self.car_rereference(data_arr) 129 | if self.cfg.rereference=="laplacian": 130 | data_arr = self.laplacian_rereference(data_arr) 131 | if self.cfg.normalization=="zscore": 132 | data_arr = scipy.stats.zscore(data_arr, axis=1) 133 | if self.cfg.normalization=="standard": 134 | #This should be the same as above in principle 135 | mean = np.mean(data_arr, axis=1) 136 | std_dev = np.std(data_arr, axis=1) 137 | eps = 0.0001 138 | standardized = (data_arr - mean)/(std_dev + eps) 139 | if self.cfg.despike: 140 | for i in range(data_arr.shape[0]): 141 | data_arr[i] = self.cfg.de_spike(data_arr[i]) 142 | return data_arr 143 | 144 | def select_electrodes(self, selected) -> np.ndarray: 145 | ''' 146 | Input: 147 | word_window_arr = array of shape [n_electrodes, n_words, n_samples] 148 | electrode_labels = list of all the electrode labels for a sample 149 | Output: 150 | word_window_arr = array of shape [n_selected_electrodes, n_words, n_samples] 151 | where the order of electrodes is the same as in self.selected_electrodes 152 | 153 | ''' 154 | labels = self.trial_data.get_brain_region_localization() 155 | for e in selected: 156 | assert e in labels 157 | 158 | indices = [i for i,e in enumerate(labels) if e in selected] 159 | 160 | assert len(indices) == len(selected) 161 | 162 | electrode_data = [] 163 | with h5py.File(self.trial_data.neural_data_file, 'r') as hf: 164 | raw_data = hf['data'] 165 | for i in indices: 166 | electrode_data.append(raw_data[f'electrode_{i}'][:]) 167 | electrode_data_arr = np.stack(electrode_data) 168 | 169 | return electrode_data_arr 170 | -------------------------------------------------------------------------------- /data/make_aligned_data_caches.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig, OmegaConf 2 | import hydra 3 | import models 4 | import tasks 5 | import logging 6 | import os 7 | from data.electrode_selection import get_clean_laplacian_electrodes 8 | from data.subject_data import SubjectData 9 | from data.speech_nonspeech_subject_data import NonLinguisticSubjectData, SentenceOnsetSubjectData 10 | import json 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | def create_subj_cache(data_cfg, brain_runs, electrodes, cfg): 15 | cache_path = None 16 | if "cache_input_features" in data_cfg: 17 | cache_path = data_cfg.cache_input_features 18 | 19 | subject_test_results = {} 20 | data_cfg.electrodes = electrodes 21 | data_cfg.brain_runs = brain_runs 22 | if cache_path is not None: 23 | #cache_path needs to identify the pretrained model 24 | e_cache_path = os.path.join(cache_path, data_cfg.subject, data_cfg.name ,e) 25 | log.info(f"logging input features in {e_cache_path}") 26 | data_cfg.cache_input_features = e_cache_path 27 | subj_data = SentenceOnsetSubjectData(data_cfg) 28 | 29 | @hydra.main(config_path="../conf") 30 | def main(cfg: DictConfig) -> None: 31 | log.info(f"Run testing for all electrodes in all test_subjects") 32 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 33 | log.info(f'Working directory {os.getcwd()}') 34 | 35 | test_split_path = cfg.exp.test_split_path 36 | with open(test_split_path, "r") as f: 37 | test_splits = json.load(f) 38 | 39 | data_cfg = cfg.data 40 | all_test_results = {} 41 | for subj in test_splits: 42 | log.info(f"Subject {subj}") 43 | data_cfg.subject = subj 44 | electrodes = get_clean_laplacian_electrodes(subj) 45 | create_subj_cache(data_cfg, test_splits[subj], electrodes, cfg) 46 | 47 | if __name__ == "__main__": 48 | main() 49 | 50 | -------------------------------------------------------------------------------- /data/modify_manifest.py: -------------------------------------------------------------------------------- 1 | #to write features to disk 2 | # python3 -m data.modify_manifest +data=pretrain_wavs_from_disk ++data.data=pretrain_data/manifests +data_prep=remove_subjects ++data_prep.out_dir=/storage/czw/self_supervised_seeg/pretrain_data/manifests_no_subject3/ +preprocessor=wav_preprocessor 3 | from multiprocessing import Process, Queue 4 | from omegaconf import DictConfig, OmegaConf 5 | import hydra 6 | import models 7 | import tasks 8 | from datasets import build_dataset 9 | import logging 10 | import os 11 | from pathlib import Path 12 | import numpy as np 13 | from tqdm import tqdm as tqdm 14 | import csv 15 | import yaml 16 | import json 17 | 18 | #take dataset with wavs and write it to cache 19 | log = logging.getLogger(__name__) 20 | 21 | def write_manifest(root_dir, out_dir, cfg): 22 | #map the original file to the cached feature 23 | manifest_path = os.path.join(root_dir, "manifests") 24 | Path(manifest_path).mkdir(parents=True, exist_ok=True) 25 | old_manifest_path = os.path.join(manifest_path, "manifest.tsv") 26 | old_rows = [] 27 | with open(old_manifest_path, "r", newline="") as f: 28 | reader = csv.reader(f, delimiter='\t') 29 | for i, row in tqdm(enumerate(reader)): 30 | old_rows.append(row) 31 | 32 | manifest_path = os.path.join(out_dir, "manifests") 33 | Path(manifest_path).mkdir(parents=True, exist_ok=True) 34 | new_manifest_path = os.path.join(manifest_path, "manifest.tsv") 35 | log.info("Writing manifest") 36 | with open(new_manifest_path, "w", newline="") as f: 37 | writer = csv.writer(f, delimiter='\t') 38 | for i,row in tqdm(enumerate(old_rows)): 39 | if i==0: 40 | writer.writerow(row) 41 | else: 42 | path, size = row 43 | for s in cfg.data_prep.subjects_to_remove: 44 | if s not in path: 45 | writer.writerow(row) 46 | 47 | @hydra.main(version_base=None, config_path="../conf") 48 | def main(cfg: DictConfig) -> None: 49 | data_cfg = cfg.data 50 | dataset = build_dataset(data_cfg, preprocessor_cfg=cfg.preprocessor) 51 | 52 | assert hasattr(dataset, "files") 53 | 54 | files = dataset.files 55 | #random.shuffle(files) 56 | #files = files[:] 57 | root_dir = dataset.root_dir 58 | write_manifest(root_dir, cfg.data_prep.out_dir, cfg) 59 | 60 | if __name__=="__main__": 61 | main() 62 | 63 | -------------------------------------------------------------------------------- /data/pretrain_split_trials.json: -------------------------------------------------------------------------------- 1 | {"sub_1": ["trial000", "trial002"], "sub_2": ["trial000", "trial001", "trial002", "trial003", "trial004", "trial005"], "sub_3": ["trial001", "trial002"], "sub_8": ["trial000"], "sub_6": ["trial000", "trial001"], "sub_7": ["trial001"], "sub_4": ["trial002", "trial001"], "sub_5": ["trial000"], "sub_9": ["trial000"], "sub_10": ["trial001"]} 2 | -------------------------------------------------------------------------------- /data/speech_nonspeech_subject_data.py: -------------------------------------------------------------------------------- 1 | from scipy import signal, stats#TODO remove import 2 | import psutil 3 | import time 4 | import os 5 | import torch 6 | import string 7 | import numpy as np 8 | import h5py 9 | # import numpy.typing as npt 10 | 11 | from torch.utils import data 12 | from .trial_data import TrialData 13 | from .trial_data_reader import TrialDataReader 14 | from typing import Optional, List, Dict, Any, Tuple 15 | import pandas as pd 16 | from types import SimpleNamespace 17 | 18 | class NonLinguisticSubjectData(): 19 | def __init__(self, cfg) -> None: 20 | self.selected_electrodes = cfg.electrodes 21 | self.cfg = cfg 22 | self.labels, self.neural_data, self.trials = self.get_subj_data(cfg.subject) 23 | 24 | def get_subj_data(self, subject): 25 | words, seeg_data, trials = [], [], [] 26 | cached_transcript_aligns = self.cfg.cached_transcript_aligns 27 | for trial in self.cfg.brain_runs: 28 | if cached_transcript_aligns: #TODO: I want to make this automatic 29 | cached_transcript_aligns = os.path.join(cached_transcript_aligns, subject, trial) 30 | os.makedirs(cached_transcript_aligns, exist_ok=True) 31 | self.cfg.cached_transcript_aligns = cached_transcript_aligns 32 | t = TrialData(subject, trial, self.cfg) 33 | reader = TrialDataReader(t, self.cfg) 34 | 35 | duration = self.cfg.duration 36 | interval_duration = self.cfg.interval_duration 37 | seeg_trial_no_word_data, labels = reader.get_aligned_non_words_matrix(duration=duration, interval_duration=interval_duration) 38 | labels['movie_id'] = t.movie_id 39 | trials.append(t) 40 | words.append(labels) 41 | seeg_data.append(seeg_trial_no_word_data) 42 | 43 | neural_data = np.concatenate(seeg_data, axis=1) 44 | labels_df = pd.concat(words) #NOTE the index will not be unique, but the location will 45 | #TODO: pretty sure we are missing the get_subj_data method here 46 | return labels_df, neural_data, trials 47 | 48 | class SentenceOnsetSubjectData(): 49 | def __init__(self, cfg) -> None: 50 | self.selected_electrodes = cfg.electrodes 51 | self.cfg = cfg 52 | self.labels, self.neural_data, self.trials = self.get_subj_data(cfg.subject) 53 | 54 | def get_subj_data(self, subject): 55 | words, seeg_data, trials = [], [], [] 56 | cached_transcript_aligns = self.cfg.cached_transcript_aligns 57 | for trial in self.cfg.brain_runs: 58 | if cached_transcript_aligns: #TODO: I want to make this automatic 59 | cached_transcript_aligns = os.path.join(cached_transcript_aligns, subject, trial) 60 | os.makedirs(cached_transcript_aligns, exist_ok=True) 61 | self.cfg.cached_transcript_aligns = cached_transcript_aligns 62 | t = TrialData(subject, trial, self.cfg) 63 | reader = TrialDataReader(t, self.cfg) 64 | 65 | duration = self.cfg.duration 66 | delta = self.cfg.delta 67 | interval_duration = self.cfg.interval_duration 68 | seeg_trial_no_word_data, labels = reader.get_aligned_speech_onset_matrix(duration=duration, interval_duration=interval_duration) 69 | labels['movie_id'] = t.movie_id 70 | trials.append(t) 71 | words.append(labels) 72 | seeg_data.append(seeg_trial_no_word_data) 73 | 74 | neural_data = np.concatenate(seeg_data, axis=1) 75 | labels_df = pd.concat(words) #NOTE the index will not be unique, but the location will 76 | #TODO: pretty sure we are missing the get_subj_data method here 77 | return labels_df, neural_data, trials 78 | -------------------------------------------------------------------------------- /data/subject_data.py: -------------------------------------------------------------------------------- 1 | from scipy import signal, stats#TODO remove import 2 | import time 3 | import os 4 | import torch 5 | import string 6 | import numpy as np 7 | import h5py 8 | # import numpy.typing as npt 9 | 10 | from torch.utils import data 11 | from .trial_data import TrialData 12 | from .trial_data_reader import TrialDataReader 13 | from typing import Optional, List, Dict, Any, Tuple 14 | import pandas as pd 15 | from types import SimpleNamespace 16 | 17 | class SubjectData(): 18 | def __init__(self, cfg) -> None: 19 | self.selected_electrodes = cfg.electrodes 20 | self.selected_words = cfg.words 21 | self.cfg = cfg 22 | self.words, self.neural_data, self.trials = self.get_subj_data(cfg.subject) 23 | 24 | def get_subj_data(self, subject): 25 | words, seeg_data, trials = [], [], [] 26 | cached_transcript_aligns = self.cfg.cached_transcript_aligns 27 | for trial in self.cfg.brain_runs: 28 | if cached_transcript_aligns: #TODO: I want to make this automatic 29 | cached_transcript_aligns = os.path.join(cached_transcript_aligns, subject, trial) 30 | os.makedirs(cached_transcript_aligns, exist_ok=True) 31 | self.cfg.cached_transcript_aligns = cached_transcript_aligns 32 | t = TrialData(subject, trial, self.cfg) 33 | reader = TrialDataReader(t, self.cfg) 34 | 35 | trial_words, seeg_trial_data = reader.get_aligned_predictor_matrix(duration=self.cfg.duration, delta=self.cfg.delta) 36 | assert (range(seeg_trial_data.shape[1]) == trial_words.index).all() 37 | trial_words['movie_id'] = t.movie_id 38 | trials.append(t) 39 | words.append(trial_words) 40 | seeg_data.append(seeg_trial_data) 41 | 42 | neural_data = np.concatenate(seeg_data, axis=1) 43 | #neural_data is [n_electrodes, n_words, n_samples] 44 | words_df = pd.concat(words) #NOTE the index will not be unique, but the location will 45 | return words_df, neural_data, trials 46 | -------------------------------------------------------------------------------- /data/test_split_trials.json: -------------------------------------------------------------------------------- 1 | {"m00183": ["trial001"], "m00184": ["trial006"], "m00185": ["trial000"], "m00191": ["trial004"], "m00187": ["trial000"], "m00195":["trial000"], "m00192":["trial000"]} 2 | -------------------------------------------------------------------------------- /data/throw_out_zeros.py: -------------------------------------------------------------------------------- 1 | #to write features to disk 2 | # python3 -m data.throw_out_zeros +data=pretrain_wavs_from_disk ++data.data=pretrain_data/manifests +preprocessor=wav_preprocessor +data_prep=throw_out_zeros 3 | from multiprocessing import Process, Queue 4 | from omegaconf import DictConfig, OmegaConf 5 | import hydra 6 | import models 7 | import tasks 8 | from pretrain.runner import Runner 9 | from datasets import build_dataset 10 | import logging 11 | import os 12 | from pathlib import Path 13 | import numpy as np 14 | from tqdm import tqdm as tqdm 15 | import csv 16 | import yaml 17 | import json 18 | 19 | #take dataset with wavs and write it to cache 20 | log = logging.getLogger(__name__) 21 | 22 | def check_features(pid, root_dir, files, cfg, qq): 23 | bad_files = [] 24 | for raw_file_path in tqdm(files): 25 | absolute_path = os.path.join(root_dir, raw_file_path) 26 | inputs = np.load(absolute_path) 27 | #print(inputs.shape) 28 | inputs = inputs.astype('float32') 29 | if (inputs[0]==inputs).all(): 30 | print("BAD FILE", raw_file_path) 31 | bad_files.append(raw_file_path) 32 | bad_files_q = qq.get() 33 | bad_files_q += bad_files 34 | qq.put(bad_files_q) 35 | 36 | def write_manifest(root_dir, all_bad_files): 37 | #map the original file to the cached feature 38 | manifest_path = os.path.join(root_dir, "manifests") 39 | Path(manifest_path).mkdir(parents=True, exist_ok=True) 40 | old_manifest_path = os.path.join(manifest_path, "manifest.tsv") 41 | old_rows = [] 42 | with open(old_manifest_path, "r", newline="") as f: 43 | reader = csv.reader(f, delimiter='\t') 44 | for i, row in tqdm(enumerate(reader)): 45 | old_rows.append(row) 46 | 47 | new_manifest_path = os.path.join(manifest_path, "new_manifest.tsv") 48 | log.info("Writing manifest") 49 | with open(new_manifest_path, "w", newline="") as f: 50 | writer = csv.writer(f, delimiter='\t') 51 | for i,row in tqdm(enumerate(old_rows)): 52 | if i==0: 53 | writer.writerow(row) 54 | else: 55 | path, size = row 56 | if path not in all_bad_files: 57 | writer.writerow(row) 58 | 59 | @hydra.main(version_base=None, config_path="../conf") 60 | def main(cfg: DictConfig) -> None: 61 | log.info("Writing data to disk") 62 | data_cfg = cfg.data 63 | dataset = build_dataset(data_cfg, preprocessor_cfg=cfg.preprocessor) 64 | 65 | assert hasattr(dataset, "files") 66 | 67 | files = dataset.files 68 | random.shuffle(files) 69 | files = files[:] 70 | root_dir = dataset.root_dir 71 | 72 | ps = [] 73 | n=cfg.data_prep.n_workers 74 | step = int(len(files)/n) + 1 75 | ranges = [(i*step, (i+1)*step) for i in range(n)] 76 | qq = Queue() 77 | qq.put([]) 78 | for index in range(n): 79 | start, end = ranges[index] 80 | idx_slice = files[start:end] 81 | log.info(f'Main : create and start process {index} with {start} to {end}') 82 | pid = index 83 | x = Process(target=check_features, args=(pid, root_dir, idx_slice, cfg, qq)) 84 | ps.append(x) 85 | x.start() 86 | 87 | for index, process in enumerate(ps): 88 | log.info(f'Main : before joining process {index}') 89 | process.join() 90 | log.info("Main : process %d done", index) 91 | all_bad_files = qq.get() 92 | 93 | out_path = "/storage/czw/self_supervised_seeg/data/all_zero_files_1.json" 94 | with open(out_path, "w") as f: 95 | json.dump(all_bad_files, f) 96 | write_manifest(root_dir, all_bad_files) 97 | 98 | if __name__=="__main__": 99 | main() 100 | 101 | -------------------------------------------------------------------------------- /data/timestamped_subject_data.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from scipy import signal, stats#TODO remove import 3 | import psutil 4 | import time 5 | import pytz 6 | import os 7 | import torch 8 | import string 9 | import numpy as np 10 | import h5py 11 | # import numpy.typing as npt 12 | 13 | from torch.utils import data 14 | from .h5_data import H5Data 15 | from .h5_data_reader import H5DataReader 16 | from typing import Optional, List, Dict, Any, Tuple 17 | import pandas as pd 18 | from types import SimpleNamespace 19 | 20 | class TimestampedSubjectData(): 21 | def __init__(self, cfg) -> None: 22 | self.selected_electrodes = cfg.electrodes 23 | self.selected_words = cfg.words 24 | self.cfg = cfg 25 | self.neural_data, self.trials, self.labels = self.get_subj_data(cfg.subject) 26 | 27 | #convert everything to US east coast time 28 | est = pytz.timezone('US/Eastern') 29 | self.labels = [t.astimezone(est) for t in self.labels] 30 | 31 | #Only get the night hours for a single day 32 | days = np.array([x.day for x in self.labels]) 33 | hours = np.array([x.hour for x in self.labels]) 34 | single_day_idxs = days==(days[0]+1) #hand selected day 35 | night_idxs = (hours >= 1) & (hours <= 5) 36 | single_night_idxs = (night_idxs) & (single_day_idxs) 37 | night_samples = np.array(self.labels)[single_night_idxs] 38 | self.labels = night_samples 39 | assert self.neural_data.shape[0]==1 40 | self.neural_data = self.neural_data[:,single_night_idxs] 41 | 42 | def get_subj_data(self, subject): 43 | seeg_data, trials, timestamps = [], [], [] 44 | for trial in self.cfg.brain_runs: 45 | t = H5Data(subject, trial, self.cfg) 46 | reader = H5DataReader(t, self.cfg) 47 | 48 | timestamp = t.get_timestamp() 49 | seeg_trial_data = reader.get_filtered_data() 50 | trials.append(t) 51 | duration = self.cfg.duration 52 | 53 | cutoff_len = int(seeg_trial_data.shape[-1] / (2048*duration))* 2048 * duration #how many samples should we take? 54 | cutoff_len = int(cutoff_len) 55 | seeg_trial_data = seeg_trial_data[:,:cutoff_len] 56 | seeg_trial_data = seeg_trial_data.reshape([seeg_trial_data.shape[0],-1, int(2048*duration)]) #NOTE hardcode 57 | trial_timestamps = [timestamp + timedelta(seconds=int(duration*i)) for i in range(seeg_trial_data.shape[1])] 58 | 59 | timestamps += trial_timestamps 60 | seeg_data.append(seeg_trial_data) 61 | 62 | assert len(self.cfg.brain_runs)==1 63 | seeg_data = np.concatenate(seeg_data) 64 | return seeg_data, trials, timestamps 65 | -------------------------------------------------------------------------------- /data/trial_data.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | import json 4 | import pandas as pd 5 | from typing import Tuple, Dict, List 6 | from types import SimpleNamespace 7 | import string 8 | from .h5_data import H5Data 9 | 10 | class TrialData(H5Data): 11 | def __init__(self, subject: str, trial, cfg) -> None: 12 | ''' 13 | input: 14 | subject=subject id 15 | trial=trial id 16 | data_dir=path to ecog data 17 | ''' 18 | super().__init__(subject, trial, cfg) 19 | self.trial_id = trial 20 | dataset_dir = cfg.raw_brain_data_dir 21 | 22 | # Path to trigger times csv file 23 | self.trigger_times_file = os.path.join(dataset_dir,f'data-by-subject/{subject}/data/trials/{trial}/trigger-times.csv') 24 | 25 | # Path to trial metadata json file 26 | self.metadata_file = os.path.join(dataset_dir,f'data-by-subject/{subject}/data/trials/{trial}/metadata.json') 27 | 28 | self.movie_id, _ = self.get_metadata() 29 | 30 | # Path to transcript csv file 31 | self.transcript_file = os.path.join(dataset_dir, f'transcripts/{self.movie_id}/manual/word-times-stanford.csv') 32 | 33 | def get_trigger_times(self) -> pd.DataFrame: 34 | ''' 35 | returns the trigger times for this subject and trial 36 | ''' 37 | trigs_df = pd.read_csv(self.trigger_times_file) 38 | return trigs_df 39 | 40 | def get_metadata(self) -> Tuple[str, Dict]: 41 | ''' 42 | returns movie id and meta data dictionary 43 | ''' 44 | with open(self.metadata_file, 'r') as f: 45 | meta_dict = json.load(f) 46 | movie_id = meta_dict['filename'] 47 | return movie_id, meta_dict 48 | 49 | def get_movie_transcript(self) -> pd.DataFrame: 50 | ''' 51 | returns dataframe of every word in the movie 52 | importantly, includes onset times for words 53 | ''' 54 | words_df = pd.read_csv(self.transcript_file).set_index('Unnamed: 0') 55 | words_df = words_df.dropna().reset_index(drop=True) 56 | #words_df['text'] = list(map(str.lower, words_df['text'])) 57 | #words_df['text'] = list(map(lambda s: s.translate(str.maketrans('', '', string.punctuation)), words_df['text'])) 58 | return words_df 59 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import logging 4 | from tqdm import tqdm as tqdm 5 | import numpy as np 6 | import csv 7 | from pathlib import Path 8 | 9 | log = logging.getLogger(__name__) 10 | 11 | def file_as_bytes(file): 12 | with file: 13 | return file.read() 14 | 15 | def compute_m5_hash(file_path): 16 | #from https://stackoverflow.com/a/3431835 17 | return hashlib.md5(file_as_bytes(open(file_path, 'rb'))).hexdigest() 18 | 19 | def stem_electrode_name(name): 20 | #names look like 'O1aIb4', 'O1aIb5', 'O1aIb6', 'O1aIb7' 21 | #names look like 'T1b2 22 | reverse_name = reversed(name) 23 | found_stem_end = False 24 | stem, num = [], [] 25 | for c in reversed(name): 26 | if c.isalpha(): 27 | found_stem_end = True 28 | if found_stem_end: 29 | stem.append(c) 30 | else: 31 | num.append(c) 32 | return ''.join(reversed(stem)), int(''.join(reversed(num))) 33 | 34 | def write_manifest(root_out, paths, lengths): 35 | absolute_path = Path(root_out).resolve() 36 | manifest_path = os.path.join(absolute_path, "manifests") 37 | Path(manifest_path).mkdir(exist_ok=True, parents=True) 38 | manifest_path = os.path.join(manifest_path, "manifest.tsv") 39 | with open(manifest_path, "w", newline="") as f: 40 | writer = csv.writer(f, delimiter='\t') 41 | writer.writerow((str(absolute_path),)) 42 | for row in zip(paths, lengths): 43 | writer.writerow(row) 44 | 45 | 46 | -------------------------------------------------------------------------------- /data/write_data_to_disk.py: -------------------------------------------------------------------------------- 1 | #usage: 2 | #to write wavs to disk 3 | #python3 -m data.write_data_to_disk +data=pretraining.yaml +data_prep=wavs_to_disk ++data.prep.out_dir=all_day_data 4 | #python3 -m data.write_preprocessed_inputs +data=tf_unmasked +data_prep=cwt_to_disk ++data.data=all_day_data 5 | from omegaconf import DictConfig, OmegaConf 6 | import hydra 7 | import logging 8 | from pathlib import Path 9 | import os 10 | from datasets import build_dataset 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | def write_trial_data(root_out, trial_id, data_cfg): 15 | subject_id = data_cfg.subject 16 | data_cfg.brain_runs = [trial_id] 17 | dataset = build_dataset(data_cfg) 18 | log.info(f'Writing {trial_id}') 19 | paths, lengths = [], [] 20 | trial_absolute_path = os.path.join(root_out, subject_id, trial_id) 21 | Path(trial_absolute_path).mkdir(exist_ok=True, parents=True) 22 | for i in tqdm(range(len(dataset))): 23 | example = dataset[i]["input"].squeeze() 24 | file_name = f'{i}.npy' 25 | relative_path = os.path.join(subject_id, trial_id, file_name) 26 | save_path = os.path.join(trial_absolute_path, file_name) 27 | np.save(save_path, example) 28 | paths.append(str(relative_path)) 29 | lengths.append(example.shape[0]) 30 | return paths, lengths 31 | 32 | @hydra.main(version_base=None, config_path="../conf") 33 | def main(cfg: DictConfig) -> None: 34 | log.info("Writing data to disk") 35 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 36 | log.info(f'Working directory {os.getcwd()}') 37 | 38 | #NOTE: we are going to write the data in chunks. There is a brain_runs argument in data_cfg that we will overwrite in our loop 39 | root_out = cfg.data_prep.out_dir 40 | data_cfg = cfg.data 41 | 42 | Path(root_out).mkdir(exist_ok=True, parents=True) 43 | paths, lengths = [], [] 44 | for trial_id in cfg.data_prep.brain_runs: 45 | p,l = write_trial_data(root_out, trial_id, data_cfg) 46 | paths = paths + p 47 | lengths = lengths + l 48 | write_manifest(root_out, paths, lengths) 49 | 50 | if __name__=="__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /data/write_preprocessed_inputs.py: -------------------------------------------------------------------------------- 1 | #to write features to disk 2 | #python3 -m data.write_preprocessed_inputs +data=tf_unmasked +data_prep=cwt_to_disk ++data.data=all_day_data 3 | from multiprocessing import Process 4 | from omegaconf import DictConfig, OmegaConf 5 | import hydra 6 | import models 7 | import tasks 8 | from pretrain.runner import Runner 9 | from datasets import build_dataset 10 | import logging 11 | import os 12 | from pathlib import Path 13 | import numpy as np 14 | from tqdm import tqdm as tqdm 15 | import csv 16 | import yaml 17 | 18 | #take dataset with wavs and write it to cache 19 | log = logging.getLogger(__name__) 20 | 21 | def write_features(pid, root_dir, files, extracter, cfg): 22 | absolute_root = Path(cfg.data_prep.out_dir).resolve() 23 | 24 | dirs_to_create = set([os.path.dirname(x) for x in files]) 25 | for out_dir in dirs_to_create: 26 | cached_absolute_path = os.path.join(absolute_root, out_dir) 27 | Path(cached_absolute_path).mkdir(parents=True, exist_ok=True) 28 | 29 | for raw_file_path in tqdm(files): 30 | cached_absolute_path = os.path.join(absolute_root, raw_file_path) 31 | absolute_path = os.path.join(root_dir, raw_file_path) 32 | inputs = np.load(absolute_path) 33 | feature = extracter(inputs) 34 | np.save(cached_absolute_path, feature) 35 | 36 | def write_manifest(files, cfg): 37 | #map the original file to the cached feature 38 | absolute_root = Path(cfg.data_prep.out_dir).resolve() 39 | paths = [] 40 | for raw_file_path in tqdm(files): 41 | paths.append((raw_file_path, raw_file_path)) 42 | 43 | manifest_path = os.path.join(absolute_root, "manifests") 44 | Path(manifest_path).mkdir(parents=True, exist_ok=True) 45 | cfg_path = os.path.join(manifest_path, "config.yaml") 46 | manifest_path = os.path.join(manifest_path, "manifest.tsv") 47 | log.info("Writing cfg") 48 | with open(cfg_path, 'w') as yamlfile: 49 | OmegaConf.save(config=cfg, f=yamlfile.name) 50 | 51 | log.info("Writing manifest") 52 | with open(manifest_path, "w", newline="") as f: 53 | writer = csv.writer(f, delimiter='\t') 54 | writer.writerow((absolute_root,)) 55 | writer.writerow(('orig_file', 'cached_feature')) 56 | for row in tqdm(paths): 57 | writer.writerow(row) 58 | 59 | @hydra.main(version_base=None, config_path="../conf") 60 | def main(cfg: DictConfig) -> None: 61 | log.info("Writing data to disk") 62 | data_cfg = cfg.data 63 | dataset = build_dataset(data_cfg, preprocessor_cfg=cfg.preprocessor) 64 | 65 | assert hasattr(dataset, "files") 66 | assert hasattr(dataset, "extracter") 67 | assert "preprocessor" in data_cfg 68 | 69 | extracter = dataset.extracter 70 | files = dataset.files 71 | root_dir = dataset.root_dir 72 | 73 | ps = [] 74 | n=cfg.data_prep.n_workers 75 | step = int(len(files)/n) + 1 76 | ranges = [(i*step, (i+1)*step) for i in range(n)] 77 | for index in range(n): 78 | start, end = ranges[index] 79 | idx_slice = files[start:end] 80 | log.info(f'Main : create and start process {index} with {start} to {end}') 81 | #write_features(root_dir, files, extracter, cfg) 82 | pid = index 83 | x = Process(target=write_features, args=(pid, root_dir, idx_slice, extracter, cfg)) 84 | ps.append(x) 85 | x.start() 86 | 87 | for index, process in enumerate(ps): 88 | log.info(f'Main : before joining process {index}') 89 | process.join() 90 | log.info("Main : process %d done", index) 91 | 92 | write_manifest(files, cfg) 93 | 94 | if __name__=="__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /data/write_pretrain_data_wavs.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process 2 | from omegaconf import DictConfig, OmegaConf 3 | import hydra 4 | import logging 5 | from pathlib import Path 6 | import os 7 | import json 8 | from datasets import build_dataset 9 | from .electrode_selection import get_clean_laplacian_electrodes 10 | from .utils import write_manifest 11 | import csv 12 | import glob 13 | import time 14 | from tqdm import tqdm as tqdm 15 | import numpy as np 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | def write_manifest(manifest_path, root_out, paths, lengths): 20 | absolute_path = Path(root_out).resolve() 21 | manifest_path = os.path.join(manifest_path, "manifests") 22 | Path(manifest_path).mkdir(exist_ok=True, parents=True) 23 | manifest_path = os.path.join(manifest_path, "manifest.tsv") 24 | with open(manifest_path, "w", newline="") as f: 25 | writer = csv.writer(f, delimiter='\t') 26 | writer.writerow((str(absolute_path),)) 27 | for row in zip(paths, lengths): 28 | writer.writerow(row) 29 | 30 | def write_trial_data(root_out, trial_id, data_cfg): 31 | log.info(f'Writing {trial_id}') 32 | subject_id = data_cfg.subject 33 | data_cfg.brain_runs = [trial_id] 34 | electrodes = data_cfg.electrodes 35 | paths, lengths = [], [] 36 | trial_absolute_path = os.path.join(root_out, subject_id, trial_id) 37 | Path(trial_absolute_path).mkdir(exist_ok=True, parents=True) 38 | 39 | global_i = 0 40 | for electrode in tqdm(electrodes):#iterate over electrodes here to save memory 41 | data_cfg_copy = data_cfg.copy() 42 | data_cfg_copy.electrodes = [electrode] 43 | dataset = build_dataset(data_cfg_copy) 44 | for i in range(len(dataset)): 45 | print("index", global_i) 46 | example = dataset[i]["input"].squeeze() 47 | file_name = f'{global_i}.npy' 48 | relative_path = os.path.join(subject_id, trial_id, file_name) 49 | save_path = os.path.join(trial_absolute_path, file_name) 50 | np.save(save_path, example) 51 | paths.append(str(relative_path)) 52 | lengths.append(example.shape[0]) 53 | global_i += 1 54 | manifest_path = os.path.join(root_out, "manifests", subject_id, trial_id) 55 | Path(manifest_path).mkdir(exist_ok=True, parents=True) 56 | write_manifest(manifest_path, root_out, paths, lengths) 57 | return paths, lengths 58 | 59 | def write_absolute_manifests(root_out): 60 | absolute_path = Path(root_out).resolve() 61 | all_tsvs = glob.glob(os.path.join(absolute_path, "manifests/*/*/*/*.tsv")) 62 | header = "" 63 | all_rows = [] 64 | for f in all_tsvs: 65 | with open(f, "r") as fd: 66 | rd = csv.reader(fd, delimiter="\t", quotechar='"') 67 | for i, row in enumerate(rd): 68 | if i==0: 69 | header = row 70 | else: 71 | all_rows.append(row) 72 | 73 | manifest_path = os.path.join(root_out, "manifests", "manifest.tsv") 74 | with open(manifest_path, "w", newline="") as f: 75 | writer = csv.writer(f, delimiter='\t') 76 | writer.writerow(header) 77 | for row in all_rows: 78 | writer.writerow(row) 79 | 80 | def single_process(cfg, subject_splits): 81 | root_out = cfg.data_prep.out_dir 82 | data_cfg = cfg.data 83 | 84 | paths, lengths = [], [] 85 | Path(root_out).mkdir(exist_ok=True, parents=True) 86 | 87 | for subject in subject_splits: 88 | print("subject", subject) 89 | for trial in subject_splits[subject]: 90 | print("trial", trial) 91 | data_cfg.brain_runs=[trial] 92 | data_cfg.electrodes = get_clean_laplacian_electrodes(subject, data_root=cfg.data.raw_brain_data_dir) 93 | data_cfg.subject = subject 94 | write_trial_data(root_out, trial, data_cfg) 95 | 96 | @hydra.main(version_base=None, config_path="../conf") 97 | def main(cfg: DictConfig) -> None: 98 | log.info("Writing data to disk") 99 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 100 | log.info(f'Working directory {os.getcwd()}') 101 | 102 | start_time = time.time() 103 | pretrain_split_path = cfg.data_prep.pretrain_split 104 | with open(pretrain_split_path) as f: 105 | pretrain_split = json.load(f) 106 | 107 | subject_splits = {} 108 | for i,k in enumerate(pretrain_split): 109 | idx = i%2 110 | if idx not in subject_splits: 111 | subject_splits[idx] = {} 112 | subject_splits[idx][k] = pretrain_split[k] 113 | 114 | #subject splits maps process_id to a subset of pretrain split 115 | ps = [] 116 | for i in subject_splits: 117 | x = Process(target=single_process, args=(cfg, subject_splits[i])) 118 | ps.append(x) 119 | x.start() 120 | 121 | for index, process in enumerate(ps): 122 | log.info(f'Main : before joining process {index}') 123 | process.join() 124 | log.info("Main : process %d done", index) 125 | 126 | root_out = cfg.data_prep.out_dir 127 | write_absolute_manifests(root_out) 128 | end_time = time.time() 129 | log.info(f'total time {(end_time - start_time)/60} minutes') 130 | 131 | 132 | if __name__=="__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #from .raw_wav_file_dataset import RawWavFileDataset 2 | #from .debug_dataset import DebugDataset 3 | # 4 | #__all__ = ["RawWavFileDataset", 5 | # "DebugDataset"] 6 | 7 | import importlib 8 | import os 9 | from pathlib import Path 10 | 11 | DATASET_REGISTRY = {} 12 | 13 | __all__ = ["build_dataset"] 14 | 15 | def build_dataset(cfg, *args, **kwargs): 16 | dataset_name = cfg.name 17 | assert dataset_name in DATASET_REGISTRY 18 | dataset = DATASET_REGISTRY[dataset_name](cfg, *args, **kwargs) 19 | return dataset 20 | 21 | def register_dataset(name): 22 | def register_dataset_cls(cls): 23 | if name in DATASET_REGISTRY: 24 | raise ValueError(f'{name} already in registry') 25 | else: 26 | DATASET_REGISTRY[name] = cls 27 | return cls 28 | return register_dataset_cls 29 | 30 | def import_datasets(): 31 | for file in os.listdir(os.path.dirname(__file__)): 32 | if file.endswith(".py") and not file.startswith("_"): 33 | module_name = str(Path(file).with_suffix("")) 34 | importlib.import_module('datasets.'+module_name) 35 | 36 | import_datasets() 37 | -------------------------------------------------------------------------------- /datasets/base_tf_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from torch.utils import data 4 | import os 5 | import numpy as np 6 | from scipy.io import wavfile 7 | from datasets import register_dataset 8 | from preprocessors import STFTPreprocessor 9 | 10 | class BaseTFDataset(data.Dataset): 11 | #Parent time-frequency dataset 12 | def __init__(self, cfg, task_cfg=None): 13 | extracter = DebugPreprocessor() 14 | self.cfg = cfg 15 | self.task_cfg = task_cfg 16 | manifest_path = cfg.data 17 | manifest_path = os.path.join(manifest_path, "train.tsv") 18 | with open(manifest_path, "r") as f: 19 | lines = f.readlines() 20 | self.root_dir = lines[0].strip() 21 | files, lengths = [], [] 22 | for x in lines[1:]: 23 | row = x.strip().split('\t') 24 | files.append(row[0]) 25 | lengths.append(row[1]) 26 | self.files, self.lengths = files, lengths 27 | self.extracter = extracter 28 | 29 | def get_input_dim(self): 30 | item = self.__getitem__(0) 31 | return item["masked_input"].shape[-1] 32 | 33 | def mask_time(self, data): 34 | mask_label = torch.zeros_like(data) 35 | 36 | consecutive_min = self.task_cfg.time_mask_consecutive_min 37 | consecutive_max = self.task_cfg.time_mask_consecutive_max 38 | assert consecutive_min <= consecutive_max 39 | assert consecutive_max < data.shape[0] 40 | valid_starts = range(len(data)-consecutive_max) 41 | masked_steps = [i for i in valid_starts if random.random() < self.task_cfg.time_mask_p] 42 | masked_steps = [(i, i+random.randint(consecutive_min, consecutive_max)) for i in masked_steps] 43 | 44 | for (start,end) in masked_steps: 45 | mask_label[start:end,:] = 1 46 | 47 | masked_data = torch.clone(data) 48 | for (start,end) in masked_steps: 49 | if random.random() < 0.85: #NOTE hardcode dice 50 | masked_data[start:end,:] = 0 51 | 52 | return masked_data, mask_label 53 | 54 | def __len__(self): 55 | return len(self.lengths) 56 | 57 | def __getitem__(self, idx): 58 | file_name = self.files[idx] 59 | file_name = os.path.join(self.root_dir, file_name) 60 | #raw_wave = np.load(file_name) 61 | samplerate, data = wavfile.read(file_name) 62 | 63 | data = data.astype('float32') 64 | #rand_len = random.randrange(1000, len(data), 1) 65 | rand_len = -1 66 | data = data[:rand_len] 67 | data = self.extracter(data) 68 | 69 | masked_data, mask_label = self.mask_time(data) 70 | 71 | return {"masked_input": masked_data, 72 | "length": data.shape[0], 73 | "mask_label": mask_label, 74 | "target": data} 75 | 76 | -------------------------------------------------------------------------------- /datasets/masked_tf_dataset.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import torch 3 | import random 4 | from torch.utils import data 5 | import os 6 | import numpy as np 7 | from scipy.io import wavfile 8 | from datasets import register_dataset 9 | from preprocessors import STFTPreprocessor 10 | from util.mask_utils import mask_inputs 11 | 12 | @register_dataset(name="masked_tf_dataset") 13 | class MaskedTFDataset(data.Dataset): 14 | def __init__(self, cfg, task_cfg=None, preprocessor_cfg=None): 15 | #THE PLAN 16 | #also make masked_tf_datased_from_cached 17 | self.cfg = cfg 18 | self.task_cfg = task_cfg 19 | manifest_path = cfg.data 20 | manifest_path = os.path.join(manifest_path, "manifest.tsv") 21 | with open(manifest_path, "r") as f: 22 | lines = f.readlines() 23 | self.root_dir = lines[0].strip() 24 | files, lengths = [], [] 25 | for x in lines[1:]: 26 | row = x.strip().split('\t') 27 | files.append(row[0]) 28 | lengths.append(row[1]) 29 | self.files, self.lengths = files, lengths 30 | 31 | self.cached_features = None 32 | 33 | if 'cached_features' in cfg: 34 | self.cached_features = cfg.cached_features 35 | self.initialize_cached_features(cfg.cached_features) 36 | elif preprocessor_cfg.name=="stft": 37 | extracter = STFTPreprocessor(preprocessor_cfg) 38 | self.extracter = extracter 39 | else: 40 | raise RuntimeError("Specify preprocessor") 41 | 42 | def initialize_cached_features(self, cache_root): 43 | cfg_path = os.path.join(cache_root, "config.yaml") 44 | loaded = OmegaConf.load(cfg_path) 45 | assert self.cfg.preprocessor == loaded.data.preprocessor 46 | 47 | manifest_path = os.path.join(cache_root, "manifest.tsv") 48 | with open(manifest_path, "r") as f: 49 | lines = f.readlines() 50 | self.cache_root_dir = lines[0].strip() 51 | orig2cached = {} #Map original file to cached feature file 52 | for x in lines[2:]: 53 | row = x.strip().split('\t') 54 | orig2cached[row[0]] = row[1] 55 | self.orig2cached = orig2cached 56 | 57 | def get_input_dim(self): 58 | item = self.__getitem__(0) 59 | return item["masked_input"].shape[-1] 60 | 61 | 62 | def __len__(self): 63 | return len(self.lengths) 64 | 65 | def get_cached_features(self, file_name): 66 | file_name = self.orig2cached[file_name] 67 | file_name = os.path.join(self.cache_root_dir, file_name) 68 | data = np.load(file_name) 69 | data = np.nan_to_num(data) #For superlet caches 70 | data = torch.FloatTensor(data) 71 | return data 72 | 73 | def __getitem__(self, idx): 74 | file_name = self.files[idx] 75 | file_path = os.path.join(self.root_dir, file_name) 76 | data = np.load(file_path) 77 | 78 | data = data.astype('float32') 79 | #rand_len = random.randrange(1000, len(data), 1) 80 | rand_len = -1 81 | wav = data[:rand_len] 82 | 83 | if self.cached_features: 84 | data = self.get_cached_features(file_name) 85 | else: 86 | data = self.extracter(wav) 87 | 88 | masked_data, mask_label = mask_inputs(data, self.task_cfg) 89 | return {"masked_input": masked_data, 90 | "length": data.shape[0], 91 | "mask_label": mask_label, 92 | "wav": wav, 93 | "target": data} 94 | -------------------------------------------------------------------------------- /datasets/pretraining_multi_elec_wavs_in_mem.py: -------------------------------------------------------------------------------- 1 | #python3 -m data.create_data_dirs +data=pretraining_subject3 +hydra.job.chdir=False 2 | from data.electrode_subject_data import ElectrodeSubjectData 3 | from torch.utils import data 4 | import numpy as np 5 | from datasets import register_dataset 6 | 7 | def get_electrode_subj_data(args): 8 | s = ElectrodeSubjectData(args.subject, args) 9 | return s 10 | 11 | @register_dataset(name="pretraining_wavs_multi_elec_in_mem") 12 | class PretrainingMultiElecWavsInMem(data.Dataset): 13 | #NOTE: this is to be used in pre-training, while the other class in this file is to be used during fine-tuning 14 | #Takes multiple electrodes and makes each one its own example 15 | def __init__(self, args) -> None: 16 | subject_data = get_electrode_subj_data(args) 17 | self.subject_data = subject_data 18 | self.seeg_data = subject_data.neural_data 19 | self.seeg_data = np.transpose(self.seeg_data, [1,0,2]) 20 | self.seeg_data = self.seeg_data.reshape([-1, self.seeg_data.shape[-1]]) 21 | 22 | def __len__(self): 23 | ''' 24 | returns: 25 | Number of words in the dataset 26 | ''' 27 | return self.seeg_data.shape[0] 28 | 29 | def __getitem__(self, idx: int): 30 | 31 | #NOTE: remember not to load to cuda here 32 | target = self.seeg_data[idx] 33 | return { 34 | "input" : target, 35 | } 36 | -------------------------------------------------------------------------------- /datasets/pretraining_wavs_in_mem.py: -------------------------------------------------------------------------------- 1 | from data.electrode_subject_data import ElectrodeSubjectData 2 | from torch.utils import data 3 | import numpy as np 4 | from datasets import register_dataset 5 | 6 | def get_electrode_subj_data(args): 7 | s = ElectrodeSubjectData(args.subject, args) 8 | return s 9 | 10 | @register_dataset(name="pretraining_wavs_in_mem") 11 | class PretrainingWavsInMem(data.Dataset): 12 | #NOTE: this is to be used in pre-training, while the other class in this file is to be used during fine-tuning 13 | def __init__(self, args) -> None: 14 | subject_data = get_electrode_subj_data(args) 15 | self.subject_data = subject_data 16 | self.seeg_data = subject_data.neural_data 17 | self.seeg_data = np.transpose(self.seeg_data, [1,0,2]) 18 | 19 | def __len__(self): 20 | ''' 21 | returns: 22 | Number of words in the dataset 23 | ''' 24 | return self.seeg_data.shape[0] 25 | 26 | def __getitem__(self, idx: int): 27 | 28 | #NOTE: remember not to load to cuda here 29 | target = self.seeg_data[idx] 30 | return { 31 | "input" : target, 32 | } 33 | -------------------------------------------------------------------------------- /datasets/raw_wav_file_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | import numpy as np 3 | from datasets import register_dataset 4 | import os 5 | 6 | @register_dataset(name="raw_wav_file_dataset") 7 | class RawWavFileDataset(data.Dataset): 8 | def __init__(self, cfg, task_cfg=None, preprocessor_cfg=None): 9 | self.cfg = cfg 10 | manifest_path = cfg.data 11 | manifest_path = os.path.join(manifest_path, "manifest.tsv") 12 | with open(manifest_path, "r") as f: 13 | lines = f.readlines() 14 | self.root_dir = lines[0].strip() 15 | files, lengths = [], [] 16 | for x in lines[1:]: 17 | row = x.strip().split('\t') 18 | files.append(row[0]) 19 | lengths.append(row[1]) 20 | self.files, self.lengths = files, lengths 21 | 22 | def get_input_dim(self): 23 | item = self.__getitem__(0) 24 | return item["input"].shape[-1] 25 | 26 | def __len__(self): 27 | return len(self.lengths) 28 | 29 | def __getitem__(self, idx): 30 | file_name = self.files[idx] 31 | file_name = os.path.join(self.root_dir, file_name) 32 | data = np.load(file_name) 33 | 34 | wav = data.astype('float32') 35 | 36 | return {"input": wav} 37 | -------------------------------------------------------------------------------- /datasets/single_subject_all_electrode.py: -------------------------------------------------------------------------------- 1 | from data.subject_data import SubjectData 2 | from datasets import register_dataset 3 | from .finetuning_datasets import BaseFinetuning 4 | import pandas as pd 5 | import numpy as np 6 | 7 | @register_dataset(name="single_subject_all_electrode") 8 | class SingleSubjectAllElectrode(BaseFinetuning): 9 | def __init__(self, cfg, task_cfg=None, preprocessor_cfg=None) -> None: 10 | 11 | super().__init__(cfg, preprocessor_cfg=preprocessor_cfg) 12 | s = SubjectData(cfg) 13 | self.regions_file = s.trials[0].regions_file 14 | self.regions = pd.read_csv(self.regions_file) 15 | 16 | self.word_df = s.words 17 | self.seeg_data = s.neural_data 18 | assert len(self.cfg.electrodes) == self.seeg_data.shape[0] 19 | 20 | if cfg.onsets_only: 21 | onset_idxs = self.word_df.loc[self.word_df.is_onset.astype(bool)].index.tolist() 22 | self.word_df = self.word_df.loc[onset_idxs] 23 | self.seeg_data = self.seeg_data[:,onset_idxs] 24 | 25 | all_ordered_labels = s.trials[0].get_brain_region_localization() 26 | selected = [(i,e) for i,e in enumerate(all_ordered_labels) if e in cfg.electrodes] 27 | sel_idxs, sel_labels = zip(*selected) 28 | self.electrode_labels = list(sel_labels) 29 | assert len(self.electrode_labels) == len(self.cfg.electrodes) 30 | 31 | def get_source(self, idx, use_cache=True): 32 | wavs = self.seeg_data[:,idx].astype('float32') # a matrix of size [n_electrodes, n_samples] 33 | specs = [] 34 | for j in range(wavs.shape[0]): 35 | wav = wavs[j] 36 | specs.append(self.extracter(wav)) 37 | specs = np.stack(specs) 38 | length = specs.shape[1] 39 | return length, specs, wavs 40 | 41 | def __getitem__(self, idx): 42 | sentence_activities = self.seeg_data[:,idx,:] 43 | length, specs, wav = self.get_source(idx) 44 | 45 | all_sentence_data = { 46 | "labels": self.electrode_labels, 47 | "seeg_data": specs 48 | } 49 | return all_sentence_data 50 | 51 | def __len__(self): 52 | return self.seeg_data.shape[1] 53 | -------------------------------------------------------------------------------- /datasets/tf_unmasked.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from torch.utils import data 4 | import os 5 | import numpy as np 6 | from datasets import register_dataset 7 | from preprocessors import STFTPreprocessor, MoreletPreprocessor, SuperletPreprocessor 8 | 9 | @register_dataset(name="tf_unmasked") 10 | #This is supposed to look like the datasets in finetuning.py 11 | class TFUnmasked(data.Dataset): 12 | def __init__(self, cfg, task_cfg=None, preprocessor_cfg=None): 13 | self.cfg = cfg 14 | self.task_cfg = task_cfg 15 | manifest_path = cfg.data 16 | manifest_path = os.path.join(manifest_path, "manifest.tsv") 17 | with open(manifest_path, "r") as f: 18 | lines = f.readlines() 19 | self.root_dir = lines[0].strip() 20 | files, lengths = [], [] 21 | for x in lines[1:]: 22 | row = x.strip().split('\t') 23 | files.append(row[0]) 24 | lengths.append(row[1]) 25 | self.files, self.lengths = files, lengths 26 | 27 | if preprocessor_cfg.name == "stft": 28 | extracter = STFTPreprocessor() 29 | elif preprocessor_cfg.name == 'morelet': 30 | extracter = MoreletPreprocessor() 31 | elif preprocessor_cfg.name == 'superlet': 32 | extracter = SuperletPreprocessor(preprocessor_cfg) 33 | else: 34 | raise RuntimeError("Specify a preprocessor") 35 | self.extracter = extracter 36 | 37 | def get_input_dim(self): 38 | item = self.__getitem__(0) 39 | return item["input"].shape[-1] 40 | 41 | def __len__(self): 42 | return len(self.lengths) 43 | 44 | def __getitem__(self, idx): 45 | file_name = self.files[idx] 46 | file_name = os.path.join(self.root_dir, file_name) 47 | data = np.load(file_name) 48 | 49 | data = data.astype('float32') 50 | #rand_len = random.randrange(1000, len(data), 1) 51 | rand_len = -1 52 | wav = data[:rand_len] 53 | data = self.extracter(wav) 54 | 55 | return {"input": data, 56 | "label": 1} 57 | 58 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | #to write features to disk 2 | #python3 -m data.write_preprocessed_inputs +data=tf_unmasked +data_prep=cwt_to_disk ++data.data=all_day_data 3 | from multiprocessing import Process 4 | from omegaconf import DictConfig, OmegaConf 5 | import hydra 6 | import models 7 | import tasks 8 | from runner import Runner 9 | from datasets import build_dataset 10 | import logging 11 | import os 12 | from pathlib import Path 13 | import numpy as np 14 | from tqdm import tqdm as tqdm 15 | import csv 16 | import yaml 17 | import preprocessors 18 | 19 | #take dataset with wavs and write it to cache 20 | log = logging.getLogger(__name__) 21 | 22 | def singleprocess_save_pretrained_spec_cache(pid, idxs, cache_path, seeg_data, extracter, extracter_cache_path, read_spec_from_cache=True): 23 | #save features for the case where you have the spectrograms cached, and now need to push them through the pretrained model 24 | for idx in tqdm(idxs): 25 | if read_spec_from_cache: 26 | spec = np.load(os.path.join(extracter_cache_path, f"{idx}.npy")) 27 | else: 28 | spec = extracter.spec_preprocessor(seeg_data[idx]) 29 | item = extracter(seeg_data[idx], spec_preprocessed=spec) 30 | save_path = os.path.join(cache_path, f'{idx}.npy') 31 | np.save(save_path, item.numpy()) 32 | 33 | def singleprocess_save_cache(pid, idxs, cache_path, seeg_data, extracter): 34 | for idx in tqdm(idxs): 35 | item = extracter(seeg_data[idx]) 36 | save_path = os.path.join(cache_path, f'{idx}.npy') 37 | np.save(save_path, item.numpy()) 38 | 39 | def save_cache(idxs, cache_path, seeg_data, extracter): 40 | if isinstance(extracter, preprocessors.spec_pretrained.SpecPretrained): 41 | extracter_cache_path = os.path.join(cache_path, "cached_spec") 42 | Path(extracter_cache_path).mkdir(exist_ok=True, parents=True) 43 | if not isinstance(extracter.spec_preprocessor, preprocessors.stft.STFTPreprocessor): 44 | multiprocess_save_cache(idxs, extracter_cache_path, seeg_data, extracter.spec_preprocessor) 45 | singleprocess_save_pretrained_spec_cache(0, idxs, cache_path, seeg_data, extracter, extracter_cache_path) 46 | else: 47 | singleprocess_save_pretrained_spec_cache(0, idxs, cache_path, seeg_data, extracter, extracter_cache_path, read_spec_from_cache=False) 48 | else: 49 | multiprocess_save_cache(idxs, cache_path, seeg_data, extracter) 50 | 51 | def multiprocess_save_cache(idxs, cache_path, seeg_data, extracter): 52 | log.info("Writing data to disk") 53 | 54 | ps = [] 55 | n=32 56 | step = int(len(idxs)/n) + 1 57 | ranges = [(i*step, (i+1)*step) for i in range(n)] 58 | for index in range(n): 59 | start, end = ranges[index] 60 | slice_idxs = idxs[start:end] 61 | log.info(f'Main : create and start process {index} with {start} to {end}') 62 | pid = index 63 | x = Process(target=singleprocess_save_cache, args=(pid, slice_idxs, cache_path, seeg_data, extracter)) 64 | ps.append(x) 65 | x.start() 66 | 67 | for index, process in enumerate(ps): 68 | log.info(f'Main : before joining process {index}') 69 | process.join() 70 | log.info("Main : process %d done", index) 71 | -------------------------------------------------------------------------------- /linear_results/onset_finetuning/linear_results.json: -------------------------------------------------------------------------------- 1 | {"m00185": ["T1b2"]} 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from pathlib import Path 4 | 5 | MODEL_REGISTRY = {} 6 | 7 | __all__ = ["build_model"] 8 | 9 | def build_model(cfg, *args, **kwargs): 10 | model_name = cfg.name 11 | assert model_name in MODEL_REGISTRY 12 | model = MODEL_REGISTRY[model_name]() 13 | model.build_model(cfg, *args, **kwargs) 14 | return model 15 | 16 | def register_model(name): 17 | def register_model_cls(cls): 18 | if name in MODEL_REGISTRY: 19 | raise ValueError(f'{name} already in registry') 20 | else: 21 | MODEL_REGISTRY[name] = cls 22 | return cls 23 | return register_model_cls 24 | 25 | def import_models(): 26 | for file in os.listdir(os.path.dirname(__file__)): 27 | if file.endswith(".py") and not file.startswith("_"): 28 | module_name = str(Path(file).with_suffix("")) 29 | importlib.import_module('models.'+module_name) 30 | import_models() 31 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class BaseModel(nn.Module): 4 | def __init__(self): 5 | super(BaseModel, self).__init__() 6 | 7 | def build_model(self, cfg): 8 | raise NotImplementedError 9 | 10 | def save_model_weights(self, states): 11 | #expects a new state with "models" key 12 | states["model"] = self.state_dict() 13 | return states 14 | 15 | def load_weights(self, states): 16 | self.load_state_dict(states) 17 | -------------------------------------------------------------------------------- /models/deep_linear_wav_baseline.py: -------------------------------------------------------------------------------- 1 | from models import register_model 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch 5 | from models.base_model import BaseModel 6 | from models.transformer_encoder_input import TransformerEncoderInput 7 | 8 | @register_model("deep_linear_wav_baseline") 9 | class DeepLinearWavModel(BaseModel): 10 | def __init__(self): 11 | super(DeepLinearWavModel, self).__init__() 12 | 13 | def forward(self, inputs): 14 | hidden = F.relu(self.linear1(inputs)) 15 | hidden = F.relu(self.linear2(hidden)) 16 | hidden = F.relu(self.linear3(hidden)) 17 | hidden = F.relu(self.linear4(hidden)) 18 | out = (self.linear_out(hidden)) 19 | return out 20 | 21 | def build_model(self, cfg, input_dim): 22 | self.cfg = cfg 23 | self.linear1 = nn.Linear(in_features=input_dim, out_features=1024) 24 | self.linear2 = nn.Linear(in_features=1024, out_features=512) 25 | self.linear3 = nn.Linear(in_features=512, out_features=256) 26 | self.linear4 = nn.Linear(in_features=256, out_features=128) 27 | self.linear_out = nn.Linear(in_features=128, out_features=1) #TODO hardcode out_features 28 | #TODO hardcode in_features 29 | -------------------------------------------------------------------------------- /models/feature_extract_deep_model.py: -------------------------------------------------------------------------------- 1 | from models import register_model 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch 5 | from models.base_model import BaseModel 6 | from models.transformer_encoder_input import TransformerEncoderInput 7 | 8 | @register_model("feature_extract_deep_model") 9 | class FeatureExtractDeepModel(BaseModel): 10 | def __init__(self): 11 | super(FeatureExtractDeepModel, self).__init__() 12 | 13 | def forward(self, inputs): 14 | hidden = F.relu(self.linear1(inputs)) 15 | hidden = F.relu(self.linear2(hidden)) 16 | hidden = F.relu(self.linear3(hidden)) 17 | hidden = F.relu(self.linear4(hidden)) 18 | out = (self.linear_out(hidden)) 19 | return out 20 | 21 | def build_model(self, cfg): 22 | self.cfg = cfg 23 | self.linear1 = nn.Linear(in_features=cfg.input_dim, out_features=1024) 24 | self.linear2 = nn.Linear(in_features=1024, out_features=512) 25 | self.linear3 = nn.Linear(in_features=512, out_features=256) 26 | self.linear4 = nn.Linear(in_features=256, out_features=128) 27 | self.linear_out = nn.Linear(in_features=128, out_features=1) #TODO hardcode out_features 28 | 29 | -------------------------------------------------------------------------------- /models/feature_extract_hidden.py: -------------------------------------------------------------------------------- 1 | from models import register_model 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch 5 | from models.base_model import BaseModel 6 | from models.transformer_encoder_input import TransformerEncoderInput 7 | 8 | @register_model("feature_extract_hidden_model") 9 | class FeatureExtractHiddenModel(BaseModel): 10 | def __init__(self): 11 | super(FeatureExtractHiddenModel, self).__init__() 12 | 13 | def forward(self, inputs): 14 | hidden = F.relu(self.linear_out1(inputs)) 15 | out = self.linear_out(hidden) 16 | return out 17 | 18 | def build_model(self, cfg): 19 | self.cfg = cfg 20 | self.linear_out1 = nn.Linear(in_features=cfg.input_dim, out_features=50) 21 | self.linear_out = nn.Linear(in_features=50, out_features=1) #TODO hardcode out_features 22 | 23 | -------------------------------------------------------------------------------- /models/feature_extract_model.py: -------------------------------------------------------------------------------- 1 | from models import register_model 2 | import torch.nn as nn 3 | import torch 4 | from models.base_model import BaseModel 5 | from models.transformer_encoder_input import TransformerEncoderInput 6 | 7 | @register_model("feature_extract_model") 8 | class FeatureExtractModel(BaseModel): 9 | def __init__(self): 10 | super(FeatureExtractModel, self).__init__() 11 | 12 | def forward(self, inputs): 13 | out = self.linear_out(inputs) 14 | return out 15 | 16 | def build_model(self, cfg): 17 | self.cfg = cfg 18 | self.linear_out = nn.Linear(in_features=cfg.input_dim, out_features=1) #TODO hardcode out_features 19 | #TODO hardcode in_features 20 | -------------------------------------------------------------------------------- /models/finetune_model.py: -------------------------------------------------------------------------------- 1 | from models import register_model 2 | import torch.nn as nn 3 | import torch 4 | from models.base_model import BaseModel 5 | from models.transformer_encoder_input import TransformerEncoderInput 6 | 7 | @register_model("finetune_model") 8 | class FinetuneModel(BaseModel): 9 | def __init__(self): 10 | super(FinetuneModel, self).__init__() 11 | 12 | def forward(self, inputs, pad_mask): 13 | if self.frozen_upstream: 14 | self.upstream.eval() 15 | with torch.no_grad(): 16 | outputs = self.upstream(inputs, pad_mask, intermediate_rep=True) 17 | else: 18 | outputs = self.upstream(inputs, pad_mask, intermediate_rep=True) 19 | middle = int(outputs.shape[1]/2) 20 | outputs = outputs[:,middle-5:middle+5].mean(axis=1) 21 | out = self.linear_out(outputs) 22 | return out 23 | 24 | def build_model(self, cfg, upstream_model): 25 | self.cfg = cfg 26 | self.upstream = upstream_model 27 | self.upstream_cfg = self.upstream.cfg 28 | hidden_dim = self.upstream_cfg.hidden_dim 29 | self.linear_out = nn.Linear(in_features=hidden_dim, out_features=1) #TODO hardcode out_features 30 | self.frozen_upstream = cfg.frozen_upstream 31 | 32 | -------------------------------------------------------------------------------- /models/hidden_linear_wav_model.py: -------------------------------------------------------------------------------- 1 | from models import register_model 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torch 5 | from models.base_model import BaseModel 6 | from models.transformer_encoder_input import TransformerEncoderInput 7 | 8 | @register_model("hidden_linear_wav_baseline") 9 | class HiddenLinearWavBaseline(BaseModel): 10 | def __init__(self): 11 | super(HiddenLinearWavBaseline, self).__init__() 12 | 13 | def forward(self, inputs): 14 | hidden = F.relu(self.linear1(inputs)) 15 | out = F.relu(self.linear_out(hidden)) 16 | return out 17 | 18 | def build_model(self, cfg, input_dim): 19 | self.cfg = cfg 20 | self.linear1 = nn.Linear(in_features=input_dim, out_features=768) 21 | self.linear_out = nn.Linear(in_features=768, out_features=1) #TODO hardcode out_features 22 | #TODO hardcode in_features 23 | -------------------------------------------------------------------------------- /models/linear_spec_baseline.py: -------------------------------------------------------------------------------- 1 | from models import register_model 2 | import torch.nn as nn 3 | import torch 4 | from models.base_model import BaseModel 5 | from models.transformer_encoder_input import TransformerEncoderInput 6 | 7 | @register_model("linear_spec_baseline") 8 | class LinearSpecModel(BaseModel): 9 | def __init__(self): 10 | super(LinearWavModel, self).__init__() 11 | 12 | def forward(self, inputs): 13 | out = self.linear_out(inputs) 14 | return out 15 | 16 | def build_model(self, cfg, input_dim): 17 | self.cfg = cfg 18 | self.linear_out = nn.Linear(in_features=input_dim, out_features=1) #TODO hardcode out_features 19 | #TODO hardcode in_features 20 | -------------------------------------------------------------------------------- /models/linear_wav_baseline.py: -------------------------------------------------------------------------------- 1 | from models import register_model 2 | import torch.nn as nn 3 | import torch 4 | from models.base_model import BaseModel 5 | from models.transformer_encoder_input import TransformerEncoderInput 6 | 7 | @register_model("linear_wav_baseline") 8 | class LinearWavModel(BaseModel): 9 | def __init__(self): 10 | super(LinearWavModel, self).__init__() 11 | 12 | def forward(self, inputs): 13 | out = self.linear_out(inputs) 14 | return out 15 | 16 | def build_model(self, cfg, input_dim): 17 | self.cfg = cfg 18 | self.linear_out = nn.Linear(in_features=input_dim, out_features=1) #TODO hardcode out_features 19 | #TODO hardcode in_features 20 | -------------------------------------------------------------------------------- /models/masked_tf_model.py: -------------------------------------------------------------------------------- 1 | from models import register_model 2 | import torch.nn as nn 3 | from models.base_model import BaseModel 4 | from models.transformer_encoder_input import TransformerEncoderInput 5 | from models.spec_prediction_head import SpecPredictionHead 6 | 7 | @register_model("masked_tf_model") 8 | class MaskedTFModel(BaseModel): 9 | def __init__(self): 10 | super(MaskedTFModel, self).__init__() 11 | 12 | def forward(self, input_specs, src_key_mask, intermediate_rep=False, rep_from_layer=-1): 13 | input_specs, pos_enc = self.input_encoding(input_specs) 14 | input_specs = input_specs.transpose(0,1) #nn.Transformer wants [seq, batch, dim] 15 | if rep_from_layer==-1: 16 | output_specs = self.transformer(input_specs, src_key_padding_mask=src_key_mask) 17 | else: 18 | raise NotImplementedError 19 | output_specs = output_specs.transpose(0,1) #[batch, seq, dim] 20 | if intermediate_rep: 21 | return output_specs 22 | output_specs = self.spec_prediction_head(output_specs) 23 | return output_specs, pos_enc 24 | 25 | def init_weights(self, module): 26 | if isinstance(module, nn.Linear): 27 | if module.bias is not None: 28 | module.bias.data.zero_() 29 | if isinstance(module, nn.LayerNorm): 30 | module.bias.data.zero_() 31 | module.bias.data.fill_(1.0) 32 | 33 | def build_model(self, cfg): 34 | self.cfg = cfg 35 | hidden_dim = self.cfg.hidden_dim 36 | self.input_encoding = TransformerEncoderInput(cfg) 37 | encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=self.cfg.nhead, dim_feedforward=self.cfg.layer_dim_feedforward, activation=self.cfg.layer_activation) 38 | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=self.cfg.encoder_num_layers) 39 | self.spec_prediction_head = SpecPredictionHead(cfg) 40 | self.apply(self.init_weights) 41 | -------------------------------------------------------------------------------- /models/seeg_wav2vec.py: -------------------------------------------------------------------------------- 1 | from models import register_model 2 | import torch.nn as nn 3 | from models.base_model import BaseModel 4 | from models.transformer_encoder_input import TransformerEncoderInput 5 | from models.spec_prediction_head import SpecPredictionHead 6 | 7 | @register_model("seeg_wav2vec") 8 | class SeegWav2Vec(BaseModel): 9 | def __init__(self): 10 | super(SeegWav2Vec, self).__init__() 11 | 12 | def forward(self, inputs): 13 | print(inputs) 14 | import pdb; pdb.set_trace() 15 | return output_specs, pos_enc 16 | 17 | def init_weights(self, module): 18 | if isinstance(module, nn.Linear): 19 | if module.bias is not None: 20 | module.bias.data.zero_() 21 | if isinstance(module, nn.LayerNorm): 22 | module.bias.data.zero_() 23 | module.bias.data.fill_(1.0) 24 | 25 | def build_model(self, cfg): 26 | self.cfg = cfg 27 | hidden_dim = self.cfg.hidden_dim 28 | self.input_encoding = TransformerEncoderInput(cfg) 29 | encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=self.cfg.nhead, dim_feedforward=self.cfg.layer_dim_feedforward, activation=self.cfg.layer_activation) 30 | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=self.cfg.encoder_num_layers) 31 | self.spec_prediction_head = SpecPredictionHead(cfg) 32 | self.apply(self.init_weights) 33 | 34 | -------------------------------------------------------------------------------- /models/spec_prediction_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class SpecPredictionHead(nn.Module): 4 | def __init__(self, cfg): 5 | super(SpecPredictionHead, self).__init__() 6 | self.hidden_layer = nn.Linear(cfg.hidden_dim, cfg.hidden_dim) 7 | self.act_fn = None 8 | if cfg.layer_activation=="gelu": 9 | self.act_fn = nn.GELU() 10 | self.layer_norm = nn.LayerNorm(cfg.hidden_dim) 11 | self.output = nn.Linear(cfg.hidden_dim, cfg.input_dim) 12 | 13 | def forward(self, hidden): 14 | h = self.hidden_layer(hidden) 15 | h = self.act_fn(h) 16 | h = self.layer_norm(h) 17 | h = self.output(h) 18 | return h 19 | -------------------------------------------------------------------------------- /models/transformer_encoder_input.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | class PositionalEncoding(nn.Module): 6 | def __init__(self, d_model, dropout=0.1, max_len=5000): 7 | ''' 8 | From https://discuss.pytorch.org/t/how-to-modify-the-positional-encoding-in-torch-nn-transformer/104308/2 9 | ''' 10 | super(PositionalEncoding, self).__init__() 11 | 12 | pe = torch.zeros(max_len, d_model) 13 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 14 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 15 | pe[:, 0::2] = torch.sin(position * div_term) 16 | pe[:, 1::2] = torch.cos(position * div_term) 17 | pe = pe.unsqueeze(0) 18 | self.register_buffer('pe', pe) 19 | 20 | def forward(self, seq): 21 | #seq is [batch, len, dim] 22 | assert len(seq.shape) == 3 23 | pos_enc = self.pe[:,:seq.size(1),:] 24 | out = seq + pos_enc 25 | test = torch.zeros_like(seq) + pos_enc 26 | return out, pos_enc 27 | 28 | class TransformerEncoderInput(nn.Module): 29 | def __init__(self, cfg, dropout=0.1): 30 | super(TransformerEncoderInput, self).__init__() 31 | self.cfg = cfg 32 | self.in_proj = nn.Linear(in_features=cfg.input_dim, out_features=cfg.hidden_dim) 33 | self.positional_encoding = PositionalEncoding(self.cfg.hidden_dim) 34 | self.layer_norm = nn.LayerNorm(cfg.hidden_dim) 35 | self.dropout = nn.Dropout(p=dropout) 36 | 37 | def forward(self, input_specs): 38 | input_specs = self.in_proj(input_specs) 39 | input_specs, pos_enc = self.positional_encoding(input_specs) 40 | input_specs = self.layer_norm(input_specs) 41 | input_specs = self.dropout(input_specs) 42 | return input_specs, pos_enc 43 | -------------------------------------------------------------------------------- /notebooks/example_wav_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czlwang/BrainBERT/711983690a4c5502038c70473ef0c4f450e1f7fe/notebooks/example_wav_1.npy -------------------------------------------------------------------------------- /preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | from .stft import STFTPreprocessor 2 | from .morelet_preprocessor import MoreletPreprocessor 3 | from .superlet_preprocessor import SuperletPreprocessor 4 | from .wav_preprocessor import WavPreprocessor 5 | from .superlet import superlet 6 | from .spec_pretrained import SpecPretrained 7 | from .spec_pooled import SpecPooled 8 | 9 | 10 | __all__ = ["STFTPreprocessor", 11 | "MoreletPreprocessor", 12 | "SuperletPreprocessor", 13 | "superlet", 14 | "WavPreprocessor", 15 | "SpecPretrained", 16 | "SpecPooled" 17 | ] 18 | 19 | def build_preprocessor(preprocessor_cfg): 20 | if preprocessor_cfg.name == "stft": 21 | extracter = STFTPreprocessor(preprocessor_cfg) 22 | elif preprocessor_cfg.name == "superlet": 23 | extracter = SuperletPreprocessor(preprocessor_cfg) 24 | elif preprocessor_cfg.name == "wav_preprocessor": 25 | extracter = WavPreprocessor(preprocessor_cfg) 26 | elif preprocessor_cfg.name == "spec_pretrained": 27 | extracter = SpecPretrained(preprocessor_cfg) 28 | elif preprocessor_cfg.name == "spec_pooled_preprocessor": 29 | extracter = SpecPooled(preprocessor_cfg) 30 | else: 31 | raise ValueError("Specify preprocessor") 32 | return extracter 33 | -------------------------------------------------------------------------------- /preprocessors/morelet_preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import signal, stats 5 | import mne 6 | 7 | class MoreletPreprocessor(nn.Module): 8 | def get_morelet(self, sig, fs, freqs, normalizing=None, **kwargs): 9 | 10 | sig = np.expand_dims(np.expand_dims(sig, 0), 0) 11 | arr = mne.time_frequency.tfr_array_morlet(sig, fs, freqs, decim=60)#NOTE decim hardcode 12 | arr = arr[0,0] 13 | morelet = stats.zscore(np.abs(arr)[:,10:-10], axis=1) 14 | morelet = morelet.transpose(1,0) 15 | return morelet 16 | 17 | def __init__(self): 18 | super(MoreletPreprocessor, self).__init__() 19 | 20 | def forward(self, wav): 21 | freqs = list(np.arange(10,200,4)) 22 | morelet = self.get_morelet(wav, 2048, freqs) #TODO hardcode sampling rate 23 | return morelet 24 | -------------------------------------------------------------------------------- /preprocessors/spec_pooled.py: -------------------------------------------------------------------------------- 1 | from .stft import STFTPreprocessor 2 | from .morelet_preprocessor import MoreletPreprocessor 3 | from .superlet_preprocessor import SuperletPreprocessor 4 | import torch 5 | import torch.nn as nn 6 | import models 7 | import os 8 | 9 | #This preprocssor combines a spectrogram preprocessor with a feature extracter (transformer) 10 | 11 | def build_preprocessor(spec_name, preprocessor_cfg): 12 | if spec_name == "stft": 13 | extracter = STFTPreprocessor(preprocessor_cfg) 14 | elif spec_name == "superlet": 15 | extracter = SuperletPreprocessor(preprocessor_cfg) 16 | return extracter 17 | 18 | class SpecPooled(nn.Module): 19 | def __init__(self, cfg): 20 | super(SpecPooled, self).__init__() 21 | self.spec_preprocessor = build_preprocessor(cfg.spec_name, cfg) 22 | 23 | def forward(self, wav, spec_preprocessed=None): 24 | if spec_preprocessed is None: 25 | spec = self.spec_preprocessor(wav) 26 | else: 27 | spec = torch.FloatTensor(spec_preprocessed) 28 | inputs = spec.unsqueeze(0) #[batch, time, num_freq_channels] 29 | outputs = inputs 30 | middle = int(outputs.shape[1]/2) 31 | out = outputs[:,middle-5:middle+5].mean(axis=1) 32 | #out = outputs.mean(axis=1) 33 | out = out.squeeze(0) 34 | return out 35 | -------------------------------------------------------------------------------- /preprocessors/spec_pretrained.py: -------------------------------------------------------------------------------- 1 | from .stft import STFTPreprocessor 2 | from .morelet_preprocessor import MoreletPreprocessor 3 | from .superlet_preprocessor import SuperletPreprocessor 4 | import torch 5 | import torch.nn as nn 6 | import models 7 | import os 8 | 9 | #This preprocssor combines a spectrogram preprocessor with a feature extracter (transformer) 10 | 11 | def build_preprocessor(spec_name, preprocessor_cfg): 12 | if spec_name == "stft": 13 | extracter = STFTPreprocessor(preprocessor_cfg) 14 | elif spec_name == "superlet": 15 | extracter = SuperletPreprocessor(preprocessor_cfg) 16 | return extracter 17 | 18 | class SpecPretrained(nn.Module): 19 | def __init__(self, cfg): 20 | super(SpecPretrained, self).__init__() 21 | self.spec_preprocessor = build_preprocessor(cfg.spec_name, cfg) 22 | 23 | self.cfg = cfg 24 | ckpt_path = cfg.upstream_ckpt 25 | init_state = torch.load(ckpt_path) 26 | upstream_cfg = init_state["model_cfg"] 27 | if upstream_cfg.name=='debug_model': 28 | upstream_cfg.name='masked_tf_model' 29 | self.upstream = models.build_model(upstream_cfg) 30 | #model.module.load_weights(states) 31 | states = init_state["model"] 32 | self.upstream.load_weights(states) 33 | 34 | def forward(self, wav, spec_preprocessed=None): 35 | if spec_preprocessed is None: 36 | spec = self.spec_preprocessor(wav) 37 | else: 38 | spec = torch.FloatTensor(spec_preprocessed) 39 | inputs = spec.unsqueeze(0) #[batch, time, num_freq_channels] 40 | pad_mask = torch.zeros(1, spec.shape[0], dtype=bool) 41 | self.upstream.eval() 42 | middle = int(inputs.shape[1]/2) 43 | with torch.no_grad(): 44 | #clip=50 45 | #outputs = self.upstream(inputs[:,middle-clip:middle+clip], pad_mask[:, middle-clip:middle+clip], intermediate_rep=True) 46 | rep_from_layer = -1 47 | if "rep_from_layer" in self.cfg: 48 | rep_from_layer = self.cfg.rep_from_layer 49 | outputs = self.upstream(inputs, pad_mask, intermediate_rep=True, rep_from_layer=rep_from_layer) 50 | middle = int(outputs.shape[1]/2) 51 | out = outputs[:,middle-5:middle+5].mean(axis=1) 52 | #out = outputs.mean(axis=1) 53 | out = out.squeeze(0) 54 | return out 55 | -------------------------------------------------------------------------------- /preprocessors/stft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import signal, stats 5 | 6 | def _first(arr, axis): 7 | #from https://github.com/scipy/scipy/blob/v1.9.0/scipy/stats/_stats_py.py#L2662-L2730 8 | """Return arr[..., 0:1, ...] where 0:1 is in the `axis` position.""" 9 | return np.take_along_axis(arr, np.array(0, ndmin=arr.ndim), axis) 10 | 11 | def zscore(a, axis): 12 | #from https://github.com/scipy/scipy/blob/v1.9.0/scipy/stats/_stats_py.py#L2662-L2730 13 | mn = a.mean(axis=axis, keepdims=True) 14 | std = a.std(axis=axis, ddof=0, keepdims=True) 15 | 16 | std[(std==0)] = 1.0 #this is a hack. I should eventually find where the bad data is 17 | z = (a - mn) / std 18 | return z 19 | 20 | class STFTPreprocessor(nn.Module): 21 | def get_stft(self, x, fs, show_fs=-1, normalizing=None, **kwargs): 22 | f, t, Zxx = signal.stft(x, fs, **kwargs) 23 | 24 | if "return_onesided" in kwargs and kwargs["return_onesided"] == True: 25 | Zxx = Zxx[:show_fs] 26 | f = f[:show_fs] 27 | else: 28 | pass #TODO 29 | #Zxx = np.concatenate([Zxx[:,:,:show_fs], Zxx[:,:,-show_fs:]], axis=-1) 30 | #f = np.concatenate([f[:show_fs], f[-show_fs:]], axis=-1) 31 | 32 | Zxx = np.abs(Zxx) 33 | 34 | if normalizing=="zscore": 35 | Zxx = zscore(Zxx, axis=-1)#TODO is this order correct? I put it this way to prevent input nans 36 | if (Zxx.std() == 0).any(): 37 | Zxx = np.ones_like(Zxx) 38 | Zxx = Zxx[:,10:-10] 39 | elif normalizing=="db": 40 | Zxx = np.log(Zxx) 41 | 42 | if np.isnan(Zxx).any(): 43 | Zxx = np.nan_to_num(Zxx, nan=0.0) 44 | 45 | return f, t, torch.Tensor(np.transpose(Zxx)) 46 | 47 | def __init__(self, cfg): 48 | super(STFTPreprocessor, self).__init__() 49 | self.cfg = cfg 50 | 51 | def forward(self, wav): 52 | _,_,linear = self.get_stft(wav, 2048, show_fs=self.cfg.freq_channel_cutoff, nperseg=self.cfg.nperseg, noverlap=self.cfg.noverlap, normalizing=self.cfg.normalizing, return_onesided=True) #TODO hardcode sampling rate 53 | return linear 54 | -------------------------------------------------------------------------------- /preprocessors/superlet_preprocessor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from scipy import signal, stats 5 | from .superlet import superlet 6 | 7 | class SuperletPreprocessor(nn.Module): 8 | def get_superlet(self, s_data, order_min=2, order_max=12, c_1=3, foi=None): 9 | #s_data is [.., n_samples] 10 | s_data = np.transpose(s_data) 11 | def scale_from_period(period): 12 | return period / (2 * np.pi) 13 | 14 | fs = 2048 # sampling frequency 15 | # frequencies of interest in Hz 16 | if foi is None: 17 | foi = np.linspace(5, 200, 50) 18 | scales = scale_from_period(1 / foi) 19 | 20 | spec = superlet( 21 | s_data, 22 | samplerate=fs, 23 | scales=scales, 24 | order_max=order_max, 25 | order_min=order_min, 26 | c_1=c_1, 27 | adaptive=True, 28 | ) 29 | spec = np.abs(spec) 30 | decim = self.decim 31 | spec = spec[:,::decim] 32 | clip=5 33 | time = s_data.shape[0]/fs 34 | t = np.linspace(0,time,spec.shape[1]) 35 | t = t[clip:-clip] 36 | spec = stats.zscore(spec[:,clip:-clip], axis=1) 37 | spec = spec.transpose(1,0) 38 | spec = np.nan_to_num(spec) 39 | spec = torch.FloatTensor(spec) 40 | return t, foi, spec 41 | 42 | def __init__(self, cfg): 43 | super(SuperletPreprocessor, self).__init__() 44 | self.cfg = cfg 45 | self.c1 = cfg.c1 46 | self.order_max = cfg.order_max 47 | self.order_min = cfg.order_min 48 | self.decim = cfg.decim 49 | 50 | def forward(self, wav): 51 | foi = np.linspace(self.cfg.min_f,self.cfg.max_f,self.cfg.n_f_steps) 52 | t, fs, spec = self.get_superlet(wav, order_min=self.order_min, order_max=self.order_max, c_1=self.c1, foi=foi) 53 | return spec 54 | -------------------------------------------------------------------------------- /preprocessors/wav_preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import signal, stats 5 | 6 | class WavPreprocessor(nn.Module): 7 | def __init__(self, cfg): 8 | super(WavPreprocessor, self).__init__() 9 | self.cfg = cfg 10 | 11 | def forward(self, wav): 12 | middle = int(len(wav)/2) 13 | sr = self.cfg.sample_rate 14 | if "clip_seconds" in self.cfg: 15 | clip_window = int(sr*self.cfg.clip_seconds) 16 | assert clip_window*2 < len(wav) 17 | return wav[middle-clip_window:middle+clip_window] 18 | return wav 19 | -------------------------------------------------------------------------------- /pretrain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czlwang/BrainBERT/711983690a4c5502038c70473ef0c4f450e1f7fe/pretrain/__init__.py -------------------------------------------------------------------------------- /pretrain/spec2vec/spec2vec.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | @register_model 4 | class Spec2Vec(): 5 | def __init__(self): 6 | print("Spec2Vec") 7 | -------------------------------------------------------------------------------- /pretrain_data/manifests/manifest.tsv: -------------------------------------------------------------------------------- 1 | /storage//BrainBERT/pretrain_data 2 | subject_1/trial001/0.npy 10240 3 | subject_1/trial001/1.npy 10240 4 | subject_1/trial001/2.npy 10240 5 | subject_1/trial001/3.npy 10240 6 | subject_1/trial001/4.npy 10240 7 | subject_2/trial001/0.npy 10240 8 | subject_2/trial001/1.npy 10240 9 | subject_2/trial001/2.npy 10240 10 | subject_2/trial001/3.npy 10240 11 | subject_2/trial001/4.npy 10240 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | omegaconf 2 | hydra-core 3 | scipy 4 | h5py 5 | python-dateutil 6 | tqdm 7 | pandas 8 | mne 9 | psutil 10 | tensorboardX 11 | git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 12 | scikit-learn 13 | tensorboard 14 | torch_optimizer 15 | tables 16 | -------------------------------------------------------------------------------- /run_tests.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from omegaconf import DictConfig, OmegaConf 3 | import hydra 4 | import models 5 | import tasks 6 | from runner import Runner 7 | import logging 8 | import os 9 | from data.electrode_selection import get_clean_laplacian_electrodes 10 | import json 11 | from pathlib import Path 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | def run_subject_test(data_cfg, brain_runs, electrodes, cfg): 16 | data_cfg_copy = data_cfg.copy() 17 | cache_path = None 18 | if "cache_input_features" in data_cfg_copy: 19 | cache_path = data_cfg_copy.cache_input_features 20 | 21 | subject_test_results = {} 22 | for e in electrodes: 23 | data_cfg_copy.electrodes = [e] 24 | data_cfg_copy.brain_runs = brain_runs 25 | if cache_path is not None: 26 | #cache_path needs to identify the pretrained model 27 | e_cache_path = os.path.join(cache_path, data_cfg_copy.subject, data_cfg_copy.name ,e) 28 | log.info(f"logging input features in {e_cache_path}") 29 | data_cfg_copy.cache_input_features = e_cache_path 30 | cfg.data = data_cfg_copy 31 | task = tasks.setup_task(cfg.task) 32 | task.load_datasets(cfg.data, cfg.preprocessor) 33 | model = task.build_model(cfg.model) 34 | criterion = task.build_criterion(cfg.criterion) 35 | runner = Runner(cfg.exp.runner, task, model, criterion) 36 | best_model = runner.train() 37 | test_results = runner.test(best_model) 38 | subject_test_results[e] = test_results 39 | return subject_test_results 40 | 41 | def write_summary(all_test_results, out_path): 42 | out_json = os.path.join(out_path, "all_test_results.json") 43 | with open(out_json, "w") as f: 44 | json.dump(all_test_results, f) 45 | 46 | out_json = os.path.join(out_path, "summary.json") 47 | all_rocs = [] 48 | for s in all_test_results: 49 | for e in all_test_results[s]: 50 | all_rocs.append(all_test_results[s][e]["roc_auc"]) 51 | 52 | summary_results = {"avg_roc_auc": np.mean(all_rocs), "std_roc_auc": np.std(all_rocs)} 53 | with open(out_json, "w") as f: 54 | json.dump(summary_results, f) 55 | 56 | log.info(f"Wrote test results to {out_path}") 57 | 58 | @hydra.main(config_path="conf") 59 | def main(cfg: DictConfig) -> None: 60 | log.info(f"Run testing for all electrodes in all test_subjects") 61 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 62 | out_dir = os.getcwd() 63 | log.info(f'Working directory {os.getcwd()}') 64 | if "out_dir" in cfg.test: 65 | out_dir = cfg.test.out_dir 66 | log.info(f'Output directory {out_dir}') 67 | 68 | test_split_path = cfg.test.test_split_path 69 | with open(test_split_path, "r") as f: 70 | test_splits = json.load(f) 71 | 72 | test_electrodes = None #For the topk. Omit this argument if you want everything 73 | if "test_electrodes_path" in cfg.test and cfg.test.test_electrodes_path != "None": #very hacky 74 | test_electrodes_path = cfg.test.test_electrodes_path 75 | test_electrodes_path = os.path.join(test_electrodes_path, cfg.data.name) 76 | test_electrodes_path = os.path.join(test_electrodes_path, "linear_results.json") 77 | with open(test_electrodes_path, "r") as f: 78 | test_electrodes = json.load(f) 79 | 80 | data_cfg = cfg.data 81 | all_test_results = {} 82 | for subj in test_splits: 83 | subj_test_results = {} 84 | Path(out_dir).mkdir(exist_ok=True, parents=True) 85 | out_path = os.path.join(out_dir, "all_test_results") 86 | log.info(f"Subject {subj}") 87 | data_cfg.subject = subj 88 | if test_electrodes is not None: 89 | if subj not in test_electrodes: 90 | continue 91 | electrodes = test_electrodes[subj] 92 | else: 93 | electrodes = get_clean_laplacian_electrodes(subj) 94 | subject_test_results = run_subject_test(data_cfg, test_splits[subj], electrodes, cfg) 95 | all_test_results[subj] = subject_test_results 96 | subj_test_results[subj] = subject_test_results 97 | 98 | out_json_path = os.path.join(out_path, subj) 99 | Path(out_json_path).mkdir(exist_ok=True, parents=True) 100 | out_json = os.path.join(out_json_path, "subj_test_results.json") 101 | with open(out_json, "w") as f: 102 | json.dump(subj_test_results, f) 103 | log.info(f"Wrote test results to {out_json}") 104 | write_summary(all_test_results, out_path) 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /run_train.py: -------------------------------------------------------------------------------- 1 | #example 2 | #python3 run_train.py +exp=spec2vec ++exp.runner.device=cuda ++exp.runner.multi_gpu=True ++exp.runner.num_workers=16 +data=masked_spec +model=debug_model +data.data=/storage/czw/self_supervised_seeg/all_electrode_data/manifests 3 | from omegaconf import DictConfig, OmegaConf 4 | import hydra 5 | import models 6 | import tasks 7 | from runner import Runner 8 | import logging 9 | import os 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | @hydra.main(config_path="conf") 14 | def main(cfg: DictConfig) -> None: 15 | log.info("Training") 16 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 17 | log.info(f'Working directory {os.getcwd()}') 18 | task = tasks.setup_task(cfg.task) 19 | task.load_datasets(cfg.data, cfg.preprocessor) 20 | model = task.build_model(cfg.model) 21 | criterion = task.build_criterion(cfg.criterion) 22 | runner = Runner(cfg.exp.runner, task, model, criterion) 23 | best_model = runner.train() 24 | runner.test(best_model) 25 | 26 | if __name__ == "__main__": 27 | main() 28 | -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from tqdm.contrib.logging import logging_redirect_tqdm 3 | import numpy as np 4 | import torch.nn as nn 5 | import os 6 | from tqdm import tqdm 7 | import torch 8 | import tasks 9 | import torch.multiprocessing as mp 10 | import torch.distributed as dist 11 | import logging 12 | from tensorboardX import SummaryWriter 13 | from schedulers import build_scheduler 14 | import torch_optimizer as torch_optim 15 | 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | class Runner(): 20 | def __init__(self, cfg, task, model, criterion): 21 | self.cfg = cfg 22 | self.model = model 23 | self.task = task 24 | self.evaluator = None 25 | self.device = cfg.device 26 | self.criterion = criterion 27 | self.exp_dir = os.getcwd() 28 | self.output_tb = cfg.get("output_tb", True) 29 | self.logger = None 30 | if self.output_tb: 31 | self.logger = SummaryWriter(self.exp_dir) 32 | 33 | if cfg.multi_gpu: 34 | self.model = torch.nn.DataParallel(self.model) 35 | log.info(f'Use {torch.cuda.device_count()} GPUs') 36 | assert not(cfg.device=='cpu' and cfg.multi_gpu) 37 | self.model.to(self.device) 38 | self.optim = self._init_optim(self.cfg) 39 | self.scheduler = build_scheduler(self.cfg.scheduler, self.optim) 40 | total_steps = self.cfg.total_steps 41 | self.progress = tqdm(total=total_steps, dynamic_ncols=True, desc="overall") 42 | 43 | if 'start_from_ckpt' in cfg: 44 | self.load_from_ckpt() 45 | 46 | def load_from_ckpt(self): 47 | ckpt_path = self.cfg.start_from_ckpt 48 | init_state = torch.load(ckpt_path) 49 | self.task.load_model_weights(self.model, init_state['model'], self.cfg.multi_gpu) 50 | self.optim.load_state_dict(init_state["optim"]) 51 | self.scheduler.load_state_dict(init_state["optim"]) 52 | 53 | def _init_optim(self, args): 54 | if args.optim == "SGD": 55 | optim = torch.optim.SGD(self.model.parameters(), lr=args.lr, momentum = 0.9) 56 | elif args.optim == 'Adam': 57 | optim = torch.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0.01) 58 | elif args.optim == 'AdamW': 59 | optim = torch.optim.AdamW(self.model.parameters(), lr=args.lr) 60 | elif args.optim == 'AdamW_finetune': 61 | linear_out_params = self.model.linear_out.parameters() if not self.cfg.multi_gpu else self.model.module.linear_out.parameters() 62 | ignored_params = list(map(id, linear_out_params)) 63 | base_params = filter(lambda p: id(p) not in ignored_params, 64 | self.model.parameters()) 65 | 66 | optim = torch.optim.AdamW([ 67 | {'params': base_params}, 68 | {'params': linear_out_params, 'lr': args.lr} 69 | ], lr=args.lr*0.1) 70 | elif args.optim == 'LAMB': 71 | optim = torch_optim.Lamb(self.model.parameters(), lr=args.lr) 72 | else: 73 | print("no valid optim name") 74 | return optim 75 | 76 | def output_logs(self, train_logging_outs, val_logging_outs): 77 | global_step = self.progress.n 78 | train_logging_outs['lr'] = self.scheduler.get_lr() 79 | standard_metrics = ["lr", "loss", "grad_norm"] 80 | all_standard_metrics = {} 81 | def add_prefix(prefix, outs): 82 | for k,v in outs.items(): 83 | if k in standard_metrics: 84 | all_standard_metrics[f'{prefix}_{k}'] = v 85 | add_prefix('train', train_logging_outs) 86 | add_prefix('val', val_logging_outs) 87 | 88 | log.info(all_standard_metrics) 89 | 90 | if self.logger is not None: 91 | for k,v in all_standard_metrics.items(): 92 | self.logger.add_scalar(k, v, global_step=global_step) 93 | self.task.output_logs(train_logging_outs, val_logging_outs, self.logger, global_step) 94 | 95 | def get_valid_outs(self): 96 | valid_loader = self.get_batch_iterator(self.task.valid_set, self.cfg.valid_batch_size, shuffle=self.cfg.shuffle, num_workers=self.cfg.num_workers) 97 | valid_logging_outs = self.task.get_valid_outs(self.model, valid_loader, self.criterion, self.device) 98 | return valid_logging_outs 99 | 100 | def save_checkpoint_last(self, states, best_val=False): 101 | cwd = os.getcwd() 102 | if best_val: 103 | save_path = os.path.join(cwd, 'checkpoint_best.pth') 104 | else: 105 | save_path = os.path.join(cwd, 'checkpoint_last.pth') 106 | log.info(f'Saving checkpoint to {save_path}') 107 | torch.save(states, save_path) 108 | log.info(f'Saved checkpoint to {save_path}') 109 | 110 | def save_checkpoints(self, best_val=False): 111 | all_states = {} 112 | all_states = self.task.save_model_weights(self.model, all_states, self.cfg.multi_gpu) 113 | all_states['optim'] = self.optim.state_dict() 114 | all_states['scheduler'] = self.scheduler.get_state_dict() 115 | if self.cfg.multi_gpu: 116 | all_states['model_cfg'] = self.model.module.cfg 117 | else: 118 | all_states['model_cfg'] = self.model.cfg 119 | self.save_checkpoint_last(all_states) 120 | if best_val: 121 | self.save_checkpoint_last(all_states, best_val) 122 | 123 | def run_epoch(self, train_loader, total_loss, best_state): 124 | epoch_loss = [] 125 | for batch in train_loader: 126 | if self.progress.n >= self.progress.total: 127 | break 128 | self.model.train() 129 | logging_out = self.task.train_step(batch, self.model, self.criterion, self.optim, self.scheduler, self.device, self.cfg.grad_clip) 130 | total_loss.append(logging_out["loss"]) 131 | epoch_loss.append(logging_out["loss"]) 132 | log_step = self.progress.n % self.cfg.log_step == 0 or self.progress.n == self.progress.total - 1 133 | 134 | ckpt_step = False 135 | if self.cfg.checkpoint_step > -1: 136 | ckpt_step = self.progress.n % self.cfg.checkpoint_step == 0 or self.progress.n == self.progress.total - 1 137 | 138 | best_model, best_val = best_state 139 | valid_logging_outs = {} 140 | if ckpt_step or log_step: 141 | self.model.eval() 142 | valid_logging_outs = self.get_valid_outs() 143 | if log_step: 144 | logging_out["loss"] = np.mean(total_loss) 145 | self.output_logs(logging_out, valid_logging_outs) 146 | total_loss = [] 147 | if ckpt_step: 148 | if valid_logging_outs["loss"] < best_val["loss"]: 149 | self.save_checkpoints(best_val=True) 150 | best_val = valid_logging_outs 151 | best_model = copy.deepcopy(self.model) 152 | else: 153 | self.save_checkpoints() 154 | self.progress.update(1) 155 | return total_loss, (best_model, best_val) 156 | 157 | def scheduler_step(self): 158 | pass 159 | 160 | def train(self): 161 | train_loader = self.get_batch_iterator(self.task.train_set, self.cfg.train_batch_size, shuffle=self.cfg.shuffle, num_workers=self.cfg.num_workers, persistent_workers=self.cfg.num_workers>0) 162 | 163 | total_loss = [] 164 | best_val = {"loss": float("inf")} 165 | best_model = None 166 | best_state = (best_model, best_val) 167 | with logging_redirect_tqdm(): 168 | if self.cfg.checkpoint_step > -1: 169 | self.save_checkpoints() 170 | while self.progress.n < self.progress.total: 171 | total_loss, best_state = self.run_epoch(train_loader, total_loss, best_state) 172 | best_model, best_val = best_state 173 | self.progress.close() 174 | return best_model 175 | 176 | def test(self, best_model_weights): 177 | test_loader = self.get_batch_iterator(self.task.test_set, self.cfg.valid_batch_size, shuffle=self.cfg.shuffle, num_workers=self.cfg.num_workers, persistent_workers=self.cfg.num_workers>0) 178 | 179 | test_outs = self.task.get_valid_outs(self.model, test_loader, self.criterion, self.device) 180 | log.info(f"test_results {test_outs}") 181 | return test_outs 182 | 183 | def get_batch_iterator(self, dataset, batch_size, **kwargs): 184 | return self.task.get_batch_iterator(dataset, batch_size, **kwargs) 185 | -------------------------------------------------------------------------------- /schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .reduce_on_plateau import ReduceOnPlateau 2 | from .ramp_up import RampUp 3 | 4 | __all__ = ["build_scheduler"] 5 | 6 | def build_scheduler(cfg, optim): 7 | name = cfg.name 8 | if name=="reduce_on_plateau": 9 | return ReduceOnPlateau(cfg, optim) 10 | if name=="ramp_up": 11 | return RampUp(cfg, optim) 12 | else: 13 | raise ValueError("Scheduler name not found") 14 | -------------------------------------------------------------------------------- /schedulers/base_scheduler.py: -------------------------------------------------------------------------------- 1 | class BaseScheduler(): 2 | def __init__(self): 3 | pass 4 | 5 | def step(*args, **kwargs): 6 | raise NotImplementedError 7 | 8 | def load_state_dict(self, init_state): 9 | self.scheduler.load_state_dict(init_state) 10 | 11 | def get_state_dict(self): 12 | self.scheduler.state_dict() 13 | 14 | def get_lr(self): 15 | return self.scheduler._last_lr[0] 16 | 17 | -------------------------------------------------------------------------------- /schedulers/ramp_up.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_scheduler import BaseScheduler 3 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 4 | from torch.optim.sgd import SGD 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | class RampUp(BaseScheduler): 8 | def __init__(self, cfg, optim): 9 | ''' 10 | https://github.com/ildoonet/pytorch-gradual-warmup-lr 11 | ''' 12 | super(RampUp, self).__init__() 13 | self.cfg = cfg 14 | warmup = int(self.cfg.warmup*self.cfg.total_steps) 15 | step_size = (self.cfg.total_steps - warmup)/100 16 | scheduler_steplr = StepLR(optim, step_size=step_size, gamma=0.99) 17 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=warmup, after_scheduler=scheduler_steplr) 18 | 19 | # this zero gradient update is needed to avoid a warning message, issue #8. 20 | optim.zero_grad() 21 | optim.step() 22 | self.scheduler = scheduler_warmup 23 | 24 | def step(self, loss): 25 | self.scheduler.step() 26 | -------------------------------------------------------------------------------- /schedulers/reduce_on_plateau.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_scheduler import BaseScheduler 3 | 4 | class ReduceOnPlateau(BaseScheduler): 5 | def __init__(self, cfg, optim): 6 | super(ReduceOnPlateau, self).__init__() 7 | 8 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, 'min', patience=300) 9 | self.scheduler.step(100) #TODO hack 10 | 11 | def step(self, loss): 12 | self.scheduler.step(loss) 13 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from pathlib import Path 4 | 5 | TASK_REGISTRY = {} 6 | 7 | __all__ = ["setup_task"] 8 | 9 | def setup_task(cfg): 10 | task_name = cfg.name 11 | assert task_name in TASK_REGISTRY 12 | task = TASK_REGISTRY[task_name] 13 | return task.setup_task(cfg) 14 | 15 | def register_task(name): 16 | def register_task_cls(cls): 17 | if name in TASK_REGISTRY: 18 | raise ValueError(f'{name} already in registry') 19 | else: 20 | TASK_REGISTRY[name] = cls 21 | return cls 22 | return register_task_cls 23 | 24 | def import_tasks(): 25 | for file in os.listdir(os.path.dirname(__file__)): 26 | if file.endswith(".py") and not file.startswith("_"): 27 | module_name = str(Path(file).with_suffix("")) 28 | importlib.import_module('tasks.'+module_name) 29 | 30 | import_tasks() 31 | -------------------------------------------------------------------------------- /tasks/base_task.py: -------------------------------------------------------------------------------- 1 | import models 2 | import criterions 3 | from torch.utils import data 4 | import torch 5 | from datasets import build_dataset 6 | from tasks.utils import split_dataset 7 | 8 | class BaseTask(): 9 | def __init__(self, cfg): 10 | self.cfg = cfg 11 | 12 | def build_model(self, cfg): 13 | return models.build_model(cfg) 14 | 15 | def load_datasets(self, data_cfg, preprocessor_cfg): 16 | #create train/val/test dataset 17 | dataset = build_dataset(data_cfg, task_cfg=self.cfg, preprocessor_cfg=preprocessor_cfg) 18 | 19 | train_set, val_set, test_set = split_dataset(dataset, data_cfg) 20 | self.dataset = dataset 21 | self.train_set = train_set 22 | self.valid_set = val_set 23 | self.test_set = test_set 24 | 25 | def train_step(self, batch, model, criterion, optimizer, scheduler, device, grad_clip=None): 26 | loss, logging_out = criterion(model, batch, device) 27 | loss.backward(loss) 28 | if grad_clip: 29 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 30 | optimizer.step() 31 | optimizer.zero_grad() 32 | scheduler.step(loss) 33 | 34 | logging_out["grad_norm"] = grad_norm.item() 35 | return logging_out 36 | 37 | def build_criterion(self, cfg): 38 | return criterions.build_criterion(cfg) 39 | 40 | def get_batch_iterator(self, dataset, batch_size, shuffle=True, **kwargs): 41 | return data.DataLoader(dataset, batch_size=batch_size, **kwargs) 42 | 43 | def get_valid_outs(): 44 | raise NotImplementedError 45 | 46 | def save_model_weights(self, model, states, multi_gpu): 47 | #expects a new state with "models" key 48 | if multi_gpu: 49 | return model.module.save_model_weights(states) 50 | return model.save_model_weights(states) 51 | 52 | def load_model_weights(self, model, states, multi_gpu): 53 | if multi_gpu: 54 | model.module.load_weights(states) 55 | else: 56 | model.load_weights(states) 57 | -------------------------------------------------------------------------------- /tasks/baseline_wav_task.py: -------------------------------------------------------------------------------- 1 | #usage 2 | #python3 run_train.py +exp=spec2vec ++exp.runner.device=cuda ++exp.runner.multi_gpu=False ++exp.runner.num_workers=0 +data=timestamp_data +model=debug_finetune_model ++exp.task.name=debug_finetune_task ++exp.criterion.name=debug_finetune_criterion ++exp.runner.total_steps=1000 ++model.frozen_upstream=True ++exp.runner.checkpoint_step=-1 3 | import logging 4 | import numpy as np 5 | import models 6 | from torch.utils import data 7 | import torch 8 | from tasks import register_task 9 | from tasks.base_task import BaseTask 10 | from tasks.batch_utils import baseline_wav_collator 11 | from util.tensorboard_utils import plot_tensorboard_line 12 | from sklearn.metrics import roc_auc_score, f1_score 13 | 14 | log = logging.getLogger(__name__) 15 | 16 | @register_task(name="baseline_wav_task") 17 | class BaselineWavTask(BaseTask): 18 | def __init__(self, cfg): 19 | super(BaselineWavTask, self).__init__(cfg) 20 | 21 | def build_model(self, cfg): 22 | #assert hasattr(self, "dataset") 23 | input_dim = self.dataset.get_input_dim() 24 | return models.build_model(cfg, input_dim) 25 | 26 | @classmethod 27 | def setup_task(cls, cfg): 28 | return cls(cfg) 29 | 30 | def get_valid_outs(self, model, valid_loader, criterion, device): 31 | model.eval() 32 | all_outs = {"loss":0} 33 | predicts, labels = [], [] 34 | with torch.no_grad(): 35 | for batch in valid_loader: 36 | batch["input"] = batch["input"].to(device) 37 | _, valid_outs = criterion(model, batch, device, return_predicts=True) 38 | 39 | predicts.append(valid_outs["predicts"]) 40 | labels.append(batch["labels"]) 41 | all_outs["loss"] += valid_outs["loss"] 42 | labels = np.array([x for y in labels for x in y]) 43 | predicts = [np.array([p]) if len(p.shape)==0 else p for p in predicts] 44 | predicts = np.concatenate(predicts) 45 | roc_auc = roc_auc_score(labels, predicts) 46 | all_outs["loss"] /= len(valid_loader) 47 | all_outs["roc_auc"] = roc_auc 48 | f1 = f1_score(labels, np.round(predicts)) 49 | all_outs["f1"] = f1 50 | return all_outs 51 | 52 | def get_batch_iterator(self, dataset, batch_size, shuffle=True, **kwargs): 53 | return data.DataLoader(dataset, batch_size=batch_size, collate_fn=baseline_wav_collator, **kwargs) 54 | 55 | def output_logs(self, train_logging_outs, val_logging_outs, writer, global_step): 56 | val_auc_roc = val_logging_outs["roc_auc"] 57 | val_f1 = val_logging_outs["f1"] 58 | if writer is not None: 59 | writer.add_scalar("valid_roc_auc", val_auc_roc, global_step) 60 | writer.add_scalar("valid_f1", val_f1, global_step) 61 | log.info(f'valid_roc_auc: {val_auc_roc}') 62 | 63 | image = train_logging_outs["images"]["wav"] 64 | label = train_logging_outs["images"]["wav_label"] 65 | tb_image = plot_tensorboard_line(image, title=label) 66 | if writer is not None: 67 | writer.add_image("raw_wave", tb_image, global_step) 68 | 69 | -------------------------------------------------------------------------------- /tasks/batch_utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn.utils.rnn import pad_sequence 2 | import torch 3 | 4 | def make_pad_mask(batched_input, lengths): 5 | pad_mask = torch.ones(batched_input.shape[:-1]) #[batch, len] 6 | 7 | for i in range(pad_mask.shape[0]): 8 | pad_mask[i,lengths[i]:] = 0 9 | 10 | pad_mask = ~pad_mask.bool() 11 | return pad_mask 12 | 13 | def spec_collator(batch): 14 | input_specs = [b["masked_input"] for b in batch] 15 | mask_labels = [b["mask_label"] for b in batch] 16 | targets = [b["target"] for b in batch] 17 | lengths = [b["length"] for b in batch] 18 | wavs = [b["wav"] for b in batch] 19 | 20 | batched_input = pad_sequence(input_specs, batch_first=True) 21 | batched_target = pad_sequence(targets, batch_first=True) 22 | batched_mask_label = pad_sequence(mask_labels, batch_first=True) 23 | 24 | attn_mask = make_pad_mask(batched_input, lengths) 25 | 26 | batch = {"attn_mask": attn_mask, 27 | "masked_input": batched_input, 28 | "target": batched_target, 29 | "mask_label": batched_mask_label, 30 | "wavs": wavs} 31 | return batch 32 | 33 | def wav_collator(batch): 34 | wavs = [torch.Tensor(b["input"]).unsqueeze(0) for b in batch] 35 | wavs = pad_sequence(wavs, batch_first=True) 36 | return {"input":wavs, 37 | } 38 | 39 | def baseline_wav_collator(batch): 40 | labels = [b["label"] for b in batch] 41 | wavs = [torch.Tensor(b["input"]) for b in batch] 42 | wavs = pad_sequence(wavs, batch_first=True) 43 | 44 | lengths = [b["length"] for b in batch] 45 | 46 | return {"input":wavs, 47 | "labels":labels, 48 | } 49 | 50 | def finetune_collator(batch): 51 | specs = [b["input"] for b in batch] 52 | specs = pad_sequence(specs, batch_first=True) 53 | labels = [b["label"] for b in batch] 54 | wavs = [b["wav"] for b in batch] 55 | 56 | lengths = [b["length"] for b in batch] 57 | pad_mask = make_pad_mask(specs, lengths) 58 | 59 | return {"input":specs, 60 | "labels":labels, 61 | "wavs": wavs, 62 | "pad_mask": pad_mask} 63 | 64 | def feature_extracter_collator(batch): 65 | specs = [b["input"] for b in batch] 66 | specs = pad_sequence(specs, batch_first=True) 67 | labels = [b["label"] for b in batch] 68 | wavs = [b["wav"] for b in batch] 69 | 70 | lengths = [b["length"] for b in batch] 71 | 72 | return {"input":specs, 73 | "labels":labels, 74 | "wavs": wavs} 75 | -------------------------------------------------------------------------------- /tasks/feature_extract_task.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import models 4 | from torch.utils import data 5 | import torch 6 | from tasks import register_task 7 | from tasks.base_task import BaseTask 8 | from tasks.batch_utils import feature_extracter_collator 9 | from util.tensorboard_utils import plot_tensorboard_line 10 | from sklearn.metrics import roc_auc_score, f1_score 11 | 12 | log = logging.getLogger(__name__) 13 | 14 | @register_task(name="feature_extract_task") 15 | class FeatureExtractTask(BaseTask): 16 | def __init__(self, cfg): 17 | super(FeatureExtractTask, self).__init__(cfg) 18 | 19 | def build_model(self, cfg): 20 | assert hasattr(self, "dataset") 21 | input_dim = self.dataset.get_input_dim() 22 | return models.build_model(cfg) 23 | 24 | @classmethod 25 | def setup_task(cls, cfg): 26 | return cls(cfg) 27 | 28 | def get_valid_outs(self, model, valid_loader, criterion, device): 29 | model.eval() 30 | all_outs = {"loss":0} 31 | predicts, labels = [], [] 32 | with torch.no_grad(): 33 | for batch in valid_loader: 34 | batch["input"] = batch["input"].to(device) 35 | _, valid_outs = criterion(model, batch, device, return_predicts=True) 36 | 37 | predicts.append(valid_outs["predicts"]) 38 | labels.append(batch["labels"]) 39 | all_outs["loss"] += valid_outs["loss"] 40 | labels = np.array([x for y in labels for x in y]) 41 | predicts = [np.array([p]) if len(p.shape)==0 else p for p in predicts] 42 | predicts = np.concatenate(predicts) 43 | roc_auc = roc_auc_score(labels, predicts) 44 | f1 = f1_score(labels, np.round(predicts)) 45 | all_outs["loss"] /= len(valid_loader) 46 | all_outs["roc_auc"] = roc_auc 47 | all_outs["f1"] = f1 48 | return all_outs 49 | 50 | def get_batch_iterator(self, dataset, batch_size, shuffle=True, **kwargs): 51 | return data.DataLoader(dataset, batch_size=batch_size, collate_fn=feature_extracter_collator, **kwargs) 52 | 53 | def output_logs(self, train_logging_outs, val_logging_outs, writer, global_step): 54 | val_auc_roc = val_logging_outs["roc_auc"] 55 | val_f1 = val_logging_outs["f1"] 56 | if writer is not None: 57 | writer.add_scalar("valid_roc_auc", val_auc_roc, global_step) 58 | writer.add_scalar("valid_f1", val_f1, global_step) 59 | log.info(f'valid_roc_auc: {val_auc_roc}, valid_f1: {val_f1}') 60 | 61 | image = train_logging_outs["images"]["wav"] 62 | label = train_logging_outs["images"]["wav_label"] 63 | tb_image = plot_tensorboard_line(image, title=label) 64 | if writer is not None: 65 | 66 | writer.add_image("raw_wave", tb_image, global_step) 67 | 68 | -------------------------------------------------------------------------------- /tasks/finetune_task.py: -------------------------------------------------------------------------------- 1 | #usage 2 | #python3 run_train.py +exp=spec2vec ++exp.runner.device=cuda ++exp.runner.multi_gpu=False ++exp.runner.num_workers=0 +data=timestamp_data +model=debug_finetune_model ++exp.task.name=debug_finetune_task ++exp.criterion.name=debug_finetune_criterion ++exp.runner.total_steps=1000 ++model.frozen_upstream=True ++exp.runner.checkpoint_step=-1 3 | import logging 4 | import numpy as np 5 | import models 6 | from torch.utils import data 7 | import torch 8 | from tasks import register_task 9 | from tasks.base_task import BaseTask 10 | from tasks.batch_utils import finetune_collator 11 | from util.tensorboard_utils import plot_tensorboard_line 12 | from sklearn.metrics import roc_auc_score 13 | 14 | log = logging.getLogger(__name__) 15 | 16 | @register_task(name="finetune_task") 17 | class FinetuneTask(BaseTask): 18 | def __init__(self, cfg): 19 | super(FinetuneTask, self).__init__(cfg) 20 | 21 | def build_model(self, cfg): 22 | ckpt_path = cfg.upstream_ckpt 23 | init_state = torch.load(ckpt_path) 24 | upstream_cfg = init_state["model_cfg"] 25 | if upstream_cfg.name=='debug_model': 26 | upstream_cfg.name='masked_tf_model' 27 | upstream = models.build_model(upstream_cfg) 28 | assert hasattr(self, "dataset") 29 | input_dim = self.dataset.get_input_dim() 30 | assert input_dim == upstream_cfg["input_dim"] 31 | states = init_state["model"] 32 | self.load_model_weights(upstream, states, False) 33 | return models.build_model(cfg, upstream) 34 | 35 | @classmethod 36 | def setup_task(cls, cfg): 37 | return cls(cfg) 38 | 39 | def get_valid_outs(self, model, valid_loader, criterion, device): 40 | model.eval() 41 | all_outs = {"loss":0} 42 | predicts, labels = [], [] 43 | with torch.no_grad(): 44 | for batch in valid_loader: 45 | batch["input"] = batch["input"].to(device) 46 | _, valid_outs = criterion(model, batch, device, return_predicts=True) 47 | 48 | predicts.append(valid_outs["predicts"]) 49 | labels.append(batch["labels"]) 50 | all_outs["loss"] += valid_outs["loss"] 51 | labels = np.array([x for y in labels for x in y]) 52 | predicts = [np.array([p]) if len(p.shape)==0 else p for p in predicts] 53 | predicts = np.concatenate(predicts) 54 | roc_auc = roc_auc_score(labels, predicts) 55 | all_outs["loss"] /= len(valid_loader) 56 | all_outs["roc_auc"] = roc_auc 57 | return all_outs 58 | 59 | def get_batch_iterator(self, dataset, batch_size, shuffle=True, **kwargs): 60 | return data.DataLoader(dataset, batch_size=batch_size, collate_fn=finetune_collator, **kwargs) 61 | 62 | def output_logs(self, train_logging_outs, val_logging_outs, writer, global_step): 63 | val_auc_roc = val_logging_outs["roc_auc"] 64 | if writer is not None: 65 | writer.add_scalar("valid_roc_auc", val_auc_roc, global_step) 66 | log.info(f'valid_roc_auc: {val_auc_roc}') 67 | 68 | image = train_logging_outs["images"]["wav"] 69 | label = train_logging_outs["images"]["wav_label"] 70 | tb_image = plot_tensorboard_line(image, title=label) 71 | if writer is not None: 72 | writer.add_image("raw_wave", tb_image, global_step) 73 | -------------------------------------------------------------------------------- /tasks/seeg_wav_task.py: -------------------------------------------------------------------------------- 1 | #usage 2 | # python3 run_train.py +exp=seeg_wav2vec ++exp.runner.device=cuda ++exp.runner.multi_gpu=True ++exp.runner.num_workers=16 +data=pretrain_wavs_from_disk +model=seeg_wav2vec +data.data=/storage/czw/LanguageEcog/semantics/manifest 3 | import models 4 | from torch.utils import data 5 | import torch 6 | from tasks import register_task 7 | from tasks.base_task import BaseTask 8 | from tasks.batch_utils import wav_collator 9 | from util.tensorboard_utils import plot_tensorboard_spectrogram, plot_tensorboard_line 10 | 11 | @register_task(name="seeg_wav_task") 12 | class SeegWavTask(BaseTask): 13 | def __init__(self, cfg): 14 | super(SeegWavTask, self).__init__(cfg) 15 | 16 | @classmethod 17 | def setup_task(cls, cfg): 18 | return cls(cfg) 19 | 20 | def get_valid_outs(self, model, valid_loader, criterion, device): 21 | model.eval() 22 | all_outs = {"loss":0} 23 | with torch.no_grad(): 24 | for batch in valid_loader: 25 | _, valid_outs = criterion(model, batch, device) 26 | all_outs["loss"] += valid_outs["loss"] 27 | all_outs["loss"] /= len(valid_loader) 28 | return all_outs 29 | 30 | def build_model(self, cfg): 31 | assert hasattr(self, "dataset") 32 | #input_dim = self.dataset.get_input_dim() 33 | #assert input_dim == cfg.input_dim 34 | return models.build_model(cfg) 35 | 36 | def get_batch_iterator(self, dataset, batch_size, shuffle=True, **kwargs): 37 | return data.DataLoader(dataset, batch_size=batch_size, collate_fn=wav_collator, **kwargs) 38 | 39 | def output_logs(self, train_logging_outs, val_logging_outs, writer, global_step): 40 | for k in train_logging_outs["images"]: 41 | image = train_logging_outs["images"][k] 42 | if k == "wav": 43 | tb_image = plot_tensorboard_line(image, title="wav") 44 | else: 45 | tb_image = plot_tensorboard_spectrogram(image) 46 | writer.add_image(k, tb_image, global_step) 47 | 48 | -------------------------------------------------------------------------------- /tasks/spec_pretrain.py: -------------------------------------------------------------------------------- 1 | #python3 run_train.py +exp=spec2vec ++exp.runner.device=cuda ++exp.runner.multi_gpu=True ++exp.runner.num_workers=16 +data=masked_tf_dataset +model=debug_model +data.data=/storage/czw/self_supervised_seeg/all_day_data/manifests ++exp.runner.train_batch_size=64 2 | import models 3 | from torch.utils import data 4 | import torch 5 | from tasks import register_task 6 | from tasks.base_task import BaseTask 7 | from tasks.batch_utils import spec_collator 8 | from util.tensorboard_utils import plot_tensorboard_spectrogram, plot_tensorboard_line 9 | 10 | @register_task(name="spec_pretrain") 11 | class SpecPretrain(BaseTask): 12 | def __init__(self, cfg): 13 | super(SpecPretrain, self).__init__(cfg) 14 | 15 | @classmethod 16 | def setup_task(cls, cfg): 17 | return cls(cfg) 18 | 19 | def get_valid_outs(self, model, valid_loader, criterion, device): 20 | model.eval() 21 | all_outs = {"loss":0, "content_aware_loss":0, "l1_loss":0} 22 | with torch.no_grad(): 23 | for batch in valid_loader: 24 | _, valid_outs = criterion(model, batch, device) 25 | all_outs["loss"] += valid_outs["loss"] 26 | all_outs["content_aware_loss"] += valid_outs["content_aware_loss"] 27 | all_outs["l1_loss"] += valid_outs["l1_loss"] 28 | for key in all_outs: 29 | all_outs[key] /= len(valid_loader) 30 | return all_outs 31 | 32 | def build_model(self, cfg): 33 | assert hasattr(self, "dataset") 34 | input_dim = self.dataset.get_input_dim() 35 | assert input_dim == cfg.input_dim 36 | return models.build_model(cfg) 37 | 38 | def get_batch_iterator(self, dataset, batch_size, shuffle=True, **kwargs): 39 | return data.DataLoader(dataset, batch_size=batch_size, collate_fn=spec_collator, **kwargs) 40 | 41 | def output_logs(self, train_logging_outs, val_logging_outs, writer, global_step): 42 | for k in train_logging_outs["images"]: 43 | image = train_logging_outs["images"][k] 44 | if k == "wav": 45 | tb_image = plot_tensorboard_line(image, title="wav") 46 | else: 47 | tb_image = plot_tensorboard_spectrogram(image) 48 | if writer is not None: 49 | writer.add_image(k, tb_image, global_step) 50 | 51 | if writer is not None: 52 | loss_metrics = ["l1_loss", "content_aware_loss", "content_l1"] 53 | all_loss_metrics = {} 54 | def add_prefix(prefix, outs): 55 | for k,v in outs.items(): 56 | if k in loss_metrics: 57 | all_loss_metrics[f'{prefix}_{k}'] = v 58 | add_prefix('train', train_logging_outs) 59 | add_prefix('val', val_logging_outs) 60 | for k,v in all_loss_metrics.items(): 61 | writer.add_scalar(k, v, global_step=global_step) 62 | -------------------------------------------------------------------------------- /tasks/utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import train_test_split 2 | from torch.utils.data import Subset 3 | 4 | def split_dataset(dataset, args): 5 | val_split = args.get("val_split", 0) 6 | test_split = args.get("test_split", 0) 7 | train_split = args.get("train_split", 1-val_split-test_split) 8 | assert val_split + test_split + train_split <= 1 9 | assert train_split > 0 10 | all_idxs = list(range(len(dataset))) 11 | train_idxs, test_val_idxs = train_test_split(all_idxs, test_size=val_split+test_split, random_state=42) 12 | train_idxs = train_idxs[:int(len(all_idxs)*train_split)] 13 | train_fewshot = args.get("train_fewshot", -1) 14 | train_idxs = train_idxs[:train_fewshot] 15 | 16 | train_set = Subset(dataset, train_idxs) 17 | 18 | val_idxs, test_idxs = train_test_split(test_val_idxs, test_size=test_split/(val_split+test_split), random_state=42) 19 | val_set = Subset(dataset, val_idxs) 20 | test_set = Subset(dataset, test_idxs) 21 | return train_set, val_set, test_set 22 | 23 | -------------------------------------------------------------------------------- /testing/collect_dataset_stats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from omegaconf import DictConfig, OmegaConf 3 | import hydra 4 | import models 5 | import tasks 6 | from runner import Runner 7 | import logging 8 | import os 9 | from data.electrode_selection import get_clean_laplacian_electrodes 10 | import json 11 | from pathlib import Path 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | def get_electrode_data(data_cfg, brain_runs, electrodes, cfg): 16 | data_cfg_copy = data_cfg.copy() 17 | cache_path = None 18 | if "cache_input_features" in data_cfg_copy: 19 | cache_path = data_cfg_copy.cache_input_features 20 | 21 | subject_test_results = {} 22 | positive_count, negative_count = 0,0 23 | e = electrodes[0] 24 | data_cfg_copy.electrodes = [e] 25 | data_cfg_copy.brain_runs = brain_runs 26 | if cache_path is not None: 27 | #cache_path needs to identify the pretrained model 28 | e_cache_path = os.path.join(cache_path, data_cfg_copy.subject, data_cfg_copy.name ,e) 29 | log.info(f"logging input features in {e_cache_path}") 30 | data_cfg_copy.cache_input_features = e_cache_path 31 | cfg.data = data_cfg_copy 32 | task = tasks.setup_task(cfg.task) 33 | task.load_datasets(cfg.data, cfg.preprocessor) 34 | labels = np.array([e["label"] for e in task.dataset]) 35 | positive_count += sum(labels==1) 36 | negative_count += sum(labels==0) 37 | return positive_count, negative_count 38 | 39 | def write_summary(all_test_results, out_path): 40 | out_json = os.path.join(out_path, "all_test_results.json") 41 | with open(out_json, "w") as f: 42 | json.dump(all_test_results, f) 43 | 44 | out_json = os.path.join(out_path, "summary.json") 45 | all_rocs = [] 46 | for s in all_test_results: 47 | for e in all_test_results[s]: 48 | all_rocs.append(all_test_results[s][e]["roc_auc"]) 49 | 50 | summary_results = {"avg_roc_auc": np.mean(all_rocs), "std_roc_auc": np.std(all_rocs)} 51 | with open(out_json, "w") as f: 52 | json.dump(summary_results, f) 53 | 54 | log.info(f"Wrote test results to {out_path}") 55 | 56 | def get_dataset_stats(data_cfg, test_splits, cfg): 57 | positive_count, negative_count = [], [] 58 | subj_results = {} 59 | for subj in test_splits: 60 | subj_test_results = {} 61 | log.info(f"Subject {subj}") 62 | data_cfg.subject = subj 63 | electrodes = get_clean_laplacian_electrodes(subj) 64 | p, n = get_electrode_data(data_cfg, test_splits[subj], electrodes, cfg) 65 | positive_count.append(p) 66 | negative_count.append(n) 67 | subj_results[subj] = {"positive": p.item(), "negative": n.item()} 68 | return np.array(positive_count).mean(), np.array(negative_count).mean(), subj_results 69 | 70 | @hydra.main(config_path="../conf") 71 | def main(cfg: DictConfig) -> None: 72 | log.info(f"Get data stats") 73 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 74 | out_dir = os.getcwd() 75 | log.info(f'Working directory {os.getcwd()}') 76 | if "out_dir" in cfg.test: 77 | out_dir = cfg.test.out_dir 78 | log.info(f'Output directory {out_dir}') 79 | 80 | test_split_path = cfg.test.test_split_path 81 | with open(test_split_path, "r") as f: 82 | test_splits = json.load(f) 83 | 84 | data_cfg = cfg.data 85 | all_test_results = {} 86 | features = ["onset_finetuning", "speech_finetuning", "rms_finetuning", "pitch_finetuning"] 87 | for feature in features: 88 | print("feature", feature) 89 | data_cfg.name = feature 90 | p, n, s = get_dataset_stats(data_cfg, test_splits, cfg) 91 | print("positive", p, "negative", n) 92 | #all_test_results[feature] = {"positive":p, "negative":n} 93 | all_test_results[feature] = s 94 | 95 | Path(out_dir).mkdir(exist_ok=True, parents=True) 96 | with open(os.path.join(out_dir, "all_results.json"), "w") as f: 97 | json.dump(all_test_results, f) 98 | 99 | if __name__ == "__main__": 100 | main() 101 | -------------------------------------------------------------------------------- /testing/effective_dimensionality.py: -------------------------------------------------------------------------------- 1 | from tasks.batch_utils import finetune_collator 2 | import json 3 | from omegaconf import DictConfig, OmegaConf 4 | import hydra 5 | import logging 6 | import os 7 | import torch 8 | import random 9 | import numpy as np 10 | import models 11 | import umap 12 | from data.electrode_selection import get_clean_laplacian_electrodes 13 | from pathlib import Path 14 | from omegaconf import DictConfig, OmegaConf 15 | from datasets import build_dataset 16 | from tqdm import tqdm as tqdm 17 | from sklearn.manifold import TSNE 18 | from sklearn.decomposition import PCA 19 | import matplotlib.pyplot as plt 20 | from torch.utils.data import Subset 21 | from torch.utils import data 22 | 23 | log = logging.getLogger(__name__) 24 | 25 | def load_model_weights(model, states, multi_gpu): 26 | if multi_gpu: 27 | model.module.load_weights(states) 28 | else: 29 | model.load_weights(states) 30 | 31 | log = logging.getLogger(__name__) 32 | 33 | def make_scatter_plot(vecs, labels, dataset, name="scatter"): 34 | #labels must be numeric 35 | unique_colors = np.unique(labels) 36 | colors = np.array(labels) 37 | 38 | cmap = plt.get_cmap("tab10") 39 | fig, ax = plt.subplots() 40 | for color in unique_colors: 41 | cvecs = vecs[colors==color] 42 | ax.scatter(cvecs[:,0], 43 | cvecs[:,1], 44 | color=cmap(color), 45 | label=color, 46 | s=1) 47 | ax.legend(markerscale=5) 48 | plt.savefig(f'{name}.png') 49 | 50 | def get_effective_dim(contexts, dataset, args): 51 | if args.dim_reduce=="pca": 52 | pca = PCA(n_components=args.n_components) 53 | reduced = pca.fit_transform(contexts) 54 | #if args.dim_reduce=="tsne": 55 | # tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=1000) 56 | # reduced = tsne.fit_transform(contexts) 57 | #if args.dim_reduce=="umap": 58 | # reducer = umap.UMAP() 59 | # reduced = reducer.fit_transform(contexts) 60 | ratios = pca.explained_variance_ratio_ 61 | dim = 0 62 | dist = {i:0 for i in range(len(ratios))} 63 | while dim < len(ratios): 64 | percent = np.sum(ratios[:dim]) 65 | if percent > 0.95: 66 | break 67 | dist[dim] = percent.item() 68 | dim += 1 69 | return dim, dist 70 | 71 | def build_model(cfg): 72 | ckpt_path = cfg.upstream_ckpt 73 | init_state = torch.load(ckpt_path) 74 | upstream_cfg = init_state["model_cfg"] 75 | upstream = models.build_model(upstream_cfg) 76 | return upstream 77 | 78 | def get_embeddings(dataset, model, raw_spec=False): 79 | embeds, labels = [], [] 80 | if model is not None: 81 | model.eval() 82 | all_idxs = list(range(len(dataset))) 83 | #random.shuffle(all_idxs) 84 | #for item in tqdm(dataset): 85 | subset = Subset(dataset, [x for x in range(500)]) 86 | loader = data.DataLoader(subset, batch_size=64, collate_fn=finetune_collator) 87 | 88 | for batch in tqdm(loader): 89 | if raw_spec: 90 | out = batch["input"] 91 | else: 92 | inputs = batch["input"].to('cuda') 93 | mask = torch.zeros((inputs.shape[:2])).bool().to('cuda') 94 | with torch.no_grad(): 95 | out = model.forward(inputs, mask, intermediate_rep=True) 96 | middle = out.shape[1] 97 | #embed = out[:,middle-5:middle+5,:].mean(axis=1) #TODO remove 98 | embed = out.mean(axis=1) 99 | #embed = out[:,random.randint(0,62),:] 100 | if np.any(np.array(batch["labels"])==0): 101 | import pdb; pdb.set_trace() 102 | embeds.append(embed.cpu().numpy()) 103 | embeds = np.concatenate(embeds) 104 | return embeds, labels 105 | 106 | @hydra.main(config_path="../conf") 107 | def main(cfg: DictConfig) -> None: 108 | log.info(f"Find effective dimensionality") 109 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 110 | log.info(f'Working directory {os.getcwd()}') 111 | 112 | raw_spec = cfg.test.raw_spec 113 | model = None 114 | if not raw_spec: 115 | model = build_model(cfg.test) 116 | model = torch.nn.DataParallel(model) 117 | model.to('cuda') 118 | ckpt_path = cfg.test.upstream_ckpt 119 | init_state = torch.load(ckpt_path) 120 | load_model_weights(model, init_state['model'], True) 121 | 122 | log.info(f'Use {torch.cuda.device_count()} GPUs') 123 | 124 | test_split_path = cfg.test.test_split_path 125 | with open(test_split_path, "r") as f: 126 | test_splits = json.load(f) 127 | 128 | all_results = {} 129 | Path(cfg.test.out_dir).mkdir(parents=True, exist_ok=True) 130 | for subject in test_splits: 131 | electrodes = get_clean_laplacian_electrodes(subject) 132 | all_results[subject] = {} 133 | subj_data_cfg = cfg.data.copy() 134 | subj_data_cfg.subject = subject 135 | random.shuffle(electrodes) 136 | for e in electrodes:#[:100]: 137 | data_cfg = subj_data_cfg.copy() 138 | data_cfg.electrodes=[e] 139 | dataset = build_dataset(data_cfg, preprocessor_cfg=cfg.preprocessor) 140 | embeddings, labels = get_embeddings(dataset, model, cfg.test.raw_spec) 141 | dim, dist = get_effective_dim(embeddings, dataset, cfg.test) 142 | all_results[subject][e] = dim 143 | e_out_dir = os.path.join(cfg.test.out_dir, subject, e) 144 | Path(e_out_dir).mkdir(parents=True, exist_ok=True) 145 | print(dim) 146 | with open(os.path.join(e_out_dir, "dim_results.json"), "w") as f: 147 | json.dump(dist, f) 148 | with open(os.path.join(cfg.test.out_dir, "dim_results.json"), "w") as f: 149 | json.dump(all_results, f) 150 | 151 | if __name__=="__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /testing/make_finetuning_datasets_stats.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | featurename2label = {"onset_finetuning":"Onset", 4 | "speech_finetuning":"Speech", 5 | "pitch_finetuning":"Pitch", 6 | "rms_finetuning":"Volume", 7 | "brightness_finetuning": "brightness"} 8 | 9 | subject2label = {"subject1": "subject-1", 10 | "subject2": "subject-2", 11 | "subject3": "subject-3", 12 | "subject4": "subject-4", 13 | "subject4": "subject-4", 14 | "subject5": "subject-5", 15 | "subject6": "subject-6", 16 | "subject7": "subject-7" 17 | } 18 | 19 | path = "dataset_stats/all_results.json" 20 | with open(path, "r") as f: 21 | all_results = json.load(f) 22 | 23 | features = [f for f in all_results] 24 | print(" & ".join(features) + "\\\\") 25 | 26 | for feature in features: 27 | print(feature) 28 | p,n,total = 0,0,0 29 | for subject in all_results[feature]: 30 | feature_results = all_results[feature][subject] 31 | pi = feature_results["positive"] 32 | ni = feature_results["negative"] 33 | p += pi 34 | n += ni 35 | total += pi + ni 36 | print("positive", p, "negative", n, "total", total) 37 | 38 | for subject in all_results["rms_finetuning"]: 39 | row_nums = [subject2label[subject]] 40 | for feature in features: 41 | feature_results = all_results[feature][subject] 42 | p = feature_results["positive"] 43 | n = feature_results["negative"] 44 | row_nums.append(f'{p:,}') 45 | row_nums.append(f'{n:,}') 46 | row_str = " & ".join(row_nums) + "\\\\" 47 | print(row_str) 48 | 49 | -------------------------------------------------------------------------------- /testing/process_linear_results.py: -------------------------------------------------------------------------------- 1 | # reads the ROC-AUC of linear results and stores the top k 2 | #usage 3 | #python3 -m testing.process_linear_results +test=process_linear_results 4 | from omegaconf import DictConfig, OmegaConf 5 | import hydra 6 | import logging 7 | import os 8 | from data.electrode_selection import get_clean_laplacian_electrodes 9 | import json 10 | from .utils import run_electrode_test 11 | import numpy as np 12 | import glob 13 | from pathlib import Path 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | @hydra.main(config_path="../conf", version_base=None) 18 | def main(cfg: DictConfig) -> None: 19 | log.info(f"Processing linear results") 20 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 21 | log.info(f'Working directory {os.getcwd()}') 22 | results_path = cfg.test.linear_results_path 23 | results_files = glob.glob(os.path.join(results_path, "*", "*")) 24 | all_results = [] 25 | for path in results_files: 26 | subj = path.split("/")[-2] 27 | with open(path, "r") as f: 28 | subj_res = json.load(f) 29 | subj_res = subj_res[subj] 30 | for e in subj_res.keys(): 31 | name = f'{subj}_{e}' 32 | roc_auc = subj_res[e]['roc_auc'] 33 | if "single_subject" not in cfg.test or cfg.test.single_subject==subj: 34 | all_results.append((name, roc_auc)) 35 | k = cfg.test.topk 36 | topk = sorted(all_results, key=lambda x: x[1])[-k:] 37 | topk_results = {} 38 | for k, _ in topk: 39 | subj, e = k.split("_") 40 | if subj not in topk_results: 41 | topk_results[subj] = [] 42 | topk_results[subj].append(e) 43 | Path(cfg.test.out_dir).mkdir(exist_ok=True, parents=True) 44 | out_path = os.path.join(cfg.test.out_dir, "linear_results.json") 45 | with open(out_path, "w") as f: 46 | json.dump(topk_results, f) 47 | main() 48 | -------------------------------------------------------------------------------- /testing/run_fewshot_training_tests.py: -------------------------------------------------------------------------------- 1 | #usage 2 | #python3 -m testing.run_fewshot_training_tests +exp=finetune ++exp.runner.num_workers=0 +data=onset_finetuning +model=deep_linear_wav_baseline +task=baseline_wav_task +criterion=baseline_criterion ++exp.runner.scheduler.name=reduce_on_plateau ++exp.runner.log_step=100 +preprocessor=wav_preprocessor ++data.electrodes=["T1cIe11"] ++exp.runner.total_steps=10 ++data.delta=-2.5 ++data.duration=5.0 ++data.cached_data_array=/storage/czw/self_supervised_seeg/cached_data_arrays ++data.name="onset_finetuning" ++data.train_fewshot=??? ++data.reload_caches=??? 3 | from omegaconf import DictConfig, OmegaConf 4 | import hydra 5 | import logging 6 | import os 7 | from data.electrode_selection import get_clean_laplacian_electrodes 8 | import json 9 | from .utils import run_electrode_test 10 | import numpy as np 11 | from pathlib import Path 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | @hydra.main(config_path="../conf") 16 | def main(cfg: DictConfig) -> None: 17 | log.info(f"Run testing for all training_data_percentages in one electrodes in one test_subject") 18 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 19 | log.info(f'Working directory {os.getcwd()}') 20 | 21 | out_dir = cfg.test.out_dir 22 | log.info(f'Out directory {out_dir}') 23 | Path(out_dir).mkdir(exist_ok=True, parents=True) 24 | 25 | all_test_results = {} 26 | train_num = range(cfg.test.test_ex_min,cfg.test.test_ex_max,cfg.test.test_ex_step) 27 | cfg.data.reload_caches=False 28 | for train_n in train_num: 29 | cfg.data.train_fewshot = train_n 30 | test_result = run_electrode_test(cfg) 31 | cfg.data.reload_caches=False #don't need to cache after first run 32 | all_test_results[train_n] = test_result 33 | 34 | out_json = os.path.join(out_dir, "all_test_results.json") 35 | with open(out_json, "w") as f: 36 | json.dump(all_test_results, f) 37 | log.info(f"Wrote test results to {out_json}") 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /testing/run_single_electrode_tests.py: -------------------------------------------------------------------------------- 1 | #usage 2 | #python3 -m testing.run_single_electrode_tests +exp=finetune ++exp.runner.num_workers=0 +data=onset_finetuning +model=deep_linear_wav_baseline +task=baseline_wav_task +criterion=baseline_criterion ++exp.runner.scheduler.name=reduce_on_plateau ++exp.runner.log_step=100 +preprocessor=wav_preprocessor ++data.electrodes=["T1cIe11"] ++exp.runner.total_steps=10 ++data.delta=-2.5 ++data.duration=5.0 ++data.cached_data_array=/storage/czw/self_supervised_seeg/cached_data_arrays ++data.name="onset_finetuning"++data.reload_caches=??? 3 | from omegaconf import DictConfig, OmegaConf 4 | import hydra 5 | import logging 6 | import os 7 | from data.electrode_selection import get_clean_laplacian_electrodes 8 | import json 9 | from .utils import run_electrode_test 10 | import numpy as np 11 | from pathlib import Path 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | @hydra.main(config_path="../conf") 16 | def main(cfg: DictConfig) -> None: 17 | log.info(f"Run testing for all training_data_percentages in one electrodes in one test_subject") 18 | log.info(OmegaConf.to_yaml(cfg, resolve=True)) 19 | log.info(f'Working directory {os.getcwd()}') 20 | 21 | cfg.data.reload_caches=False 22 | test_result = run_electrode_test(cfg) 23 | cfg.data.reload_caches=False #don't need to cache after first run 24 | all_test_results = test_result 25 | 26 | out_json = os.path.join(os.getcwd(), "all_test_results.json") 27 | with open(out_json, "w") as f: 28 | json.dump(all_test_results, f) 29 | log.info(f"Wrote test results to {out_json}") 30 | Path(cfg.test.output_dir).mkdir(exist_ok=True, parents=True) 31 | out_json = os.path.join(cfg.test.output_dir, "all_test_results.json") 32 | with open(out_json, "w") as f: 33 | json.dump(all_test_results, f) 34 | log.info(f"Wrote test results to {out_json}") 35 | 36 | if __name__ == "__main__": 37 | main() 38 | 39 | 40 | -------------------------------------------------------------------------------- /testing/select_fewshot_learning_electrode.py: -------------------------------------------------------------------------------- 1 | #python3 -m testing.select_fewshot_learning_electrode 2 | import numpy as np 3 | import glob 4 | import json 5 | import os 6 | 7 | superlet_full_brain = "/storage/czw/self_supervised_seeg/full_brain_test_outs/superlet_large_pretrained/onset_finetuning/all_test_results/" 8 | linear_full_brain = "/storage/czw/self_supervised_seeg/outputs/2022-08-31/01-02-47/all_test_results/" 9 | 10 | #collect linear results 11 | def collect_results(results_path): 12 | results_files = glob.glob(os.path.join(results_path, "*", "*")) 13 | all_results = {} 14 | for path in results_files: 15 | subj = path.split("/")[-2] 16 | with open(path, "r") as f: 17 | subj_res = json.load(f) 18 | subj_res = subj_res[subj] 19 | all_results[subj] = subj_res 20 | return all_results 21 | 22 | linear_results = collect_results(linear_full_brain) 23 | superlet_results = collect_results(superlet_full_brain) 24 | 25 | def rank_results(results): 26 | all_results = [] 27 | all_results_d = {} 28 | for s in results: 29 | elecs = results[s] 30 | for e in elecs: 31 | roc = elecs[e]["roc_auc"] 32 | all_results.append(((s,e), roc)) 33 | all_results = sorted(all_results, key=lambda x: -x[1]) 34 | rank_results = [(x[0], i) for i,x in enumerate(all_results)] 35 | return sorted(rank_results, key=lambda x: x[0]) 36 | 37 | ranked_linear = rank_results(linear_results) 38 | import pdb; pdb.set_trace() 39 | ranked_superlet = rank_results(superlet_results) 40 | assert len(ranked_linear)==len(ranked_superlet) 41 | all_ranks = [] 42 | for i in range(len(ranked_linear)): 43 | se1, rank1 = ranked_superlet[i] 44 | se2, rank2 = ranked_linear[i] 45 | assert se1==se2 46 | all_ranks.append((se1, rank1+rank2, rank1, rank2)) 47 | print(sorted(all_ranks, key=lambda x:x[1])[:3]) 48 | 49 | #all_results = [] 50 | #for s in linear_results: 51 | # elecs = linear_results[s] 52 | # for e in elecs: 53 | # l_roc = elecs[e]["roc_auc"] 54 | # s_roc = superlet_results[s][e]["roc_auc"] 55 | # all_results.append(((s,e), l_roc+s_roc, l_roc, s_roc)) 56 | # #all_results.append(((s,e), np.sqrt(l_roc*s_roc))) 57 | # 58 | #print(sorted(all_results, key=lambda x: x[1])[-1]) 59 | -------------------------------------------------------------------------------- /testing/utils.py: -------------------------------------------------------------------------------- 1 | import models 2 | import tasks 3 | from runner import Runner 4 | 5 | def run_electrode_test(cfg): 6 | #run a test for a single electrode 7 | test_results = [] 8 | orig_cfg = cfg.copy() 9 | for i in range(cfg.test.test_runs): 10 | cfg = orig_cfg.copy()#data.cfg.cached_transcript_aligns gets modified downstream 11 | if i>0: 12 | cfg.data.reload_caches=False #don't need to cache after first run 13 | task = tasks.setup_task(cfg.task) 14 | task.load_datasets(cfg.data, cfg.preprocessor) 15 | model = task.build_model(cfg.model) 16 | criterion = task.build_criterion(cfg.criterion) 17 | runner = Runner(cfg.exp.runner, task, model, criterion) 18 | best_model = runner.train() 19 | test_results.append(runner.test(best_model)) 20 | return test_results 21 | 22 | def run_subject_test(data_cfg, brain_runs, electrodes, cfg): 23 | cache_path = None 24 | if "cache_input_features" in data_cfg: 25 | cache_path = data_cfg.cache_input_features 26 | 27 | subject_test_results = {} 28 | for e in electrodes: 29 | data_cfg.electrodes = [e] 30 | data_cfg.brain_runs = brain_runs 31 | if cache_path is not None: 32 | #cache_path needs to identify the pretrained model 33 | e_cache_path = os.path.join(cache_path, data_cfg.subject, data_cfg.name ,e) 34 | log.info(f"logging input features in {e_cache_path}") 35 | data_cfg.cache_input_features = e_cache_path 36 | cfg.data = data_cfg 37 | test_results = run_electrode_test(cfg) 38 | subject_test_results[e] = test_results 39 | return subject_test_results 40 | -------------------------------------------------------------------------------- /util/mask_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | def get_mask_fill_value(data): 6 | return 0 7 | #return data.min() - 1 8 | 9 | def create_masked_intervals(data, consecutive_min, consecutive_max, mask_p, axis="time"): 10 | consecutive_min = consecutive_min 11 | consecutive_max = consecutive_max 12 | assert consecutive_min <= consecutive_max 13 | assert consecutive_max < data.shape[0] 14 | 15 | if axis=="time": 16 | valid_starts = range(len(data)-consecutive_max) 17 | else: 18 | valid_starts = range(data.shape[1]-consecutive_max) 19 | masked_steps = [] 20 | for i in valid_starts: 21 | if random.random() < mask_p: 22 | if len(masked_steps)==0 or masked_steps[-1][1] < i: 23 | masked_steps.append((i, i+random.randint(consecutive_min, consecutive_max))) 24 | #masked_steps = [i for i in valid_starts if random.random() < mask_p] 25 | #masked_steps = [(i, i+random.randint(consecutive_min, consecutive_max)) for i in masked_steps] 26 | return masked_steps, valid_starts 27 | 28 | def fixed_mask_inputs(data, task_cfg): 29 | mask_label = torch.zeros_like(data) 30 | 31 | masked_steps, valid_starts = create_masked_intervals(data, task_cfg.time_mask_consecutive_min, task_cfg.time_mask_consecutive_max, task_cfg.time_mask_p, axis="time") 32 | for (start,end) in masked_steps: 33 | mask_label[start:end,:] = 1 34 | 35 | masked_data = torch.clone(data) 36 | mask_fill_value = get_mask_fill_value(data) 37 | for (start,end) in masked_steps: 38 | dice = random.random() 39 | if dice < 0.1:#TODO look at attentions 40 | pass 41 | elif dice < 0.2: 42 | random_replace_start = random.randint(0, len(valid_starts)-1) 43 | diff = end-start 44 | masked_data[start:end,:] = data[random_replace_start:random_replace_start+diff,:] 45 | else: 46 | masked_data[start:end,:] = mask_fill_value 47 | 48 | for (start,end) in masked_steps: 49 | mask_label[:,start:end] = 1 50 | 51 | masked_steps, valid_starts = create_masked_intervals(data, task_cfg.freq_mask_consecutive_min, task_cfg.freq_mask_consecutive_max, task_cfg.freq_mask_p, axis="freq") 52 | for (start,end) in masked_steps: 53 | dice = random.random() 54 | if dice < 0.1:#TODO look at attentions 55 | pass 56 | elif dice < 0.2: 57 | random_replace_start = valid_starts[random.randint(0, len(valid_starts)-1)] 58 | diff = end-start 59 | masked_data[:,start:end] = data[:,random_replace_start:random_replace_start+diff] 60 | else: 61 | masked_data[:,start:end] = mask_fill_value 62 | return masked_data, mask_label 63 | 64 | def variable_mask_time(data, task_cfg): 65 | decim = 60 66 | sample_rate = 2048 67 | max_size_in_secs = 0.250 68 | max_size_in_samples = max_size_in_secs*sample_rate/decim 69 | 70 | min_size_in_samples = random.randint(1,2) 71 | 72 | fs = np.linspace(task_cfg.min_f, task_cfg.max_f, task_cfg.n_freq_steps) 73 | window_sizes = [int(max(min_size_in_samples,200/(25+f))) for f in fs] 74 | #window_sizes = [int(max(0,160/(30+f))) for f in fs] 75 | 76 | max_size = max(window_sizes) 77 | valid_starts = list(np.arange(max_size, data.shape[0] - max_size)) #remember that mask is centered on time position 78 | 79 | def fill_in_time_mask(array, position, value=None, value_slice=None): 80 | #value -- what value to fill the template with 81 | assert not (value != None and value_slice != None) 82 | arr_len = array.shape[0] 83 | if value_slice is not None: 84 | for i in range(len(window_sizes)): 85 | array[max(0,position-window_sizes[i]):min(arr_len,position+window_sizes[i]),i] = value_slice[i] 86 | else: 87 | for i in range(len(window_sizes)): 88 | array[max(0,position-window_sizes[i]):min(arr_len,position+window_sizes[i]),i] = value 89 | return array 90 | 91 | def take_time_mask(array, position): 92 | arr_len = array.shape[0] 93 | value_slice = [] 94 | for i in range(len(window_sizes)): 95 | value_slice.append(array[max(0,position-window_sizes[i]):min(arr_len,position+window_sizes[i]), i]) 96 | return value_slice 97 | 98 | masked_positions = [] 99 | max_window = 2*max(window_sizes) 100 | for pos in valid_starts: 101 | if random.random() < task_cfg.mask_p: 102 | if len(masked_positions)==0 or abs(masked_positions[-1] - pos) > max_window+1: 103 | masked_positions.append(pos) 104 | #import pdb; pdb.set_trace() 105 | masked_data = torch.clone(data) 106 | mask_label = torch.zeros_like(data) 107 | for position in masked_positions: 108 | dice = random.random() 109 | if dice < 0.1:#TODO look at attentions 110 | pass 111 | elif dice < 0.2: 112 | random_position = valid_starts[random.randint(0, len(valid_starts)-1)] 113 | value_slice = take_time_mask(data, random_position) 114 | masked_data = fill_in_time_mask(masked_data, position, value_slice=value_slice) 115 | else: 116 | mask_fill_value = get_mask_fill_value(data) 117 | masked_data = fill_in_time_mask(masked_data, position, value=mask_fill_value) 118 | mask_label = fill_in_time_mask(mask_label, position, 1) 119 | return masked_data, mask_label 120 | 121 | def variable_mask_freq(data, task_cfg): 122 | fs = np.linspace(task_cfg.min_f, task_cfg.max_f, task_cfg.n_freq_steps) 123 | #mask_sizes = list(reversed([max(1,int(max(0,160/(30+f)))) for f in fs])) 124 | mask_sizes = [max(1,int(4.9*(f)/250)) for f in fs] 125 | idx2mask_size = {i:s for i,s in enumerate(mask_sizes)} 126 | valid_starts = list(range(data.shape[1] - max(mask_sizes))) 127 | masked_positions = [i for i in valid_starts if random.random() < task_cfg.mask_p] 128 | 129 | masked_data = torch.clone(data) 130 | mask_label = torch.zeros_like(data) 131 | 132 | mask_fill_value = get_mask_fill_value(data) 133 | for position in masked_positions: 134 | diff = idx2mask_size[position] 135 | dice = random.random() 136 | if dice < 0.1:#TODO look at attentions 137 | pass 138 | elif dice < 0.2: 139 | random_replace_start = valid_starts[random.randint(0, len(valid_starts)-1)] 140 | masked_data[:,position:position+diff] = data[:,random_replace_start:random_replace_start+diff] 141 | else: 142 | masked_data[:,position:position+diff] = mask_fill_value 143 | mask_label[:,position:position+diff] = 1 144 | return masked_data, mask_label 145 | 146 | def variable_mask(data, task_cfg): 147 | masked_data, mask_label = variable_mask_time(data, task_cfg) 148 | 149 | masked_data, freq_mask_label = variable_mask_freq(masked_data, task_cfg) 150 | mask_label += freq_mask_label 151 | return masked_data, mask_label 152 | 153 | def mask_inputs(data, task_cfg): 154 | if task_cfg.mask_type=="fixed": 155 | return fixed_mask_inputs(data, task_cfg) 156 | elif task_cfg.mask_type=="variable": 157 | return variable_mask(data, task_cfg) 158 | -------------------------------------------------------------------------------- /util/tensorboard_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | def plot_tensorboard_line(wav, title=None): 5 | #needs to be [batch, height, width, channels] 6 | #spec = spec.transpose(0,1).unsqueeze(0) 7 | #return spec 8 | 9 | fig, ax = plt.subplots(figsize=(12, 3)) 10 | im = ax.plot(wav) 11 | plt.xlabel("time") 12 | plt.ylabel("voltage") 13 | if title: 14 | plt.title(title) 15 | plt.tight_layout() 16 | 17 | fig.canvas.draw() 18 | data = plot_to_tensorboard(fig) 19 | plt.close() 20 | return data 21 | 22 | def plot_tensorboard_spectrogram(spec): 23 | #needs to be [batch, height, width, channels] 24 | #spec = spec.transpose(0,1).unsqueeze(0) 25 | #return spec 26 | 27 | spec = spec.transpose(1, 0) 28 | spec = spec.detach().cpu() 29 | fig, ax = plt.subplots(figsize=(12, 3)) 30 | im = ax.imshow(spec, aspect="auto", origin="lower", 31 | interpolation='none') 32 | plt.colorbar(im, ax=ax) 33 | plt.xlabel("Frames") 34 | plt.ylabel("Channels") 35 | plt.tight_layout() 36 | 37 | fig.canvas.draw() 38 | data = plot_to_tensorboard(fig) 39 | plt.close() 40 | return data 41 | 42 | def plot_to_tensorboard(fig): 43 | """ 44 | From https://martin-mundt.com/tensorboard-figures/ 45 | """ 46 | 47 | # Draw figure on canvas 48 | fig.canvas.draw() 49 | 50 | # Convert the figure to numpy array, read the pixel values and reshape the array 51 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 52 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 53 | 54 | # Normalize into 0-1 range for TensorBoard(X). Swap axes for newer versions where API expects colors in first dim 55 | img = img / 255.0 56 | img = np.swapaxes(img, 0, 2) # if your TensorFlow + TensorBoard version are >= 1.8 57 | img = np.transpose(img, axes=[0,2,1]) 58 | 59 | # Add figure in numpy "image" to TensorBoard writer 60 | plt.close(fig) 61 | return img 62 | 63 | --------------------------------------------------------------------------------