├── README.md ├── architecture.png ├── config.py ├── data_raw.py ├── environment.yml ├── eval_decoding_raw.py ├── model_decoding_raw.py ├── overview.png ├── scripts ├── eval_decoding_raw.sh ├── prepare_dataset_raw.sh └── train_decoding_raw.sh ├── train_decoding_raw.py └── util ├── __pycache__ └── data_loading_helpers_modified.cpython-38.pyc ├── construct_dataset_mat_to_pickle_v1_withRaw.py ├── construct_dataset_mat_to_pickle_v2_withRaw.py └── data_loading_helpers_modified.py /README.md: -------------------------------------------------------------------------------- 1 | --- 2 |
3 | 4 | Overview 5 | 6 | 7 |

Deep Representation Learning for Open Vocabulary Electroencephalography-to-Text Decoding

8 | 9 |

10 | Hamza Amrani, Daniela Micucci, Paolo Napoletano 11 |
12 | hamza.amrani@unimib.it, 13 | daniela.micucci@unimib.it, 14 | paolo.napoletano@unimib.it 15 | 16 |

17 |
18 | 19 |
20 | 21 | 22 | [![Paper](https://img.shields.io/badge/paper-arxiv.2312.09430-B31B1B.svg)](https://arxiv.org/abs/2312.09430) 23 | 27 | 28 |
29 | 30 | 31 | 32 | ## Abstract 33 | 34 | Previous research has demonstrated the potential of using pre-trained language models for decoding open vocabulary Electroencephalography (EEG) signals captured through a non-invasive Brain-Computer Interface (BCI). However, the impact of embedding EEG signals in the context of language models and the effect of subjectivity, remain unexplored, leading to uncertainty about the best approach to enhance decoding performance. Additionally, current evaluation metrics used to assess decoding effectiveness are predominantly syntactic and do not provide insights into the comprehensibility of the decoded output for human understanding. We present an end-to-end architecture for non-invasive brain recordings that brings modern representational learning approaches to neuroscience. Our proposal introduces the following innovations: 1) an end-to-end deep learning architecture for open vocabulary EEG decoding, incorporating a subject-dependent representation learning module for raw EEG encoding, a BART language model, and a GPT-4 sentence refinement module; 2) a more comprehensive sentence-level evaluation metric based on the BERTScore; 3) an ablation study that analyses the contributions of each module within our proposal, providing valuable insights for future research. We evaluate our approach on two publicly available datasets, ZuCo v1.0 and v2.0, comprising EEG recordings of 30 subjects engaged in natural reading tasks. Our model achieves a BLEU-1 score of 42.75%, a ROUGE-1-F of 33.28%, and a BERTScore-F of 53.86%, achieving an increment over the previous state-of-the-art by 1.40%, 2.59%, and 3.20%, respectively. 35 | 36 | 37 | 38 | ## Architecture 39 |
40 | Architecture 41 |
42 | 43 | 44 | 45 | ## Code 46 | 47 | This repo is based on the [EEG-to-Text](https://github.com/MikeWangWZHL/EEG-To-Text) repository. 48 | 49 | ### Download ZuCo datasets 50 | - Download ZuCo v1.0 'Matlab files' for 'task1-SR','task2-NR','task3-TSR' from https://osf.io/q3zws/files/ under 'OSF Storage' root, 51 | unzip and move all `.mat` files to `/dataset/ZuCo/task1-SR/Matlab_files`,`/dataset/ZuCo/task2-NR/Matlab_files`,`/dataset/ZuCo/task3-TSR/Matlab_files` respectively. 52 | - Download ZuCo v2.0 'Matlab files' for 'task1-NR' from https://osf.io/2urht/files/ under 'OSF Storage' root, unzip and move all `.mat` files to `/dataset/ZuCo/task2-NR-2.0/Matlab_files`. 53 | 54 | ### Preprocess datasets 55 | run `bash ./scripts/prepare_dataset_raw.sh` to preprocess `.mat` files and prepare sentiment labels. 56 | 57 | For each task, all `.mat` files will be converted into one `.pickle` file stored in `/dataset/ZuCo//-dataset.pickle`. 58 | 59 | ### Usage Example 60 | To train an EEG-To-Text decoding model, run `bash ./scripts/train_decoding_raw.sh`. 61 | 62 | To evaluate the trained EEG-To-Text decoding model from above, run `bash ./scripts/eval_decoding_raw.sh`. 63 | 64 | 65 | 66 | 67 | ## Citation 68 | 69 | ``` 70 | @article{amrani2024deep, 71 | title={Deep Representation Learning for Open Vocabulary Electroencephalography-to-Text Decoding}, 72 | author={Amrani, Hamza and Micucci, Daniela and Napoletano, Paolo}, 73 | journal={IEEE Journal of Biomedical and Health Informatics}, 74 | year={2024}, 75 | publisher={IEEE} 76 | } 77 | 78 | ``` 79 | 80 | -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamzaamrani/EEG-to-Text-Decoding/0b8cf03b040e37f9e573fced95ad47667b194557/architecture.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v): 4 | if isinstance(v, bool): 5 | return v 6 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 7 | return True 8 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 9 | return False 10 | else: 11 | raise argparse.ArgumentTypeError('Boolean value expected.') 12 | 13 | def get_config(case): 14 | if case == 'train_decoding': 15 | # args config for training EEG-To-Text decoder 16 | parser = argparse.ArgumentParser(description='Specify config args for training EEG-To-Text decoder') 17 | 18 | parser.add_argument('-m', '--model_name', help='choose from {BrainTranslator, BrainTranslatorNaive}', default = "BrainTranslator" ,required=True) 19 | parser.add_argument('-t', '--task_name', help='choose from {task1,task1_task2, task1_task2_task3,task1_task2_taskNRv2}', default = "task1", required=True) 20 | 21 | parser.add_argument('-1step', '--one_step', dest='skip_step_one', action='store_true') 22 | parser.add_argument('-2step', '--two_step', dest='skip_step_one', action='store_false') 23 | 24 | parser.add_argument('-pre', '--pretrained', dest='use_random_init', action='store_false') 25 | parser.add_argument('-rand', '--rand_init', dest='use_random_init', action='store_true') 26 | 27 | parser.add_argument('-load1', '--load_step1_checkpoint', dest='load_step1_checkpoint', action='store_true') 28 | parser.add_argument('-no-load1', '--not_load_step1_checkpoint', dest='load_step1_checkpoint', action='store_false') 29 | 30 | parser.add_argument('-ne1', '--num_epoch_step1', type = int, help='num_epoch_step1', default = 20, required=True) 31 | parser.add_argument('-ne2', '--num_epoch_step2', type = int, help='num_epoch_step2', default = 30, required=True) 32 | parser.add_argument('-lr1', '--learning_rate_step1', type = float, help='learning_rate_step1', default = 0.00005, required=True) 33 | parser.add_argument('-lr2', '--learning_rate_step2', type = float, help='learning_rate_step2', default = 0.0000005, required=True) 34 | parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True) 35 | 36 | parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/decoding', required=True) 37 | parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False) 38 | parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False) 39 | parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False) 40 | parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0') 41 | 42 | args = vars(parser.parse_args()) 43 | 44 | elif case == 'eval_decoding': 45 | # args config for evaluating EEG-To-Text decoder 46 | parser = argparse.ArgumentParser(description='Specify config args for evaluate EEG-To-Text decoder') 47 | parser.add_argument('-checkpoint', '--checkpoint_path', help='specify model checkpoint' ,required=True) 48 | parser.add_argument('-conf', '--config_path', help='specify training config json' ,required=True) 49 | parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0') 50 | args = vars(parser.parse_args()) 51 | 52 | 53 | return args -------------------------------------------------------------------------------- /data_raw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pickle 4 | from torch.utils.data import Dataset 5 | import json 6 | import matplotlib.pyplot as plt 7 | from glob import glob 8 | from transformers import BartTokenizer, BertTokenizer 9 | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence 10 | 11 | 12 | ZUCO_SENTIMENT_LABELS = json.load(open('./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json')) 13 | SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json')) 14 | 15 | from scipy.signal import butter, lfilter 16 | from scipy.signal import freqz 17 | def butter_bandpass_filter(signal, lowcut, highcut, fs=500, order=5): 18 | nyq = 0.5 * fs 19 | low = lowcut/nyq 20 | high = highcut/nyq 21 | b,a = butter(order, [low, high], btype='band') 22 | y = lfilter(b, a, signal, axis=-1) 23 | 24 | return torch.Tensor(y).float() 25 | 26 | def normalize_1d(input_tensor): 27 | # normalize a 1d tensor 28 | mean = torch.mean(input_tensor) 29 | std = torch.std(input_tensor) 30 | input_tensor = (input_tensor - mean)/std 31 | return input_tensor 32 | 33 | def get_input_sample(sent_obj, tokenizer, eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], max_len = 56, add_CLS_token = False, subj='unspecified',raw_eeg=False): 34 | 35 | def get_word_embedding_eeg_tensor(word_obj, eeg_type, bands): 36 | frequency_features = [] 37 | content=word_obj['content'] 38 | for band in bands: 39 | frequency_features.append(word_obj['word_level_EEG'][eeg_type][eeg_type+band]) 40 | word_eeg_embedding = np.concatenate(frequency_features) 41 | if len(word_eeg_embedding) != 105*len(bands): 42 | print(f'expect word: {content} of subj: {subj} eeg embedding dim to be {105*len(bands)}, but got {len(word_eeg_embedding)}, return None') 43 | return None 44 | # assert len(word_eeg_embedding) == 105*len(bands) 45 | return_tensor = torch.from_numpy(word_eeg_embedding) 46 | return normalize_1d(return_tensor) 47 | 48 | def get_word_raweeg_tensor(word_obj): 49 | word_raw_eeg = word_obj['rawEEG'][0] #1000 50 | return_tensor = torch.from_numpy(word_raw_eeg) 51 | return return_tensor 52 | 53 | def get_sent_eeg(sent_obj, bands): 54 | sent_eeg_features = [] 55 | for band in bands: 56 | key = 'mean'+band 57 | sent_eeg_features.append(sent_obj['sentence_level_EEG'][key]) 58 | sent_eeg_embedding = np.concatenate(sent_eeg_features) 59 | assert len(sent_eeg_embedding) == 105*len(bands) 60 | return_tensor = torch.from_numpy(sent_eeg_embedding) 61 | return normalize_1d(return_tensor) 62 | 63 | 64 | if sent_obj is None: 65 | # print(f' - skip bad sentence') 66 | return None 67 | 68 | input_sample = {} 69 | # get target label 70 | target_string = sent_obj['content'] 71 | 72 | target_tokenized = tokenizer(target_string, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True) 73 | input_sample['target_ids'] = target_tokenized['input_ids'][0] 74 | 75 | 76 | # get sentence level EEG features 77 | sent_level_eeg_tensor = get_sent_eeg(sent_obj, bands) 78 | if torch.isnan(sent_level_eeg_tensor).any(): 79 | return None 80 | input_sample['sent_level_EEG'] = sent_level_eeg_tensor 81 | 82 | # get sentiment label 83 | # handle some wierd case 84 | if 'emp11111ty' in target_string: 85 | target_string = target_string.replace('emp11111ty','empty') 86 | if 'film.1' in target_string: 87 | target_string = target_string.replace('film.1','film.') 88 | 89 | if target_string in ZUCO_SENTIMENT_LABELS: 90 | input_sample['sentiment_label'] = torch.tensor(ZUCO_SENTIMENT_LABELS[target_string]+1) # 0:Negative, 1:Neutral, 2:Positive 91 | else: 92 | input_sample['sentiment_label'] = torch.tensor(-100) 93 | 94 | # get input embeddings 95 | word_embeddings = [] 96 | word_raw_embeddings = [] 97 | word_contents = [] 98 | 99 | """add CLS token embedding at the front""" 100 | if add_CLS_token: 101 | word_embeddings.append(torch.ones(104*len(bands))) 102 | 103 | 104 | for word in sent_obj['word']: 105 | # add each word's EEG embedding as Tensors 106 | word_level_eeg_tensor = get_word_embedding_eeg_tensor(word, eeg_type, bands = bands) 107 | if raw_eeg: 108 | try: 109 | word_level_raw_eeg_tensor = get_word_raweeg_tensor(word) 110 | except: 111 | print('error in raw eeg') 112 | print(word['content']) 113 | print(sent_obj['content']) 114 | print() 115 | return None 116 | # check none, for v2 dataset 117 | if word_level_eeg_tensor is None: 118 | return None 119 | # check nan: 120 | if torch.isnan(word_level_eeg_tensor).any(): 121 | # print() 122 | # print('[NaN ERROR] problem sent:',sent_obj['content']) 123 | # print('[NaN ERROR] problem word:',word['content']) 124 | # print('[NaN ERROR] problem word feature:',word_level_eeg_tensor) 125 | # print() 126 | return None 127 | 128 | word_contents.append(word['content']) 129 | word_embeddings.append(word_level_eeg_tensor) 130 | 131 | if raw_eeg: 132 | word_level_raw_eeg_tensor = word_level_raw_eeg_tensor[:,:104] 133 | word_raw_embeddings.append(word_level_raw_eeg_tensor) 134 | 135 | if len(word_embeddings)<1: 136 | return None 137 | 138 | 139 | # pad to max_len 140 | n_eeg_representations = len(word_embeddings) 141 | while len(word_embeddings) < max_len: 142 | # TODO: FBCSP 143 | word_embeddings.append(torch.zeros(105*len(bands))) 144 | if raw_eeg: 145 | word_raw_embeddings.append(torch.zeros(1,104)) 146 | 147 | word_contents_tokenized = tokenizer(' '.join(word_contents), padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True) 148 | 149 | input_sample['word_contents'] = word_contents_tokenized['input_ids'][0] 150 | input_sample['word_contents_attn'] = word_contents_tokenized['attention_mask'][0] #bart 151 | 152 | input_sample['input_embeddings'] = torch.stack(word_embeddings) # max_len * (105*num_bands) 153 | 154 | if raw_eeg: 155 | input_sample['input_raw_embeddings'] = word_raw_embeddings 156 | 157 | 158 | # mask out padding tokens 159 | input_sample['input_attn_mask'] = torch.zeros(max_len) # 0 is masked out 160 | 161 | if add_CLS_token: 162 | input_sample['input_attn_mask'][:len(sent_obj['word'])+1] = torch.ones(len(sent_obj['word'])+1) # 1 is not masked 163 | else: 164 | input_sample['input_attn_mask'][:len(sent_obj['word'])] = torch.ones(len(sent_obj['word'])) # 1 is not masked 165 | 166 | # mask out padding tokens reverted: handle different use case: this is for pytorch transformers 167 | input_sample['input_attn_mask_invert'] = torch.ones(max_len) # 1 is masked out 168 | 169 | if add_CLS_token: 170 | input_sample['input_attn_mask_invert'][:len(sent_obj['word'])+1] = torch.zeros(len(sent_obj['word'])+1) # 0 is not masked 171 | else: 172 | input_sample['input_attn_mask_invert'][:len(sent_obj['word'])] = torch.zeros(len(sent_obj['word'])) # 0 is not masked 173 | 174 | # mask out target padding for computing cross entropy loss 175 | input_sample['target_mask'] = target_tokenized['attention_mask'][0] 176 | input_sample['seq_len'] = len(sent_obj['word']) 177 | 178 | # clean 0 length data 179 | if input_sample['seq_len'] == 0: 180 | print('discard length zero instance: ', target_string) 181 | return None 182 | 183 | # subject 184 | input_sample['subject']= subj 185 | 186 | return input_sample 187 | 188 | class ZuCo_dataset(Dataset): 189 | def __init__(self, input_dataset_dicts, phase, tokenizer, subject = 'ALL', eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'],raweeg=False, setting = 'unique_sent', is_add_CLS_token = False): 190 | self.inputs = [] 191 | self.tokenizer = tokenizer 192 | 193 | if not isinstance(input_dataset_dicts,list): 194 | input_dataset_dicts = [input_dataset_dicts] 195 | print(f'[INFO]loading {len(input_dataset_dicts)} task datasets') 196 | for input_dataset_dict in input_dataset_dicts: 197 | if subject == 'ALL': 198 | subjects = list(input_dataset_dict.keys()) 199 | print('[INFO]using subjects: ', subjects) 200 | else: 201 | subjects = [subject] 202 | 203 | total_num_sentence = len(input_dataset_dict[subjects[0]]) 204 | 205 | train_divider = int(0.8*total_num_sentence) 206 | dev_divider = train_divider + int(0.1*total_num_sentence) 207 | 208 | print(f'train divider = {train_divider}') 209 | print(f'dev divider = {dev_divider}') 210 | 211 | if setting == 'unique_sent': 212 | # take first 80% as trainset, 10% as dev and 10% as test 213 | if phase == 'train': 214 | print('[INFO]initializing a train set...') 215 | for key in subjects: 216 | for i in range(train_divider): 217 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key,raw_eeg=raweeg) 218 | if input_sample is not None: 219 | input_sample['subject']=key 220 | self.inputs.append(input_sample) 221 | elif phase == 'dev': 222 | print('[INFO]initializing a dev set...') 223 | for key in subjects: 224 | for i in range(train_divider,dev_divider): 225 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key,raw_eeg=raweeg) 226 | if input_sample is not None: 227 | input_sample['subject']=key 228 | self.inputs.append(input_sample) 229 | elif phase == 'all': 230 | print('[INFO]initializing all dataset...') 231 | for key in subjects: 232 | for i in range(int(1*total_num_sentence)): 233 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key,raw_eeg=raweeg) 234 | if input_sample is not None: 235 | input_sample['subject']=key 236 | self.inputs.append(input_sample) 237 | elif phase == 'test': 238 | print('[INFO]initializing a test set...') 239 | for key in subjects: 240 | for i in range(dev_divider,total_num_sentence): 241 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key,raw_eeg=raweeg) 242 | if input_sample is not None: 243 | input_sample['subject']=key 244 | self.inputs.append(input_sample) 245 | elif setting == 'unique_subj': 246 | print('WARNING!!! only implemented for SR v1 dataset ') 247 | # subject ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW'] for train 248 | # subject ['ZMG'] for dev 249 | # subject ['ZPH'] for test 250 | if phase == 'train': 251 | print(f'[INFO]initializing a train set using {setting} setting...') 252 | for i in range(total_num_sentence): 253 | for key in ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH','ZKW']: 254 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key) 255 | if input_sample is not None: 256 | self.inputs.append(input_sample) 257 | if phase == 'dev': 258 | print(f'[INFO]initializing a dev set using {setting} setting...') 259 | for i in range(total_num_sentence): 260 | for key in ['ZMG']: 261 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key) 262 | if input_sample is not None: 263 | self.inputs.append(input_sample) 264 | if phase == 'test': 265 | print(f'[INFO]initializing a test set using {setting} setting...') 266 | for i in range(total_num_sentence): 267 | for key in ['ZPH']: 268 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key) 269 | if input_sample is not None: 270 | self.inputs.append(input_sample) 271 | print('++ adding task to dataset, now we have:', len(self.inputs)) 272 | 273 | #print('[INFO]input tensor size:', self.inputs[0]['input_embeddings'].size()) 274 | #print() 275 | 276 | def __len__(self): 277 | return len(self.inputs) 278 | 279 | def __getitem__(self, idx): 280 | input_sample = self.inputs[idx] 281 | return ( 282 | input_sample['input_embeddings'], 283 | input_sample['seq_len'], 284 | input_sample['input_attn_mask'], 285 | input_sample['input_attn_mask_invert'], 286 | input_sample['target_ids'], 287 | input_sample['target_mask'], 288 | input_sample['sentiment_label'], 289 | input_sample['sent_level_EEG'], 290 | input_sample['input_raw_embeddings'], 291 | input_sample['word_contents'], 292 | input_sample['word_contents_attn'], 293 | input_sample['subject'] 294 | ) 295 | 296 | 297 | '''sanity test''' 298 | if __name__ == '__main__': 299 | 300 | check_dataset = 'ZuCo'#'stanford_sentiment' 301 | 302 | if check_dataset == 'ZuCo': 303 | whole_dataset_dicts = [] 304 | 305 | dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' 306 | with open(dataset_path_task1, 'rb') as handle: 307 | whole_dataset_dicts.append(pickle.load(handle)) 308 | 309 | dataset_path_task2 = './dataset/ZuCo/task2-NR/pickle/task2-NR-dataset.pickle' 310 | with open(dataset_path_task2, 'rb') as handle: 311 | whole_dataset_dicts.append(pickle.load(handle)) 312 | 313 | dataset_path_task2_v2 = './dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset.pickle' 314 | with open(dataset_path_task2_v2, 'rb') as handle: 315 | whole_dataset_dicts.append(pickle.load(handle)) 316 | 317 | print() 318 | for key in whole_dataset_dicts[0]: 319 | print(f'task2_v2, sentence num in {key}:',len(whole_dataset_dicts[0][key])) 320 | print() 321 | 322 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 323 | dataset_setting = 'unique_sent' 324 | subject_choice = 'ALL' 325 | print(f'![Debug]using {subject_choice}') 326 | eeg_type_choice = 'GD' 327 | print(f'[INFO]eeg type {eeg_type_choice}') 328 | bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] 329 | print(f'[INFO]using bands {bands_choice}') 330 | train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, raweeg=True) 331 | dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, raweeg=True) 332 | test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, raweeg=True) 333 | 334 | print('trainset size:',len(train_set)) 335 | print('devset size:',len(dev_set)) 336 | print('testset size:',len(test_set)) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: EEGToTextOpenVoc 2 | channels: 3 | - pytorch 4 | - dglteam 5 | - nvidia 6 | - conda-forge 7 | - anaconda 8 | - intel 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=main 12 | - _openmp_mutex=4.5=1_gnu 13 | - attrs=21.2.0=pyhd3eb1b0_0 14 | - autopep8=1.6.0=pyhd3eb1b0_1 15 | - blas=1.0=mkl 16 | - brotlipy=0.7.0=py38h27cfd23_1003 17 | - bzip2=1.0.8=h7b6447c_0 18 | - c-ares=1.17.1=h27cfd23_0 19 | - ca-certificates=2022.10.11=h06a4308_0 20 | - catalogue=1.0.0=py38_1 21 | - certifi=2022.9.24=py38h06a4308_0 22 | - cffi=1.14.6=py38h400218f_0 23 | - chardet=4.0.0=py38h06a4308_1003 24 | - click=8.0.1=pyhd3eb1b0_0 25 | - conda=4.10.3=py38h578d9bd_0 26 | - conda-package-handling=1.7.3=py38h27cfd23_1 27 | - cryptography=3.4.7=py38hd23ed53_0 28 | - cudatoolkit=11.1.74=h6bb024c_0 29 | - cycler=0.10.0=py38_0 30 | - cymem=2.0.5=py38h2531618_0 31 | - cython-blis=0.4.1=py38h7b6447c_1 32 | - dbus=1.13.18=hb2f20db_0 33 | - decorator=5.0.9=pyhd3eb1b0_0 34 | - dgl-cuda10.2=0.5.2=py38_0 35 | - expat=2.4.1=h2531618_2 36 | - ffmpeg=4.2.2=h20bf706_0 37 | - fontconfig=2.13.1=h6c09931_0 38 | - freetype=2.10.4=h5ab3b9f_0 39 | - glib=2.69.0=h5202010_0 40 | - gmp=6.2.1=h2531618_2 41 | - gnutls=3.6.15=he1e5248_0 42 | - gst-plugins-base=1.14.0=h8213a91_2 43 | - gstreamer=1.14.0=h28cd5cc_2 44 | - icu=58.2=he6710b0_3 45 | - idna=2.10=pyhd3eb1b0_0 46 | - importlib_metadata=3.10.0=hd3eb1b0_0 47 | - intel-openmp=2021.3.0=h06a4308_3350 48 | - jpeg=9b=h024ee3a_2 49 | - kiwisolver=1.3.1=py38h2531618_0 50 | - krb5=1.19.2=hac12032_0 51 | - lame=3.100=h7b6447c_0 52 | - lcms2=2.12=h3be6417_0 53 | - ld_impl_linux-64=2.35.1=h7274673_9 54 | - libcurl=7.78.0=h0b77cf5_0 55 | - libedit=3.1.20210216=h27cfd23_1 56 | - libev=4.33=h7b6447c_0 57 | - libffi=3.3=he6710b0_2 58 | - libgcc-ng=9.3.0=h5101ec6_17 59 | - libgfortran-ng=7.5.0=ha8ba4b0_17 60 | - libgfortran4=7.5.0=ha8ba4b0_17 61 | - libgomp=9.3.0=h5101ec6_17 62 | - libidn2=2.3.2=h7f8727e_0 63 | - libnghttp2=1.41.0=hf8bcb03_2 64 | - libopus=1.3.1=h7b6447c_0 65 | - libpng=1.6.37=hbc83047_0 66 | - libssh2=1.9.0=h1ba5d50_1 67 | - libstdcxx-ng=9.3.0=hd4cf53a_17 68 | - libtasn1=4.16.0=h27cfd23_0 69 | - libtiff=4.2.0=h85742a9_0 70 | - libunistring=0.9.10=h27cfd23_0 71 | - libuuid=1.0.3=h1bed415_2 72 | - libuv=1.40.0=h7b6447c_0 73 | - libvpx=1.7.0=h439df22_0 74 | - libwebp-base=1.2.0=h27cfd23_0 75 | - libxcb=1.14=h7b6447c_0 76 | - libxml2=2.9.12=h03d6c58_0 77 | - lz4-c=1.9.3=h295c915_1 78 | - matplotlib=3.3.2=0 79 | - matplotlib-base=3.3.2=py38h817c723_0 80 | - mkl=2020.2=256 81 | - mkl-service=2.3.0=py38he904b0f_0 82 | - mkl_fft=1.3.0=py38h54f3939_0 83 | - mkl_random=1.1.1=py38h0573a6f_0 84 | - murmurhash=1.0.5=py38h2531618_0 85 | - ncurses=6.2=he6710b0_1 86 | - nettle=3.7.3=hbbd107a_1 87 | - networkx=2.5=py_0 88 | - ninja=1.10.2=hff7bd54_1 89 | - olefile=0.46=py_0 90 | - openh264=2.1.0=hd408876_0 91 | - openjpeg=2.3.0=h05c96fa_1 92 | - openssl=1.1.1q=h7f8727e_0 93 | - pcre=8.45=h295c915_0 94 | - pip=21.0.1=py38h06a4308_0 95 | - plac=0.9.6=py38_1 96 | - preshed=3.0.5=py38h2531618_4 97 | - pycodestyle=2.10.0=py38h06a4308_0 98 | - pycosat=0.6.3=py38h7b6447c_1 99 | - pycparser=2.20=py_2 100 | - pycurl=7.43.0.6=py38h1ba5d50_0 101 | - pymongo=3.11.2=py38h2531618_0 102 | - pyopenssl=20.0.1=pyhd3eb1b0_1 103 | - pyparsing=2.4.7=pyhd3eb1b0_0 104 | - pyqt=5.9.2=py38h05f1152_4 105 | - pyrsistent=0.17.3=py38h7b6447c_0 106 | - pysocks=1.7.1=py38h06a4308_0 107 | - python=3.8.5=h7579374_1 108 | - python-dateutil=2.8.2=pyhd3eb1b0_0 109 | - python_abi=3.8=2_cp38 110 | - qt=5.9.7=h5867ecd_1 111 | - readline=8.1=h27cfd23_0 112 | - requests=2.25.1=pyhd3eb1b0_0 113 | - ruamel.yaml=0.17.10=py38h497a2fe_0 114 | - ruamel.yaml.clib=0.2.2=py38h497a2fe_2 115 | - ruamel_yaml=0.15.100=py38h27cfd23_0 116 | - setuptools=52.0.0=py38h06a4308_0 117 | - sip=4.19.13=py38he6710b0_0 118 | - six=1.16.0=pyhd3eb1b0_0 119 | - spacy=2.3.2=py38hfd86e86_0 120 | - sqlite=3.36.0=hc218d9a_0 121 | - srsly=1.0.5=py38h2531618_0 122 | - thinc=7.4.1=py38hfd86e86_0 123 | - threadpoolctl=2.2.0=pyhbf3da8f_0 124 | - tk=8.6.10=hbc83047_0 125 | - toml=0.10.2=pyhd3eb1b0_0 126 | - torchaudio=0.9.0=py38 127 | - urllib3=1.26.6=pyhd3eb1b0_1 128 | - wasabi=0.8.2=pyhd3eb1b0_0 129 | - wheel=0.37.0=pyhd3eb1b0_0 130 | - x264=1!157.20191217=h7b6447c_0 131 | - xz=5.2.5=h7b6447c_0 132 | - yaml=0.2.5=h7b6447c_0 133 | - zipp=3.5.0=pyhd3eb1b0_0 134 | - zlib=1.2.11=h7b6447c_3 135 | - zstd=1.4.9=haebb681_0 136 | - pip: 137 | - absl-py==1.4.0 138 | - accelerate==0.17.1 139 | - aiohttp==3.8.4 140 | - aiosignal==1.3.1 141 | - antlr4-python3-runtime==4.9.3 142 | - anyio==3.6.2 143 | - appdirs==1.4.4 144 | - argon2-cffi==21.3.0 145 | - argon2-cffi-bindings==21.2.0 146 | - arrow==1.2.3 147 | - asttokens==2.2.1 148 | - async-timeout==4.0.2 149 | - backcall==0.2.0 150 | - beautifulsoup4==4.11.2 151 | - bert-score==0.3.13 152 | - black==23.1.0 153 | - bleach==6.0.0 154 | - boto3==1.26.94 155 | - botocore==1.29.94 156 | - brevitas==0.8.0 157 | - cachetools==5.3.0 158 | - charset-normalizer==3.1.0 159 | - cloudpickle==2.2.1 160 | - cmake==3.26.0 161 | - comm==0.1.2 162 | - dataclasses-json==0.5.7 163 | - datasets==2.10.1 164 | - debugpy==1.6.6 165 | - defusedxml==0.7.1 166 | - dependencies==2.0.1 167 | - detectron2==0.6 168 | - dill==0.3.6 169 | - executing==1.2.0 170 | - fastjsonschema==2.16.2 171 | - fasttext==0.9.2 172 | - filelock==3.9.0 173 | - fqdn==1.5.1 174 | - frozenlist==1.3.3 175 | - fsspec==2023.3.0 176 | - fst-pso==1.8.1 177 | - funcy==1.18 178 | - future==0.18.3 179 | - fuzzy-match==0.0.1 180 | - fuzzytm==2.0.5 181 | - fvcore==0.1.5.post20221221 182 | - gensim==4.3.0 183 | - geojson==2.5.0 184 | - google-auth==2.16.0 185 | - google-auth-oauthlib==0.4.6 186 | - greenlet==2.0.2 187 | - grpcio==1.51.1 188 | - h5py==2.10.0 189 | - huggingface-hub==0.13.2 190 | - hydra-core==1.3.1 191 | - importlib-metadata==6.0.0 192 | - importlib-resources==5.10.2 193 | - inflect==6.0.2 194 | - inquirerpy==0.3.4 195 | - iopath==0.1.9 196 | - ipykernel==6.21.2 197 | - ipython==8.10.0 198 | - ipython-genutils==0.2.0 199 | - ipywidgets==8.0.4 200 | - isoduration==20.11.0 201 | - jedi==0.18.2 202 | - jinja2==3.1.2 203 | - jmespath==1.0.1 204 | - joblib==1.2.0 205 | - jsonpointer==2.3 206 | - jsonschema==4.17.3 207 | - jupyter==1.0.0 208 | - jupyter-client==8.0.2 209 | - jupyter-console==6.5.1 210 | - jupyter-core==5.2.0 211 | - jupyter-events==0.6.3 212 | - jupyter-server==2.3.0 213 | - jupyter-server-terminals==0.4.4 214 | - jupyterlab-pygments==0.2.2 215 | - jupyterlab-widgets==3.0.5 216 | - langchain==0.0.125 217 | - lit==15.0.7 218 | - mako==1.2.4 219 | - markdown==3.4.1 220 | - markupsafe==2.1.2 221 | - marshmallow==3.19.0 222 | - marshmallow-enum==1.5.1 223 | - matplotlib-inline==0.1.6 224 | - miniful==0.0.6 225 | - mistune==2.0.5 226 | - mpmath==1.3.0 227 | - multidict==6.0.4 228 | - multiprocess==0.70.14 229 | - mypy-extensions==1.0.0 230 | - nbclassic==0.5.1 231 | - nbclient==0.7.2 232 | - nbconvert==7.2.9 233 | - nbformat==5.7.3 234 | - nest-asyncio==1.5.6 235 | - nltk==3.8.1 236 | - notebook==6.5.2 237 | - notebook-shim==0.2.2 238 | - numexpr==2.8.4 239 | - numpy==1.23.5 240 | - nvidia-cublas-cu11==11.10.3.66 241 | - nvidia-cuda-cupti-cu11==11.7.101 242 | - nvidia-cuda-nvrtc-cu11==11.7.99 243 | - nvidia-cuda-runtime-cu11==11.7.99 244 | - nvidia-cudnn-cu11==8.5.0.96 245 | - nvidia-cufft-cu11==10.9.0.58 246 | - nvidia-curand-cu11==10.2.10.91 247 | - nvidia-cusolver-cu11==11.4.0.1 248 | - nvidia-cusparse-cu11==11.7.4.91 249 | - nvidia-nccl-cu11==2.14.3 250 | - nvidia-nvtx-cu11==11.7.91 251 | - oauthlib==3.2.2 252 | - omegaconf==2.3.0 253 | - openai==0.27.2 254 | - opt-einsum==3.3.0 255 | - packaging==23.0 256 | - pandas==1.5.3 257 | - pandocfilters==1.5.0 258 | - parso==0.8.3 259 | - pathspec==0.11.0 260 | - pexpect==4.8.0 261 | - pfzy==0.3.4 262 | - pickleshare==0.7.5 263 | - pillow==9.4.0 264 | - pkgutil-resolve-name==1.3.10 265 | - platformdirs==3.0.0 266 | - plotly==5.13.0 267 | - portalocker==2.7.0 268 | - powerlaw==1.5 269 | - prometheus-client==0.16.0 270 | - prompt-toolkit==3.0.36 271 | - protobuf==3.20.3 272 | - psutil==5.9.4 273 | - ptyprocess==0.7.0 274 | - pure-eval==0.2.2 275 | - pyarrow==11.0.0 276 | - pyasn1==0.4.8 277 | - pyasn1-modules==0.2.8 278 | - pybind11==2.10.4 279 | - pycocotools==2.0.6 280 | - pycuda==2022.2.2 281 | - pydantic==1.10.6 282 | - pyfume==0.2.25 283 | - pygments==2.14.0 284 | - pyldavis==3.3.1 285 | - pyowm==3.3.0 286 | - pyro-api==0.1.2 287 | - pyro-ppl==1.8.4 288 | - python-graphviz==0.20.1 289 | - python-json-logger==2.0.6 290 | - pytools==2022.1.14 291 | - pytorch-pretrained-bert==0.6.2 292 | - pytz==2022.7.1 293 | - pyyaml==6.0 294 | - pyzmq==25.0.0 295 | - qtconsole==5.4.0 296 | - qtpy==2.3.0 297 | - quant-cuda==0.0.0 298 | - regex==2022.10.31 299 | - requests-oauthlib==1.3.1 300 | - responses==0.18.0 301 | - rfc3339-validator==0.1.4 302 | - rfc3986-validator==0.1.1 303 | - rouge==1.0.1 304 | - rouge-score==0.1.2 305 | - rsa==4.9 306 | - s3transfer==0.6.0 307 | - sacremoses==0.0.53 308 | - safetensors==0.3.0 309 | - scikit-learn==1.2.1 310 | - scipy==1.10.1 311 | - send2trash==1.8.0 312 | - sentence-transformers==2.2.2 313 | - sentencepiece==0.1.97 314 | - simpful==2.10.0 315 | - sklearn==0.0.post1 316 | - smart-open==6.3.0 317 | - sniffio==1.3.0 318 | - soupsieve==2.4 319 | - sqlalchemy==1.4.47 320 | - stack-data==0.6.2 321 | - sympy==1.11.1 322 | - tabulate==0.9.0 323 | - tenacity==8.2.1 324 | - tensorboard==2.12.0 325 | - tensorboard-data-server==0.7.0 326 | - tensorboard-plugin-wit==1.8.1 327 | - termcolor==2.2.0 328 | - terminado==0.17.1 329 | - tinycss2==1.2.1 330 | - tokenizers==0.13.2 331 | - tomli==2.0.1 332 | - torch==1.13.1 333 | - torch-summary==1.4.5 334 | - torchcontrib==0.0.2 335 | - torchvision==0.15.1 336 | - torchviz==0.0.2 337 | - tornado==6.2 338 | - tqdm==4.65.0 339 | - traitlets==5.9.0 340 | - transformers==4.28.0.dev0 341 | - triton==2.0.0 342 | - typing-extensions==4.5.0 343 | - typing-inspect==0.8.0 344 | - uri-template==1.2.0 345 | - wcwidth==0.2.6 346 | - webcolors==1.12 347 | - webencodings==0.5.1 348 | - websocket-client==1.5.1 349 | - weightwatcher==0.7.1.5 350 | - werkzeug==2.2.3 351 | - widgetsnbextension==4.0.5 352 | - wordcloud==1.8.2.2 353 | - xxhash==3.2.0 354 | - yacs==0.1.8 355 | - yarl==1.8.2 356 | prefix: /home/hamza/miniconda3/envs/EEGToTextOpenVoc 357 | -------------------------------------------------------------------------------- /eval_decoding_raw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | import pickle 7 | import json 8 | import matplotlib.pyplot as plt 9 | from glob import glob 10 | 11 | from transformers import BartTokenizer, BartForConditionalGeneration 12 | from data_raw import ZuCo_dataset 13 | from model_decoding_raw import BrainTranslator 14 | from nltk.translate.bleu_score import corpus_bleu 15 | from rouge import Rouge 16 | from config import get_config 17 | 18 | from torch.nn.utils.rnn import pad_sequence 19 | 20 | 21 | from langchain.chat_models import ChatOpenAI 22 | from langchain.schema import ( 23 | HumanMessage, 24 | SystemMessage 25 | ) 26 | os.environ["OPENAI_API_KEY"] = "" # GPT4 APIs key 27 | 28 | 29 | # LLMs: Get predictions from ChatGPT 30 | def chatgpt_refinement(corrupted_text): 31 | llm = ChatOpenAI(temperature=0.2, model_name="gpt-4", max_tokens=256) 32 | 33 | messages = [ 34 | SystemMessage(content="As a text reconstructor, your task is to restore corrupted sentences to their original form while making minimum changes. You should adjust the spaces and punctuation marks as necessary. Do not introduce any additional information. If you are unable to reconstruct the text, respond with [False]."), 35 | HumanMessage(content=f"Reconstruct the following text: [{corrupted_text}].") 36 | ] 37 | 38 | output = llm(messages).content 39 | output = output.replace('[','').replace(']','') 40 | 41 | if len(output)<10 and 'False' in output: 42 | return corrupted_text 43 | 44 | return output 45 | 46 | def eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = './results_raw/temp.txt' ): 47 | # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 48 | 49 | gpt=False 50 | 51 | print("Saving to: ", output_all_results_path) 52 | model.eval() # Set model to evaluate mode 53 | running_loss = 0.0 54 | 55 | # Iterate over data. 56 | sample_count = 0 57 | 58 | target_tokens_list = [] 59 | target_string_list = [] 60 | pred_tokens_list = [] 61 | pred_string_list = [] 62 | refine_tokens_list = [] 63 | refine_string_list = [] 64 | 65 | with open(output_all_results_path,'w') as f: 66 | for _, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, input_raw_embeddings_lengths, word_contents, word_contents_attn, subject_batch in dataloaders['test']: 67 | 68 | # load in batch 69 | input_embeddings_batch = input_raw_embeddings.to( 70 | device).float() 71 | input_embeddings_lengths_batch = torch.stack([torch.tensor( 72 | a.clone().detach()) for a in input_raw_embeddings_lengths], 0).to(device) 73 | input_masks_batch = torch.stack(input_masks, 0).to(device) 74 | input_mask_invert_batch = torch.stack( 75 | input_mask_invert, 0).to(device) 76 | target_ids_batch = torch.stack(target_ids, 0).to(device) 77 | word_contents_batch = torch.stack( 78 | word_contents, 0).to(device) 79 | word_contents_attn_batch = torch.stack( 80 | word_contents_attn, 0).to(device) 81 | 82 | subject_batch = np.array(subject_batch) 83 | 84 | target_tokens = tokenizer.convert_ids_to_tokens(target_ids_batch[0].tolist(), skip_special_tokens = True) 85 | target_string = tokenizer.decode(target_ids_batch[0], skip_special_tokens = True) 86 | 87 | 88 | f.write(f'target string: {target_string}\n') 89 | target_tokens_string = "[" 90 | for el in target_tokens: 91 | target_tokens_string = target_tokens_string + str(el) + " " 92 | target_tokens_string += "]" 93 | f.write(f'target tokens: {target_tokens_string}\n') 94 | 95 | # add to list for later calculate bleu metric 96 | target_tokens_list.append([target_tokens]) 97 | target_string_list.append(target_string) 98 | 99 | """replace padding ids in target_ids with -100""" 100 | target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100 101 | 102 | # forward 103 | seq2seqLMoutput = model( 104 | input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch, input_embeddings_lengths_batch, word_contents_batch, word_contents_attn_batch, False, subject_batch, device) 105 | 106 | """calculate loss""" 107 | loss = criterion(seq2seqLMoutput.permute(0,2,1), target_ids_batch.long()) 108 | 109 | # get predicted tokens 110 | logits = seq2seqLMoutput 111 | probs = logits[0].softmax(dim = 1) 112 | # print('probs size:', probs.size()) 113 | values, predictions = probs.topk(1) 114 | # print('predictions before squeeze:',predictions.size()) 115 | predictions = torch.squeeze(predictions) 116 | predicted_string = tokenizer.decode(predictions).split('')[0].replace('','') 117 | # print('predicted string:',predicted_string) 118 | f.write(f'predicted string: {predicted_string}\n') 119 | 120 | # convert to int list 121 | predictions = predictions.tolist() 122 | truncated_prediction = [] 123 | for t in predictions: 124 | if t != tokenizer.eos_token_id: 125 | truncated_prediction.append(t) 126 | else: 127 | break 128 | pred_tokens = tokenizer.convert_ids_to_tokens(truncated_prediction, skip_special_tokens = True) 129 | # print('predicted tokens:',pred_tokens) 130 | pred_tokens_list.append(pred_tokens) 131 | pred_string_list.append(predicted_string) 132 | 133 | # chatgpt refinement and tokenizer decode 134 | if gpt: 135 | predicted_string_chatgpt = chatgpt_refinement(predicted_string).replace('\n','') 136 | f.write(f'refined string: {predicted_string_chatgpt}\n') 137 | refine_tokens_list.append(tokenizer.convert_ids_to_tokens(tokenizer(predicted_string_chatgpt)['input_ids'], skip_special_tokens=True)) 138 | refine_string_list.append(predicted_string_chatgpt) 139 | 140 | 141 | pred_tokens_string = "[" 142 | for el in pred_tokens: 143 | pred_tokens_string = pred_tokens_string + str(el) + " " 144 | pred_tokens_string += "]" 145 | f.write(f'predicted tokens (truncated): {pred_tokens_string}\n') 146 | f.write(f'################################################\n\n\n') 147 | 148 | sample_count += 1 149 | # statistics 150 | running_loss += loss.item() * input_embeddings_batch.size()[0] # batch loss 151 | 152 | 153 | epoch_loss = running_loss / dataset_sizes['test_set'] 154 | print('test loss: {:4f}'.format(epoch_loss)) 155 | 156 | 157 | print("Predicted outputs") 158 | """ calculate corpus bleu score """ 159 | weights_list = [(1.0,),(0.5,0.5),(1./3.,1./3.,1./3.),(0.25,0.25,0.25,0.25)] 160 | for weight in weights_list: 161 | corpus_bleu_score = corpus_bleu(target_tokens_list, pred_tokens_list, weights = weight) 162 | print(f'corpus BLEU-{len(list(weight))} score:', corpus_bleu_score) 163 | print() 164 | """ calculate rouge score """ 165 | rouge = Rouge() 166 | rouge_scores = rouge.get_scores(pred_string_list,target_string_list, avg = True) 167 | print(rouge_scores) 168 | print() 169 | """ calculate bertscore""" 170 | from bert_score import score 171 | P, R, F1 = score(pred_string_list,target_string_list, lang='en', device="cuda:0", model_type="bert-large-uncased") 172 | print(f"bert_score P: {np.mean(np.array(P))}") 173 | print(f"bert_score R: {np.mean(np.array(R))}") 174 | print(f"bert_score F1: {np.mean(np.array(F1))}") 175 | print("*************************************") 176 | 177 | if gpt: 178 | print() 179 | print("Refined outputs with GPT4") 180 | """ calculate corpus bleu score """ 181 | weights_list = [(1.0,),(0.5,0.5),(1./3.,1./3.,1./3.),(0.25,0.25,0.25,0.25)] 182 | for weight in weights_list: 183 | # print('weight:',weight) 184 | corpus_bleu_score = corpus_bleu(target_tokens_list, refine_tokens_list, weights = weight) 185 | print(f'corpus BLEU-{len(list(weight))} score:', corpus_bleu_score) 186 | print() 187 | """ calculate rouge score """ 188 | rouge = Rouge() 189 | rouge_scores = rouge.get_scores(refine_string_list,target_string_list, avg = True) 190 | print(rouge_scores) 191 | print() 192 | """ calculate bertscore""" 193 | from bert_score import score 194 | P, R, F1 = score(refine_string_list,target_string_list, lang='en', device="cuda:0", model_type="bert-large-uncased") 195 | print(f"bert_score P: {np.mean(np.array(P))}") 196 | print(f"bert_score R: {np.mean(np.array(R))}") 197 | print(f"bert_score F1: {np.mean(np.array(F1))}") 198 | print("*************************************") 199 | print("*************************************") 200 | 201 | 202 | if __name__ == '__main__': 203 | ''' get args''' 204 | args = get_config('eval_decoding') 205 | 206 | ''' load training config''' 207 | training_config = json.load(open(args['config_path'])) 208 | 209 | batch_size = 1 210 | 211 | subject_choice = training_config['subjects'] 212 | print(f'[INFO]subjects: {subject_choice}') 213 | eeg_type_choice = training_config['eeg_type'] 214 | print(f'[INFO]eeg type: {eeg_type_choice}') 215 | bands_choice = training_config['eeg_bands'] 216 | print(f'[INFO]using bands: {bands_choice}') 217 | 218 | 219 | 220 | dataset_setting = 'unique_sent' 221 | task_name = training_config['task_name'] 222 | model_name = training_config['model_name'] 223 | 224 | output_all_results_path = f'./results_raw/{task_name}-{model_name}-all_decoding_results.txt' 225 | ''' set random seeds ''' 226 | seed_val = 312 227 | np.random.seed(seed_val) 228 | torch.manual_seed(seed_val) 229 | torch.cuda.manual_seed_all(seed_val) 230 | 231 | ''' set up device ''' 232 | # use cuda 233 | if torch.cuda.is_available(): 234 | dev = args['cuda'] 235 | else: 236 | dev = "cpu" 237 | # CUDA_VISIBLE_DEVICES=0,1 238 | device = torch.device(dev) 239 | print(f'[INFO]using device {dev}') 240 | 241 | 242 | ''' set up dataloader ''' 243 | whole_dataset_dicts = [] 244 | if 'task1' in task_name: 245 | dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset_wRaw.pickle' 246 | with open(dataset_path_task1, 'rb') as handle: 247 | whole_dataset_dicts.append(pickle.load(handle)) 248 | if 'task2' in task_name: 249 | dataset_path_task2 = './dataset/ZuCo/task2-NR/pickle/task2-NR-dataset_wRaw.pickle' 250 | with open(dataset_path_task2, 'rb') as handle: 251 | whole_dataset_dicts.append(pickle.load(handle)) 252 | if 'task3' in task_name: 253 | dataset_path_task3 = './dataset/ZuCo/task3-TSR/pickle/task3-TSR-dataset_wRaw.pickle' 254 | with open(dataset_path_task3, 'rb') as handle: 255 | whole_dataset_dicts.append(pickle.load(handle)) 256 | if 'taskNRv2' in task_name: 257 | dataset_path_taskNRv2 = './dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset_wRaw.pickle' 258 | with open(dataset_path_taskNRv2, 'rb') as handle: 259 | whole_dataset_dicts.append(pickle.load(handle)) 260 | print() 261 | 262 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 263 | 264 | # test dataset 265 | test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, raweeg=True) 266 | 267 | dataset_sizes = {"test_set":len(test_set)} 268 | print('[INFO]test_set size: ', len(test_set)) 269 | 270 | # Allows to pad and get real size of eeg vectors 271 | def pad_and_sort_batch(data_loader_batch): 272 | """ 273 | data_loader_batch should be a list of (sequence, target, length) tuples... 274 | Returns a padded tensor of sequences sorted from longest to shortest, 275 | """ 276 | input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, word_contents, word_contents_attn, subject = tuple( 277 | zip(*data_loader_batch)) 278 | 279 | raw_eeg = [] 280 | input_raw_embeddings_lenghts = [] 281 | for sentence in input_raw_embeddings: 282 | input_raw_embeddings_lenghts.append( 283 | torch.Tensor([a.size(0) for a in sentence])) 284 | raw_eeg.append(pad_sequence( 285 | sentence, batch_first=True, padding_value=0).permute(1, 0, 2)) 286 | 287 | input_raw_embeddings = pad_sequence( 288 | raw_eeg, batch_first=True, padding_value=0).permute(0, 2, 1, 3) 289 | 290 | return input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, input_raw_embeddings_lenghts, word_contents, word_contents_attn, subject # lengths 291 | 292 | 293 | test_dataloader = DataLoader(test_set, batch_size = 1, shuffle=False, num_workers=4,collate_fn=pad_and_sort_batch) 294 | 295 | dataloaders = {'test':test_dataloader} 296 | 297 | ''' set up model ''' 298 | checkpoint_path = args['checkpoint_path'] 299 | pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large') 300 | 301 | if model_name == 'BrainTranslator': 302 | model = BrainTranslator(pretrained_bart, in_feature=1024, decoder_embedding_size=1024, 303 | additional_encoder_nhead=8, additional_encoder_dim_feedforward=4096) 304 | 305 | model.load_state_dict(torch.load(checkpoint_path)) 306 | model.to(device) 307 | 308 | criterion = nn.CrossEntropyLoss() 309 | 310 | ''' eval ''' 311 | eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = output_all_results_path) -------------------------------------------------------------------------------- /model_decoding_raw.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.utils.data 4 | 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | 7 | def cross_entropy(preds, targets, reduction='none'): 8 | log_softmax = nn.LogSoftmax(dim=-1) 9 | loss = (-targets * log_softmax(preds)).sum(1) 10 | if reduction == "none": 11 | return loss 12 | elif reduction == "mean": 13 | return loss.mean() 14 | 15 | class ProjectionHead(nn.Module): 16 | def __init__( 17 | self, 18 | embedding_dim, 19 | projection_dim=1024, 20 | dropout=0.1 21 | ): 22 | super().__init__() 23 | self.projection = nn.Linear(embedding_dim, projection_dim) 24 | self.gelu = nn.GELU() 25 | self.fc = nn.Linear(projection_dim, projection_dim) 26 | self.dropout = nn.Dropout(dropout) 27 | 28 | def forward(self, x): 29 | projected = self.projection(x) 30 | x = self.gelu(projected) 31 | x = self.fc(x) 32 | x = self.dropout(x) 33 | x = x + projected 34 | return x 35 | 36 | class BrainTranslator(nn.Module): 37 | def __init__(self, bart, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048): 38 | super(BrainTranslator, self).__init__() 39 | 40 | # Embedded EEG raw features 41 | self.hidden_dim=512 42 | self.feature_embedded = FeatureEmbedded(input_dim=104, hidden_dim=self.hidden_dim) 43 | self.fc = ProjectionHead(embedding_dim=in_feature,projection_dim=in_feature,dropout=0.1) #nn.Linear(in_feature, in_feature) 44 | 45 | # conv1d 46 | self.conv1d_point = nn.Conv1d(1, 64, 1, stride=1) 47 | 48 | SUBJECTS = ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH', 49 | 'YSD', 'YFS', 'YMD', 'YAC', 'YFR', 'YHS', 'YLS', 'YDG', 'YRH', 'YRK', 'YMS', 'YIS', 'YTL', 'YSL', 'YRP', 'YAG', 'YDR', 'YAK'] 50 | self.subjects_map_id = {} 51 | for i in range(len(SUBJECTS)): 52 | self.subjects_map_id[SUBJECTS[i]]=i 53 | 54 | # learnable subject matrices 55 | self.subject_matrices = [] 56 | for i in range(len(SUBJECTS)): 57 | self.subject_matrices.append( nn.Parameter(torch.randn(64, 1)) ) 58 | 59 | # Brain transformer encoder 60 | self.pos_embedding = nn.Parameter(torch.randn(1, 56, in_feature)) 61 | self.encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, dropout=0.1, activation="gelu", batch_first=True) 62 | self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=12) 63 | self.layernorm_embedding = nn.LayerNorm(in_feature, eps=1e-05) 64 | 65 | self.brain_projection = ProjectionHead(embedding_dim=in_feature,projection_dim=1024,dropout=0.2) 66 | 67 | # BART 68 | self.bart = bart 69 | 70 | def freeze_pretrained_bart(self): 71 | for name, param in self.named_parameters(): 72 | param.requires_grad = True 73 | if ('bart' in name): 74 | param.requires_grad = False 75 | 76 | def freeze_pretrained_brain(self): 77 | for name, param in self.named_parameters(): 78 | param.requires_grad = False 79 | if ('bart' in name): 80 | param.requires_grad = True 81 | 82 | def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, lenghts_words, word_contents_batch, word_contents_attn_batch, stepone, subject_batch, device, features=False): 83 | feature_embedding = self.feature_embedded(input_embeddings_batch, lenghts_words, device) 84 | if len(feature_embedding.shape)==2: 85 | feature_embedding = torch.unsqueeze(feature_embedding,0) 86 | encoded_embedding = self.fc(feature_embedding) 87 | 88 | # subject layer 89 | encoded_embedding_subject=[] 90 | for i in range(encoded_embedding.shape[0]): 91 | tmp = torch.unsqueeze(encoded_embedding[i,:,:],1) 92 | tmp = self.conv1d_point(tmp) 93 | tmp = torch.swapaxes(tmp, 1, 2) 94 | mat_subject = self.subject_matrices[self.subjects_map_id[subject_batch[i]]].to(device) 95 | tmp = torch.matmul(tmp, mat_subject) 96 | tmp = torch.squeeze(tmp) 97 | encoded_embedding_subject.append(tmp) 98 | if len(encoded_embedding_subject) == 1: 99 | encoded_embedding_subject = torch.unsqueeze(encoded_embedding_subject[0],0) 100 | else: 101 | encoded_embedding_subject = torch.stack(encoded_embedding_subject,0).to(device) 102 | 103 | brain_embedding = encoded_embedding_subject + self.pos_embedding 104 | brain_embedding = self.encoder(brain_embedding, src_key_padding_mask = input_masks_invert) 105 | brain_embedding = self.layernorm_embedding(brain_embedding) 106 | 107 | brain_embedding = self.brain_projection(brain_embedding) 108 | 109 | if stepone==True: 110 | words_embedding = self.bart.model.encoder.embed_tokens(word_contents_batch) 111 | loss = nn.MSELoss() 112 | return loss(brain_embedding, words_embedding) 113 | else: 114 | out = self.bart(inputs_embeds = brain_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted) 115 | if features==True: 116 | return out.logits, brain_embedding 117 | 118 | return out.logits 119 | 120 | 121 | class FeatureEmbedded(nn.Module): 122 | def __init__(self, input_dim=105, hidden_dim=512, num_layers=2, is_bidirectional=True): 123 | super(FeatureEmbedded, self).__init__() 124 | 125 | self.input_dim=input_dim 126 | self.hidden_dim=hidden_dim 127 | self.num_layers=num_layers 128 | self.is_bidirectional=is_bidirectional 129 | 130 | self.lstm = nn.GRU(input_size=self.input_dim, 131 | hidden_size=self.hidden_dim, 132 | num_layers=self.num_layers, 133 | batch_first=True, 134 | dropout=0.2, 135 | bidirectional=self.is_bidirectional 136 | ) 137 | for name, param in self.lstm.named_parameters(): 138 | if 'bias' in name: 139 | nn.init.constant_(param, 0.0) 140 | elif 'weight_ih' in name: 141 | nn.init.kaiming_normal_(param) 142 | elif 'weight_hh' in name: 143 | nn.init.orthogonal_(param) 144 | 145 | def forward(self, x, lenghts, device): 146 | 147 | sentence_embedding_batch=[] 148 | for x_sentence, lenghts_sentence in zip(x,lenghts): 149 | 150 | lstm_input = pack_padded_sequence(x_sentence, lenghts_sentence.cpu().numpy(), batch_first=True, enforce_sorted=False) 151 | lstm_outs, hidden = self.lstm(lstm_input) 152 | lstm_outs, _ = nn.utils.rnn.pad_packed_sequence(lstm_outs) 153 | 154 | if not self.is_bidirectional: 155 | sentence_embedding = [] 156 | for i in range(lenghts_sentence.shape[0]): 157 | sentence_embedding.append(lstm_outs[int(lenghts_sentence[i]-1),i,:]) 158 | sentence_embedding=torch.stack(sentence_embedding,0) #lstm_outs[-1] 159 | else: 160 | sentence_embedding = [] 161 | for i in range(lenghts_sentence.shape[0]): 162 | sentence_embedding.append(lstm_outs[int(lenghts_sentence[i]-1),i,:]) 163 | sentence_embedding=torch.stack(sentence_embedding,0) 164 | 165 | sentence_embedding_batch.append(sentence_embedding) 166 | 167 | return torch.squeeze(torch.stack(sentence_embedding_batch,0)).to(device) -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamzaamrani/EEG-to-Text-Decoding/0b8cf03b040e37f9e573fced95ad47667b194557/overview.png -------------------------------------------------------------------------------- /scripts/eval_decoding_raw.sh: -------------------------------------------------------------------------------- 1 | python3 eval_decoding_raw.py \ 2 | --checkpoint_path ./checkpoints/decoding_raw_104_h/last/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b4_25_25_5e-05_5e-05_unique_sent_best.pt \ 3 | --config_path ./config/decoding_raw/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b4_25_25_5e-05_5e-05_unique_sent.json \ 4 | -cuda cuda:0 5 | 6 | -------------------------------------------------------------------------------- /scripts/prepare_dataset_raw.sh: -------------------------------------------------------------------------------- 1 | echo "This scirpt construct .pickle files from .mat files from ZuCo dataset with raw EEG signals." 2 | echo "This script also generates tenary sentiment_labels.json file for ZuCo task1-SR v1.0 and ternary_dataset.json from filtered StanfordSentimentTreebank" 3 | echo "Note: the sentences in ZuCo task1-SR do not overlap with sentences in filtered StanfordSentimentTreebank " 4 | echo "Note: This process can take time, please be patient..." 5 | 6 | python3 ./util/construct_dataset_mat_to_pickle_v1_withRaw.py -t task1-SR 7 | python3 ./util/construct_dataset_mat_to_pickle_v1_withRaw.py -t task2-NR 8 | python3 ./util/construct_dataset_mat_to_pickle_v1_withRaw.py -t task3-TSR 9 | python3 ./util/construct_dataset_mat_to_pickle_v2_withRaw.py 10 | 11 | python3 ./util/get_sentiment_labels.py 12 | python3 ./util/get_SST_ternary_dataset.py 13 | -------------------------------------------------------------------------------- /scripts/train_decoding_raw.sh: -------------------------------------------------------------------------------- 1 | python3 train_decoding_raw.py --model_name BrainTranslator \ 2 | --task_name task1_task2_taskNRv2 \ 3 | --two_step \ 4 | --pretrained \ 5 | --not_load_step1_checkpoint \ 6 | --num_epoch_step1 25 \ 7 | --num_epoch_step2 25 \ 8 | -lr1 0.00005 \ 9 | -lr2 0.00005 \ 10 | -b 1\ 11 | -s ./checkpoints/decoding_raw \ 12 | -cuda cuda:0 -------------------------------------------------------------------------------- /train_decoding_raw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | from torch.optim import lr_scheduler 8 | from torch.utils.data import DataLoader 9 | import pickle 10 | import json 11 | from glob import glob 12 | import time 13 | from tqdm import tqdm 14 | from transformers import BartTokenizer, BartForConditionalGeneration, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification 15 | from transformers import MBartForConditionalGeneration, MBart50TokenizerFast 16 | 17 | from data_raw import ZuCo_dataset 18 | from model_decoding_raw import BrainTranslator 19 | from config import get_config 20 | 21 | from torch.nn.utils.rnn import pad_sequence 22 | 23 | from nltk.translate.bleu_score import corpus_bleu 24 | from rouge import Rouge 25 | from bert_score import score 26 | 27 | import warnings 28 | warnings.filterwarnings('ignore') 29 | from transformers import logging 30 | logging.set_verbosity_error() 31 | torch.autograd.set_detect_anomaly(True) 32 | 33 | from torch.utils.tensorboard import SummaryWriter 34 | LOG_DIR = "runs_h" 35 | train_writer = SummaryWriter(os.path.join(LOG_DIR, "train")) 36 | val_writer = SummaryWriter(os.path.join(LOG_DIR, "train_full")) 37 | dev_writer = SummaryWriter(os.path.join(LOG_DIR, "dev_full")) 38 | 39 | 40 | SUBJECTS = ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH', 41 | 'YSD', 'YFS', 'YMD', 'YAC', 'YFR', 'YHS', 'YLS', 'YDG', 'YRH', 'YRK', 'YMS', 'YIS', 'YTL', 'YSL', 'YRP', 'YAG', 'YDR', 'YAK'] 42 | 43 | def train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best='./checkpoints/decoding_raw/best/temp_decoding.pt', checkpoint_path_last='./checkpoints/decoding_raw/last/temp_decoding.pt', stepone=False): 44 | since = time.time() 45 | 46 | best_loss = 100000000000 47 | 48 | train_losses = [] 49 | val_losses = [] 50 | 51 | index_plot= 0 52 | index_plot_dev=0 53 | 54 | for epoch in range(num_epochs): 55 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 56 | print(f"lr: {scheduler.get_lr()}") 57 | print('-' * 10) 58 | 59 | # Each epoch has a training and validation phase 60 | for phase in ['train', 'dev', 'test']: 61 | if phase == 'train': 62 | model.train() # Set model to training mode 63 | else: 64 | model.eval() # Set model to evaluate mode 65 | 66 | running_loss = 0.0 67 | 68 | if phase == 'test': 69 | target_tokens_list = [] 70 | target_string_list = [] 71 | pred_tokens_list = [] 72 | pred_string_list = [] 73 | 74 | # Iterate over data. 75 | with tqdm(dataloaders[phase], unit="batch") as tepoch: 76 | for batch_idx, (_, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, input_raw_embeddings_lengths, word_contents, word_contents_attn, subject_batch) in enumerate(tepoch): 77 | 78 | # load in batch 79 | #input_embeddings_batch = torch.stack(input_embeddings_fbcsp).to(device).float() 80 | input_embeddings_batch = input_raw_embeddings.float().to(device) 81 | input_embeddings_lengths_batch = torch.stack([torch.tensor(a.clone().detach()) for a in input_raw_embeddings_lengths], 0).to(device) 82 | input_masks_batch = torch.stack(input_masks, 0).to(device) 83 | input_mask_invert_batch = torch.stack(input_mask_invert, 0).to(device) 84 | target_ids_batch = torch.stack(target_ids, 0).to(device) 85 | word_contents_batch = torch.stack(word_contents, 0).to(device) 86 | word_contents_attn_batch = torch.stack(word_contents_attn, 0).to(device) 87 | 88 | subject_batch = np.array(subject_batch) 89 | 90 | target_string_list_bertscore = [] 91 | 92 | if phase == 'test' and stepone==False: 93 | target_tokens = tokenizer.convert_ids_to_tokens( 94 | target_ids_batch[0].tolist(), skip_special_tokens=True) 95 | target_string = tokenizer.decode( 96 | target_ids_batch[0], skip_special_tokens=True) 97 | # add to list for later calculate metrics 98 | target_tokens_list.append([target_tokens]) 99 | target_string_list.append(target_string) 100 | 101 | # zero the parameter gradients 102 | optimizer.zero_grad() 103 | 104 | seq2seqLMoutput = model( 105 | input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch, input_embeddings_lengths_batch, word_contents_batch, word_contents_attn_batch, stepone, subject_batch, device) 106 | 107 | """replace padding ids in target_ids with -100""" 108 | target_ids_batch[target_ids_batch == 109 | tokenizer.pad_token_id] = -100 110 | 111 | """calculate loss""" 112 | if stepone==True: 113 | loss = seq2seqLMoutput 114 | else: 115 | loss = criterion(seq2seqLMoutput.permute(0, 2, 1), target_ids_batch.long()) 116 | 117 | if phase == 'test' and stepone==False: 118 | logits = seq2seqLMoutput 119 | probs = logits[0].softmax(dim=1) 120 | values, predictions = probs.topk(1) 121 | predictions = torch.squeeze(predictions) 122 | predicted_string = tokenizer.decode(predictions).split( 123 | '')[0].replace('', '') 124 | 125 | # convert to int list 126 | predictions = predictions.tolist() 127 | truncated_prediction = [] 128 | for t in predictions: 129 | if t != tokenizer.eos_token_id: 130 | truncated_prediction.append(t) 131 | else: 132 | break 133 | pred_tokens = tokenizer.convert_ids_to_tokens( 134 | truncated_prediction, skip_special_tokens=True) 135 | pred_tokens_list.append(pred_tokens) 136 | pred_string_list.append(predicted_string) 137 | 138 | # backward + optimize only if in training phase 139 | if phase == 'train': 140 | loss.backward() 141 | optimizer.step() 142 | 143 | # statistics 144 | running_loss += loss.item() * \ 145 | input_embeddings_batch.size()[0] # batch loss 146 | 147 | tepoch.set_postfix(loss=loss.item(), lr=scheduler.get_lr()) 148 | 149 | if phase == 'train': 150 | val_writer.add_scalar("train_full", loss.item(), index_plot) #(epoch+1)*batch_idx) 151 | index_plot=index_plot+1 152 | if phase == 'dev': 153 | dev_writer.add_scalar("dev_full", loss.item(), index_plot_dev) #(epoch+1)*batch_idx) 154 | index_plot_dev=index_plot_dev+1 155 | 156 | if phase == 'train': 157 | scheduler.step() 158 | 159 | epoch_loss = running_loss / dataset_sizes[phase] 160 | 161 | if phase == 'train': 162 | train_losses.append(epoch_loss) 163 | torch.save(model.state_dict(), checkpoint_path_last) 164 | elif phase == 'dev': 165 | val_losses.append(epoch_loss) 166 | 167 | if phase == 'train': 168 | train_losses.append(epoch_loss) 169 | train_epoch_loss = epoch_loss 170 | train_writer.add_scalar("train", epoch_loss, epoch) 171 | elif phase == 'dev': 172 | val_losses.append(epoch_loss) 173 | train_writer.add_scalar("val", epoch_loss, epoch) 174 | 175 | train_writer.add_scalars('loss train/val', { 176 | 'train': train_epoch_loss, 177 | 'val': epoch_loss, 178 | }, epoch) 179 | 180 | print('{} Loss: {:.4f}'.format(phase, epoch_loss)) 181 | 182 | # deep copy the model 183 | if phase == 'dev' and epoch_loss < best_loss: 184 | best_loss = epoch_loss 185 | '''save checkpoint''' 186 | torch.save(model.state_dict(), checkpoint_path_best) 187 | print(f'update best on dev checkpoint: {checkpoint_path_best}') 188 | 189 | if phase == 'test' and stepone==False: 190 | print("Evaluation on test") 191 | try: 192 | """ calculate corpus bleu score """ 193 | weights_list = [ 194 | (1.0,), (0.5, 0.5), (1./3., 1./3., 1./3.), (0.25, 0.25, 0.25, 0.25)] 195 | for weight in weights_list: 196 | # print('weight:',weight) 197 | corpus_bleu_score = corpus_bleu( 198 | target_tokens_list, pred_tokens_list, weights=weight) 199 | print( 200 | f'corpus BLEU-{len(list(weight))} score:', corpus_bleu_score) 201 | 202 | """ calculate rouge score """ 203 | rouge = Rouge() 204 | rouge_scores = rouge.get_scores( pred_string_list, target_string_list, avg=True, ignore_empty=True) 205 | print(rouge_scores) 206 | """ calculate bertscore""" 207 | P, R, F1 = score(pred_string_list,target_string_list, lang='en', device="cuda:0", model_type="bert-large-uncased") 208 | print(f"bert_score P: {np.mean(np.array(P))}") 209 | print(f"bert_score R: {np.mean(np.array(R))}") 210 | print(f"bert_score F1: {np.mean(np.array(F1))}") 211 | except: 212 | print("failed") 213 | 214 | print() 215 | 216 | print(f"Train losses: {train_losses}") 217 | print(f"Val losses: {val_losses}") 218 | 219 | time_elapsed = time.time() - since 220 | print('Training complete in {:.0f}m {:.0f}s'.format( 221 | time_elapsed // 60, time_elapsed % 60)) 222 | print('Best val loss: {:4f}'.format(best_loss)) 223 | torch.save(model.state_dict(), checkpoint_path_last) 224 | print(f'update last checkpoint: {checkpoint_path_last}') 225 | 226 | return model 227 | 228 | 229 | def show_require_grad_layers(model): 230 | print() 231 | print(' require_grad layers:') 232 | # sanity check 233 | for name, param in model.named_parameters(): 234 | if param.requires_grad: 235 | print(' ', name) 236 | 237 | 238 | if __name__ == '__main__': 239 | args = get_config('train_decoding') 240 | 241 | ''' config param''' 242 | dataset_setting = 'unique_sent' 243 | 244 | num_epochs_step1 = args['num_epoch_step1'] 245 | num_epochs_step2 = args['num_epoch_step2'] 246 | step1_lr = args['learning_rate_step1'] 247 | step2_lr = args['learning_rate_step2'] 248 | 249 | batch_size = args['batch_size'] 250 | 251 | model_name = args['model_name'] 252 | task_name = args['task_name'] 253 | 254 | save_path = args['save_path'] 255 | 256 | skip_step_one = args['skip_step_one'] 257 | load_step1_checkpoint = args['load_step1_checkpoint'] 258 | use_random_init = False 259 | 260 | if use_random_init and skip_step_one: 261 | step2_lr = 5*1e-4 262 | 263 | print(f'[INFO]using model: {model_name}') 264 | print(f'[INFO]using use_random_init: {use_random_init}') 265 | 266 | if skip_step_one: 267 | save_name = f'{task_name}_finetune_{model_name}_skipstep1_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}' 268 | else: 269 | save_name = f'{task_name}_finetune_{model_name}_2steptraining_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}' 270 | 271 | if use_random_init: 272 | save_name = 'randinit_' + save_name 273 | 274 | output_checkpoint_name_best = save_path + f'/best/{save_name}.pt' 275 | output_checkpoint_name_last = save_path + f'/last/{save_name}.pt' 276 | 277 | subject_choice = args['subjects'] 278 | print(f'![Debug]using {subject_choice}') 279 | eeg_type_choice = args['eeg_type'] 280 | print(f'[INFO]eeg type {eeg_type_choice}') 281 | bands_choice = args['eeg_bands'] 282 | print(f'[INFO]using bands {bands_choice}') 283 | 284 | ''' set random seeds ''' 285 | seed_val = 312 286 | np.random.seed(seed_val) 287 | torch.manual_seed(seed_val) 288 | torch.cuda.manual_seed_all(seed_val) 289 | 290 | ''' set up device ''' 291 | # use cuda 292 | if torch.cuda.is_available(): 293 | # dev = "cuda:3" 294 | dev = args['cuda'] 295 | else: 296 | dev = "cpu" 297 | device = torch.device(dev) 298 | print(f'[INFO]using device {dev}') 299 | print() 300 | 301 | ''' set up dataloader ''' 302 | whole_dataset_dicts = [] 303 | if 'task1' in task_name: 304 | dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset_wRaw.pickle' 305 | with open(dataset_path_task1, 'rb') as handle: 306 | whole_dataset_dicts.append(pickle.load(handle)) 307 | if 'task2' in task_name: 308 | dataset_path_task2 = './dataset/ZuCo/task2-NR/pickle/task2-NR-dataset_wRaw.pickle' 309 | with open(dataset_path_task2, 'rb') as handle: 310 | whole_dataset_dicts.append(pickle.load(handle)) 311 | if 'task3' in task_name: 312 | dataset_path_task3 = './dataset/ZuCo/task3-TSR/pickle/task3-TSR-dataset_wRaw.pickle' 313 | with open(dataset_path_task3, 'rb') as handle: 314 | whole_dataset_dicts.append(pickle.load(handle)) 315 | if 'taskNRv2' in task_name: 316 | dataset_path_taskNRv2 = './dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset_wRaw.pickle' 317 | with open(dataset_path_taskNRv2, 'rb') as handle: 318 | whole_dataset_dicts.append(pickle.load(handle)) 319 | 320 | print() 321 | 322 | """save config""" 323 | with open(f'./config/decoding_raw/{save_name}.json', 'w') as out_config: 324 | json.dump(args, out_config, indent=4) 325 | 326 | if model_name in ['BrainTranslator', 'BrainTranslatorNaive']: 327 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 328 | 329 | # train dataset 330 | train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject=subject_choice, 331 | eeg_type=eeg_type_choice, bands=bands_choice, setting=dataset_setting, raweeg=True) 332 | # dev dataset 333 | dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject=subject_choice, 334 | eeg_type=eeg_type_choice, bands=bands_choice, setting=dataset_setting, raweeg=True) 335 | # test dataset 336 | test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject=subject_choice, 337 | eeg_type=eeg_type_choice, bands=bands_choice, setting=dataset_setting, raweeg=True) 338 | 339 | dataset_sizes = {'train': len(train_set), 'dev': len( 340 | dev_set), 'test': len(test_set)} 341 | print('[INFO]train_set size: ', len(train_set)) 342 | print('[INFO]dev_set size: ', len(dev_set)) 343 | print('[INFO]test_set size: ', len(test_set)) 344 | 345 | # Allows to pad and get real size of eeg vectors 346 | def pad_and_sort_batch(data_loader_batch): 347 | """ 348 | data_loader_batch should be a list of (sequence, target, length) tuples... 349 | Returns a padded tensor of sequences sorted from longest to shortest, 350 | """ 351 | input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, word_contents, word_contents_attn, subject = tuple( 352 | zip(*data_loader_batch)) 353 | 354 | raw_eeg = [] 355 | input_raw_embeddings_lenghts = [] 356 | for sentence in input_raw_embeddings: 357 | input_raw_embeddings_lenghts.append( 358 | torch.Tensor([a.size(0) for a in sentence])) 359 | raw_eeg.append(pad_sequence( 360 | sentence, batch_first=True, padding_value=0).permute(1, 0, 2)) 361 | 362 | input_raw_embeddings = pad_sequence( 363 | raw_eeg, batch_first=True, padding_value=0).permute(0, 2, 1, 3) 364 | 365 | return input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, input_raw_embeddings_lenghts, word_contents, word_contents_attn, subject # lengths 366 | 367 | 368 | # train dataloader 369 | train_dataloader = DataLoader( 370 | train_set, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=pad_and_sort_batch) # 4 371 | # dev dataloader 372 | val_dataloader = DataLoader( 373 | dev_set, batch_size=1, shuffle=False, num_workers=0, collate_fn=pad_and_sort_batch) # 4 374 | # dev dataloader 375 | test_dataloader = DataLoader( 376 | test_set, batch_size=1, shuffle=False, num_workers=0, collate_fn=pad_and_sort_batch) # 4 377 | # dataloaders 378 | dataloaders = {'train': train_dataloader, 379 | 'dev': val_dataloader, 'test': test_dataloader} 380 | 381 | ''' set up model ''' 382 | if model_name == 'BrainTranslator': 383 | pretrained = BartForConditionalGeneration.from_pretrained( 384 | 'facebook/bart-large') 385 | 386 | model = BrainTranslator(pretrained, in_feature=1024, decoder_embedding_size=1024, 387 | additional_encoder_nhead=8, additional_encoder_dim_feedforward=4096) 388 | 389 | model.to(device) 390 | 391 | ''' training loop ''' 392 | 393 | ###################################################### 394 | '''step one trainig''' 395 | ###################################################### 396 | 397 | # closely follow BART paper 398 | if model_name in ['BrainTranslator']: 399 | for name, param in model.named_parameters(): 400 | if param.requires_grad and 'pretrained' in name: 401 | if ('shared' in name) or ('embed_positions' in name) or ('encoder.layers.0' in name): 402 | continue 403 | else: 404 | param.requires_grad = False 405 | 406 | if skip_step_one: 407 | if load_step1_checkpoint: 408 | stepone_checkpoint = 'path_to_step_1_checkpoint.pt' 409 | print(f'skip step one, load checkpoint: {stepone_checkpoint}') 410 | model.load_state_dict(torch.load(stepone_checkpoint)) 411 | else: 412 | print('skip step one, start from scratch at step two') 413 | else: 414 | model.to(device) 415 | 416 | ''' set up optimizer and scheduler''' 417 | optimizer_step1 = optim.SGD(filter( 418 | lambda p: p.requires_grad, model.parameters()), lr=step1_lr, momentum=0.9) 419 | 420 | exp_lr_scheduler_step1 = lr_scheduler.CyclicLR(optimizer_step1, 421 | base_lr = step1_lr, # Initial learning rate which is the lower boundary in the cycle for each parameter group 422 | max_lr = 5e-3, # Upper learning rate boundaries in the cycle for each parameter group 423 | mode = "triangular2") #triangular2 424 | 425 | ''' set up loss function ''' 426 | criterion = nn.MSELoss() 427 | model.freeze_pretrained_bart() 428 | 429 | print('=== start Step1 training ... ===') 430 | # print training layers 431 | show_require_grad_layers(model) 432 | # return best loss model from step1 training 433 | model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs_step1, 434 | checkpoint_path_best=output_checkpoint_name_best, checkpoint_path_last=output_checkpoint_name_last, stepone=True) 435 | 436 | train_writer.flush() 437 | train_writer.close() 438 | val_writer.flush() 439 | val_writer.close() 440 | dev_writer.flush() 441 | dev_writer.close() 442 | 443 | ###################################################### 444 | '''step two trainig''' 445 | ###################################################### 446 | 447 | #model.load_state_dict(torch.load("./checkpoints/decoding_raw_104_h/last/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b4_25_100_5e-05_5e-05_unique_sent.pt")) 448 | model.freeze_pretrained_brain() 449 | 450 | ''' set up optimizer and scheduler''' 451 | optimizer_step2 = optim.SGD(filter( 452 | lambda p: p.requires_grad, model.parameters()), lr=step2_lr, momentum=0.9) 453 | 454 | exp_lr_scheduler_step2 = lr_scheduler.CyclicLR(optimizer_step2, 455 | base_lr = 0.0000005, # Initial learning rate which is the lower boundary in the cycle for each parameter group 456 | max_lr = 0.00005, # Upper learning rate boundaries in the cycle for each parameter group 457 | mode = "triangular2") #triangular2''' 458 | 459 | ''' set up loss function ''' 460 | criterion = nn.CrossEntropyLoss() 461 | 462 | print() 463 | print('=== start Step2 training ... ===') 464 | # print training layers 465 | show_require_grad_layers(model) 466 | 467 | model.to(device) 468 | 469 | '''main loop''' 470 | trained_model = train_model(dataloaders, device, model, criterion, optimizer_step2, exp_lr_scheduler_step2, num_epochs=num_epochs_step2, 471 | checkpoint_path_best=output_checkpoint_name_best, checkpoint_path_last=output_checkpoint_name_last, stepone=False) 472 | 473 | train_writer.flush() 474 | train_writer.close() 475 | val_writer.flush() 476 | val_writer.close() 477 | dev_writer.flush() 478 | dev_writer.close() 479 | 480 | -------------------------------------------------------------------------------- /util/__pycache__/data_loading_helpers_modified.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamzaamrani/EEG-to-Text-Decoding/0b8cf03b040e37f9e573fced95ad47667b194557/util/__pycache__/data_loading_helpers_modified.cpython-38.pyc -------------------------------------------------------------------------------- /util/construct_dataset_mat_to_pickle_v1_withRaw.py: -------------------------------------------------------------------------------- 1 | import scipy.io as io 2 | import h5py 3 | import os 4 | import json 5 | from glob import glob 6 | from tqdm import tqdm 7 | import numpy as np 8 | import pickle 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description='Specify task name for converting ZuCo v1.0 Mat file to Pickle') 12 | parser.add_argument('-t', '--task_name', help='name of the task in /dataset/ZuCo, choose from {task1-SR,task2-NR,task3-TSR}', required=True) 13 | args = vars(parser.parse_args()) 14 | 15 | 16 | """config""" 17 | version = 'v1' # 'old' 18 | # version = 'v2' # 'new' 19 | 20 | task_name = args['task_name'] 21 | # task_name = 'task1-SR' 22 | # task_name = 'task2-NR' 23 | # task_name = 'task3-TSR' 24 | 25 | 26 | print('##############################') 27 | print(f'start processing ZuCo {task_name}...') 28 | 29 | 30 | if version == 'v1': 31 | # old version 32 | input_mat_files_dir = f'./dataset/ZuCo/{task_name}/Matlab_files' 33 | elif version == 'v2': 34 | # new version, mat73 35 | input_mat_files_dir = f'./dataset/ZuCo/{task_name}/Matlab_files' 36 | 37 | output_dir = f'./dataset/ZuCo/{task_name}/pickle' 38 | if not os.path.exists(output_dir): 39 | os.makedirs(output_dir) 40 | 41 | """load files""" 42 | mat_files = glob(os.path.join(input_mat_files_dir,'*.mat')) 43 | mat_files = sorted(mat_files) 44 | 45 | if len(mat_files) == 0: 46 | print(f'No mat files found for {task_name}') 47 | quit() 48 | 49 | dataset_dict = {} 50 | for mat_file in tqdm(mat_files): 51 | subject_name = os.path.basename(mat_file).split('_')[0].replace('results','').strip() 52 | dataset_dict[subject_name] = [] 53 | 54 | if version == 'v1': 55 | matdata = io.loadmat(mat_file, squeeze_me=True, struct_as_record=False)['sentenceData'] 56 | elif version == 'v2': 57 | matdata = h5py.File(mat_file,'r') 58 | print(matdata) 59 | 60 | for e,sent in enumerate(matdata): 61 | word_data = sent.word 62 | if not isinstance(word_data, float): 63 | # sentence level: 64 | sent_obj = {'content':sent.content} 65 | sent_obj['sentence_level_EEG'] = {'mean_t1':sent.mean_t1, 'mean_t2':sent.mean_t2, 'mean_a1':sent.mean_a1, 'mean_a2':sent.mean_a2, 'mean_b1':sent.mean_b1, 'mean_b2':sent.mean_b2, 'mean_g1':sent.mean_g1, 'mean_g2':sent.mean_g2} 66 | 67 | if task_name == 'task1-SR': 68 | sent_obj['answer_EEG'] = {'answer_mean_t1':sent.answer_mean_t1, 'answer_mean_t2':sent.answer_mean_t2, 'answer_mean_a1':sent.answer_mean_a1, 'answer_mean_a2':sent.answer_mean_a2, 'answer_mean_b1':sent.answer_mean_b1, 'answer_mean_b2':sent.answer_mean_b2, 'answer_mean_g1':sent.answer_mean_g1, 'answer_mean_g2':sent.answer_mean_g2} 69 | 70 | # word level: 71 | sent_obj['word'] = [] 72 | 73 | word_tokens_has_fixation = [] 74 | word_tokens_with_mask = [] 75 | word_tokens_all = [] 76 | print(sent.content) 77 | for word in word_data: 78 | word_obj = {'content':word.content} 79 | word_tokens_all.append(word.content) 80 | # TODO: add more version of word level eeg: GD, SFD, GPT 81 | word_obj['nFixations'] = word.nFixations 82 | if word.nFixations > 0: 83 | word_obj['word_level_EEG'] = {'FFD':{'FFD_t1':word.FFD_t1, 'FFD_t2':word.FFD_t2, 'FFD_a1':word.FFD_a1, 'FFD_a2':word.FFD_a2, 'FFD_b1':word.FFD_b1, 'FFD_b2':word.FFD_b2, 'FFD_g1':word.FFD_g1, 'FFD_g2':word.FFD_g2}} 84 | word_obj['word_level_EEG']['TRT'] = {'TRT_t1':word.TRT_t1, 'TRT_t2':word.TRT_t2, 'TRT_a1':word.TRT_a1, 'TRT_a2':word.TRT_a2, 'TRT_b1':word.TRT_b1, 'TRT_b2':word.TRT_b2, 'TRT_g1':word.TRT_g1, 'TRT_g2':word.TRT_g2} 85 | word_obj['word_level_EEG']['GD'] = {'GD_t1':word.GD_t1, 'GD_t2':word.GD_t2, 'GD_a1':word.GD_a1, 'GD_a2':word.GD_a2, 'GD_b1':word.GD_b1, 'GD_b2':word.GD_b2, 'GD_g1':word.GD_g1, 'GD_g2':word.GD_g2} 86 | sent_obj['word'].append(word_obj) 87 | word_tokens_has_fixation.append(word.content) 88 | word_tokens_with_mask.append(word.content) 89 | #print(word.rawEEG) 90 | if type(word.rawEEG) == np.ndarray: 91 | """print(e) 92 | print(f"content: {word.content}") 93 | print(f"shape: {word.rawEEG.shape}") 94 | print(f"size: {word.rawEEG.size}") 95 | print(f"ndim: {word.rawEEG.ndim}") 96 | print("nan:",np.isnan(word.rawEEG[0]).any())""" 97 | if not np.isnan(word.rawEEG[0]).any(): 98 | if word.rawEEG[0].ndim == 2: 99 | 100 | word_obj['rawEEG']=[np.float32(arr.transpose()) for arr in word.rawEEG if type(arr)==np.ndarray] 101 | elif word.rawEEG[0].ndim == 1: 102 | 103 | word_obj['rawEEG']=[word.rawEEG.transpose()] 104 | else: 105 | word_obj['rawEEG']=[] 106 | else: 107 | word_obj['rawEEG']=[] 108 | """ print("len:",len(word_obj['rawEEG'])) 109 | print("\n")""" 110 | 111 | #print(word_obj['rawEEG']) 112 | else: 113 | word_tokens_with_mask.append('[MASK]') 114 | # if a word has no fixation, use sentence level feature 115 | # word_obj['word_level_EEG'] = {'FFD':{'FFD_t1':sent.mean_t1, 'FFD_t2':sent.mean_t2, 'FFD_a1':sent.mean_a1, 'FFD_a2':sent.mean_a2, 'FFD_b1':sent.mean_b1, 'FFD_b2':sent.mean_b2, 'FFD_g1':sent.mean_g1, 'FFD_g2':sent.mean_g2}} 116 | # word_obj['word_level_EEG']['TRT'] = {'TRT_t1':sent.mean_t1, 'TRT_t2':sent.mean_t2, 'TRT_a1':sent.mean_a1, 'TRT_a2':sent.mean_a2, 'TRT_b1':sent.mean_b1, 'TRT_b2':sent.mean_b2, 'TRT_g1':sent.mean_g1, 'TRT_g2':sent.mean_g2} 117 | 118 | # NOTE:if a word has no fixation, simply skip it 119 | continue 120 | 121 | sent_obj['word_tokens_has_fixation'] = word_tokens_has_fixation 122 | sent_obj['word_tokens_with_mask'] = word_tokens_with_mask 123 | sent_obj['word_tokens_all'] = word_tokens_all 124 | 125 | dataset_dict[subject_name].append(sent_obj) 126 | 127 | else: 128 | print(f'missing sent: subj:{subject_name} content:{sent.content}, return None') 129 | dataset_dict[subject_name].append(None) 130 | 131 | continue 132 | # print(dataset_dict.keys()) 133 | # print(dataset_dict[subject_name][0].keys()) 134 | # print(dataset_dict[subject_name][0]['content']) 135 | # print(dataset_dict[subject_name][0]['word'][0].keys()) 136 | # print(dataset_dict[subject_name][0]['word'][0]['word_level_EEG']['FFD']) 137 | 138 | """output""" 139 | output_name = f'{task_name}-dataset_wRaw.pickle' 140 | # with open(os.path.join(output_dir,'task1-SR-dataset.json'), 'w') as out: 141 | # json.dump(dataset_dict,out,indent = 4) 142 | 143 | with open(os.path.join(output_dir,output_name), 'wb') as handle: 144 | pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 145 | print('write to:', os.path.join(output_dir,output_name)) 146 | 147 | 148 | """sanity check""" 149 | # check dataset 150 | with open(os.path.join(output_dir,output_name), 'rb') as handle: 151 | whole_dataset = pickle.load(handle) 152 | print('subjects:', whole_dataset.keys()) 153 | 154 | if version == 'v1': 155 | print('num of sent:', len(whole_dataset['ZAB'])) 156 | print() 157 | 158 | 159 | -------------------------------------------------------------------------------- /util/construct_dataset_mat_to_pickle_v2_withRaw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | import data_loading_helpers_modified as dh 5 | from glob import glob 6 | from tqdm import tqdm 7 | import pickle 8 | 9 | print("modified version of construct_dataset_mat_to_pickle_v2.py to include raw EEG data, fixed") 10 | 11 | task = "NR" 12 | 13 | rootdir = "./dataset/ZuCo/task2-NR-2.0/Matlab_files/" 14 | 15 | print('##############################') 16 | print(f'start processing ZuCo task2-NR-2.0...') 17 | 18 | dataset_dict = {} 19 | 20 | for file in tqdm(os.listdir(rootdir)): 21 | if file.endswith(task+".mat"): 22 | 23 | file_name = rootdir + file 24 | 25 | # print('file name:', file_name) 26 | subject = file_name.split("ts")[1].split("_")[0] 27 | # print('subject: ', subject) 28 | 29 | # exclude YMH due to incomplete data because of dyslexia 30 | if subject != 'YMH': 31 | assert subject not in dataset_dict 32 | dataset_dict[subject] = [] 33 | 34 | f = h5py.File(file_name,'r') 35 | # print('keys in f:', list(f.keys())) 36 | sentence_data = f['sentenceData'] 37 | # print('keys in sentence_data:', list(sentence_data.keys())) 38 | 39 | # sent level eeg 40 | # mean_t1 = np.squeeze(f[sentence_data['mean_t1'][0][0]][()]) 41 | mean_t1_objs = sentence_data['mean_t1'] 42 | mean_t2_objs = sentence_data['mean_t2'] 43 | mean_a1_objs = sentence_data['mean_a1'] 44 | mean_a2_objs = sentence_data['mean_a2'] 45 | mean_b1_objs = sentence_data['mean_b1'] 46 | mean_b2_objs = sentence_data['mean_b2'] 47 | mean_g1_objs = sentence_data['mean_g1'] 48 | mean_g2_objs = sentence_data['mean_g2'] 49 | 50 | rawData = sentence_data['rawData'] 51 | contentData = sentence_data['content'] 52 | # print('contentData shape:', contentData.shape, 'dtype:', contentData.dtype) 53 | omissionR = sentence_data['omissionRate'] 54 | wordData = sentence_data['word'] 55 | 56 | 57 | for idx in range(len(rawData)): 58 | # get sentence string 59 | obj_reference_content = contentData[idx][0] 60 | sent_string = dh.load_matlab_string(f[obj_reference_content]) 61 | # print('sentence string:', sent_string) 62 | 63 | sent_obj = {'content':sent_string} 64 | 65 | # get sentence level EEG 66 | sent_obj['sentence_level_EEG'] = { 67 | 'mean_t1':np.squeeze(f[mean_t1_objs[idx][0]][()]), 68 | 'mean_t2':np.squeeze(f[mean_t2_objs[idx][0]][()]), 69 | 'mean_a1':np.squeeze(f[mean_a1_objs[idx][0]][()]), 70 | 'mean_a2':np.squeeze(f[mean_a2_objs[idx][0]][()]), 71 | 'mean_b1':np.squeeze(f[mean_b1_objs[idx][0]][()]), 72 | 'mean_b2':np.squeeze(f[mean_b2_objs[idx][0]][()]), 73 | 'mean_g1':np.squeeze(f[mean_g1_objs[idx][0]][()]), 74 | 'mean_g2':np.squeeze(f[mean_g2_objs[idx][0]][()]) 75 | } 76 | # print(sent_obj) 77 | sent_obj['word'] = [] 78 | 79 | # get word level data 80 | word_data, word_tokens_all, word_tokens_has_fixation, word_tokens_with_mask = dh.extract_word_level_data(f, f[wordData[idx][0]]) 81 | 82 | if word_data == {}: 83 | print(f'missing sent: subj:{subject} content:{sent_string}, append None') 84 | dataset_dict[subject].append(None) 85 | continue 86 | elif len(word_tokens_all) == 0: 87 | print(f'no word level features: subj:{subject} content:{sent_string}, append None') 88 | dataset_dict[subject].append(None) 89 | continue 90 | 91 | else: 92 | for widx in range(len(word_data)): 93 | data_dict = word_data[widx] 94 | word_obj = {'content':data_dict['content'], 'nFixations': data_dict['nFix'], 'rawEEG':data_dict['RAW_EEG']} #@zavidos 95 | """if 'rawEEG' in data_dict: 96 | word_obj['rawEEG'] = data_dict['rawEEG'] 97 | else: 98 | word_obj['rawEEG'] = None""" 99 | if 'GD_EEG' in data_dict: 100 | # print('has fixation: ', data_dict['content']) 101 | gd = data_dict["GD_EEG"] 102 | ffd = data_dict["FFD_EEG"] 103 | trt = data_dict["TRT_EEG"] 104 | assert len(gd) == len(trt) == len(ffd) == 8 105 | word_obj['word_level_EEG'] = { 106 | 'GD':{'GD_t1':gd[0], 'GD_t2':gd[1], 'GD_a1':gd[2], 'GD_a2':gd[3], 'GD_b1':gd[4], 'GD_b2':gd[5], 'GD_g1':gd[6], 'GD_g2':gd[7]}, 107 | 'FFD':{'FFD_t1':ffd[0], 'FFD_t2':ffd[1], 'FFD_a1':ffd[2], 'FFD_a2':ffd[3], 'FFD_b1':ffd[4], 'FFD_b2':ffd[5], 'FFD_g1':ffd[6], 'FFD_g2':ffd[7]}, 108 | 'TRT':{'TRT_t1':trt[0], 'TRT_t2':trt[1], 'TRT_a1':trt[2], 'TRT_a2':trt[3], 'TRT_b1':trt[4], 'TRT_b2':trt[5], 'TRT_g1':trt[6], 'TRT_g2':trt[7]} 109 | } 110 | sent_obj['word'].append(word_obj) 111 | 112 | sent_obj['word_tokens_has_fixation'] = word_tokens_has_fixation 113 | sent_obj['word_tokens_with_mask'] = word_tokens_with_mask 114 | sent_obj['word_tokens_all'] = word_tokens_all 115 | 116 | # print(sent_obj.keys()) 117 | # print(len(sent_obj['word'])) 118 | # print(sent_obj['word'][0]) 119 | 120 | dataset_dict[subject].append(sent_obj) 121 | 122 | """output""" 123 | task_name = 'task2-NR-2.0' 124 | 125 | if dataset_dict == {}: 126 | print(f'No mat file found for {task_name}') 127 | quit() 128 | 129 | output_dir = f'./dataset/ZuCo/{task_name}/pickle' 130 | output_name = f'{task_name}-dataset_wRaw.pickle' 131 | # with open(os.path.join(output_dir,'task1-SR-dataset.json'), 'w') as out: 132 | # json.dump(dataset_dict,out,indent = 4) 133 | 134 | with open(os.path.join(output_dir,output_name), 'wb') as handle: 135 | pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) 136 | print('write to:', os.path.join(output_dir,output_name)) 137 | 138 | """sanity check""" 139 | print('subjects:', dataset_dict.keys()) 140 | print('num of sent:', len(dataset_dict['YAC'])) -------------------------------------------------------------------------------- /util/data_loading_helpers_modified.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | eeg_float_resolution=np.float16 5 | 6 | Alpha_ffd_names = ['FFD_a1', 'FFD_a1_diff', 'FFD_a2', 'FFD_a2_diff'] 7 | Beta_ffd_names = ['FFD_b1', 'FFD_b1_diff', 'FFD_b2', 'FFD_b2_diff'] 8 | Gamma_ffd_names = ['FFD_g1', 'FFD_g1_diff', 'FFD_g2', 'FFD_g2_diff'] 9 | Theta_ffd_names = ['FFD_t1', 'FFD_t1_diff', 'FFD_t2', 'FFD_t2_diff'] 10 | Alpha_gd_names = ['GD_a1', 'GD_a1_diff', 'GD_a2', 'GD_a2_diff'] 11 | Beta_gd_names = ['GD_b1', 'GD_b1_diff', 'GD_b2', 'GD_b2_diff'] 12 | Gamma_gd_names = ['GD_g1', 'GD_g1_diff', 'GD_g2', 'GD_g2_diff'] 13 | Theta_gd_names = ['GD_t1', 'GD_t1_diff', 'GD_t2', 'GD_t2_diff'] 14 | Alpha_gpt_names = ['GPT_a1', 'GPT_a1_diff', 'GPT_a2', 'GPT_a2_diff'] 15 | Beta_gpt_names = ['GPT_b1', 'GPT_b1_diff', 'GPT_b2', 'GPT_b2_diff'] 16 | Gamma_gpt_names = ['GPT_g1', 'GPT_g1_diff', 'GPT_g2', 'GPT_g2_diff'] 17 | Theta_gpt_names = ['GPT_t1', 'GPT_t1_diff', 'GPT_t2', 'GPT_t2_diff'] 18 | Alpha_sfd_names = ['SFD_a1', 'SFD_a1_diff', 'SFD_a2', 'SFD_a2_diff'] 19 | Beta_sfd_names = ['SFD_b1', 'SFD_b1_diff', 'SFD_b2', 'SFD_b2_diff'] 20 | Gamma_sfd_names = ['SFD_g1', 'SFD_g1_diff', 'SFD_g2', 'SFD_g2_diff'] 21 | Theta_sfd_names = ['SFD_t1', 'SFD_t1_diff', 'SFD_t2', 'SFD_t2_diff'] 22 | Alpha_trt_names = ['TRT_a1', 'TRT_a1_diff', 'TRT_a2', 'TRT_a2_diff'] 23 | Beta_trt_names = ['TRT_b1', 'TRT_b1_diff', 'TRT_b2', 'TRT_b2_diff'] 24 | Gamma_trt_names = ['TRT_g1', 'TRT_g1_diff', 'TRT_g2', 'TRT_g2_diff'] 25 | Theta_trt_names = ['TRT_t1', 'TRT_t1_diff', 'TRT_t2', 'TRT_t2_diff'] 26 | 27 | # IF YOU CHANGE THOSE YOU MUST ALSO CHANGE CONSTANTS 28 | Alpha_features = Alpha_ffd_names + Alpha_gd_names + Alpha_gpt_names + Alpha_trt_names# + Alpha_sfd_names 29 | Beta_features = Beta_ffd_names + Beta_gd_names + Beta_gpt_names + Beta_trt_names# + Beta_sfd_names 30 | Gamma_features = Gamma_ffd_names + Gamma_gd_names + Gamma_gpt_names + Gamma_trt_names# + Gamma_sfd_names 31 | Theta_features = Theta_ffd_names + Theta_gd_names + Theta_gpt_names + Theta_trt_names# + Theta_sfd_names 32 | # print(Alpha_features) 33 | 34 | # GD_EEG_feautres 35 | 36 | 37 | def extract_all_fixations(data_container, word_data_object, float_resolution = np.float16): 38 | """ 39 | Extracts all fixations from a word data object 40 | :param data_container: (h5py) Container of the whole data, h5py object 41 | :param word_data_object: (h5py) Container of fixation objects, h5py object 42 | :param float_resolution: (type) Resolution to which data re to be converted, used for data compression 43 | :return: 44 | fixations_data (list) Data arrays representing each fixation 45 | """ 46 | word_data = data_container[word_data_object] 47 | fixations_data = [] 48 | if len(word_data.shape) > 1: 49 | for fixation_idx in range(word_data.shape[0]): 50 | fixations_data.append(np.array(data_container[word_data[fixation_idx][0]]).astype(float_resolution)) 51 | return fixations_data 52 | 53 | 54 | def is_real_word(word): 55 | """ 56 | Check if the word is a real word 57 | :param word: (str) word string 58 | :return: 59 | is_word (bool) True if it is a real word 60 | """ 61 | is_word = re.search('[a-zA-Z0-9]', word) 62 | return is_word 63 | 64 | 65 | def load_matlab_string(matlab_extracted_object): 66 | """ 67 | Converts a string loaded from h5py into a python string 68 | :param matlab_extracted_object: (h5py) matlab string object 69 | :return: 70 | extracted_string (str) translated string 71 | """ 72 | extracted_string = u''.join(chr(c[0]) for c in matlab_extracted_object) 73 | return extracted_string 74 | 75 | 76 | def extract_word_level_data(data_container, word_objects, eeg_float_resolution = np.float16): 77 | """ 78 | Extracts word level data for a specific sentence 79 | :param data_container: (h5py) Container of the whole data, h5py object 80 | :param word_objects: (h5py) Container of all word data for a specific sentence 81 | :param eeg_float_resolution: (type) Resolution with which to save EEG, used for data compression 82 | :return: 83 | word_level_data (dict) Contains all word level data indexed by their index number in the sentence, 84 | together with the reading order, indexed by "word_reading_order" 85 | """ 86 | available_objects = list(word_objects) 87 | #print(available_objects) 88 | #print(len(available_objects)) 89 | # print('available_objects:', available_objects) 90 | 91 | if isinstance(available_objects[0], str): 92 | 93 | contentData = word_objects['content'] 94 | #fixations_order_per_word = [] 95 | if "rawEEG" in available_objects: 96 | 97 | rawData = word_objects['rawEEG'] 98 | etData = word_objects['rawET'] 99 | 100 | ffdData = word_objects['FFD'] 101 | gdData = word_objects['GD'] 102 | gptData = word_objects['GPT'] 103 | trtData = word_objects['TRT'] 104 | 105 | try: 106 | sfdData = word_objects['SFD'] 107 | except KeyError: 108 | print("no SFD!") 109 | sfdData = [] 110 | nFixData = word_objects['nFixations'] 111 | fixPositions = word_objects["fixPositions"] 112 | 113 | Alpha_features_data = [word_objects[feature] for feature in Alpha_features] 114 | Beta_features_data = [word_objects[feature] for feature in Beta_features] 115 | Gamma_features_data = [word_objects[feature] for feature in Gamma_features] 116 | Theta_features_data = [word_objects[feature] for feature in Theta_features] 117 | #### 118 | GD_EEG_features = [word_objects[feature] for feature in ['GD_t1','GD_t2','GD_a1','GD_a2','GD_b1','GD_b2','GD_g1','GD_g2']] 119 | FFD_EEG_features = [word_objects[feature] for feature in ['FFD_t1','FFD_t2','FFD_a1','FFD_a2','FFD_b1','FFD_b2','FFD_g1','FFD_g2']] 120 | TRT_EEG_features = [word_objects[feature] for feature in ['TRT_t1','TRT_t2','TRT_a1','TRT_a2','TRT_b1','TRT_b2','TRT_g1','TRT_g2']] 121 | #### 122 | assert len(contentData) == len(etData) == len(rawData), "different amounts of different data!!" 123 | 124 | zipped_data = zip(rawData, etData, contentData, ffdData, gdData, gptData, trtData, sfdData, nFixData, fixPositions) 125 | 126 | word_level_data = {} 127 | word_idx = 0 128 | 129 | word_tokens_has_fixation = [] 130 | word_tokens_with_mask = [] 131 | word_tokens_all = [] 132 | for raw_eegs_obj, ets_obj, word_obj, ffd, gd, gpt, trt, sfd, nFix, fixPos in zipped_data: 133 | word_string = load_matlab_string(data_container[word_obj[0]]) 134 | if is_real_word(word_string): 135 | data_dict = {} 136 | data_dict["RAW_EEG"] = extract_all_fixations(data_container, raw_eegs_obj[0], eeg_float_resolution) 137 | data_dict["RAW_ET"] = extract_all_fixations(data_container, ets_obj[0], np.float32) 138 | 139 | data_dict["FFD"] = data_container[ffd[0]][()][0, 0] if len(data_container[ffd[0]][()].shape) == 2 else None 140 | data_dict["GD"] = data_container[gd[0]][()][0, 0] if len(data_container[gd[0]][()].shape) == 2 else None 141 | data_dict["GPT"] = data_container[gpt[0]][()][0, 0] if len(data_container[gpt[0]][()].shape) == 2 else None 142 | data_dict["TRT"] = data_container[trt[0]][()][0, 0] if len(data_container[trt[0]][()].shape) == 2 else None 143 | data_dict["SFD"] = data_container[sfd[0]][()][0, 0] if len(data_container[sfd[0]][()].shape) == 2 else None 144 | data_dict["nFix"] = data_container[nFix[0]][()][0, 0] if len(data_container[nFix[0]][()].shape) == 2 else None 145 | 146 | #fixations_order_per_word.append(np.array(data_container[fixPos[0]])) 147 | 148 | #print([data_container[obj[word_idx][0]][()] for obj in Alpha_features_data]) 149 | 150 | 151 | data_dict["ALPHA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()] 152 | if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] 153 | for obj in Alpha_features_data], 0) 154 | 155 | data_dict["BETA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()] 156 | if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] 157 | for obj in Beta_features_data], 0) 158 | 159 | data_dict["GAMMA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()] 160 | if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] 161 | for obj in Gamma_features_data], 0) 162 | 163 | data_dict["THETA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()] 164 | if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] 165 | for obj in Theta_features_data], 0) 166 | 167 | 168 | 169 | 170 | data_dict["word_idx"] = word_idx 171 | data_dict["content"] = word_string 172 | #################################### 173 | word_tokens_all.append(word_string) 174 | if data_dict["nFix"] is not None: 175 | #################################### 176 | data_dict["GD_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in GD_EEG_features] 177 | data_dict["FFD_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in FFD_EEG_features] 178 | data_dict["TRT_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in TRT_EEG_features] 179 | #################################### 180 | word_tokens_has_fixation.append(word_string) 181 | word_tokens_with_mask.append(word_string) 182 | else: 183 | word_tokens_with_mask.append('[MASK]') 184 | 185 | 186 | word_level_data[word_idx] = data_dict 187 | word_idx += 1 188 | else: 189 | print(word_string + " is not a real word.") 190 | else: 191 | # If there are no word-level data it will be word embeddings alone 192 | word_level_data = {} 193 | word_idx = 0 194 | word_tokens_has_fixation = [] 195 | word_tokens_with_mask = [] 196 | word_tokens_all = [] 197 | 198 | for word_obj in contentData: 199 | word_string = load_matlab_string(data_container[word_obj[0]]) 200 | if is_real_word(word_string): 201 | data_dict = {} 202 | data_dict["RAW_EEG"] = [] 203 | data_dict["ICA_EEG"] = [] 204 | data_dict["RAW_ET"] = [] 205 | data_dict["FFD"] = None 206 | data_dict["GD"] = None 207 | data_dict["GPT"] = None 208 | data_dict["TRT"] = None 209 | data_dict["SFD"] = None 210 | data_dict["nFix"] = None 211 | data_dict["ALPHA_EEG"] = [] 212 | data_dict["BETA_EEG"] = [] 213 | data_dict["GAMMA_EEG"] = [] 214 | data_dict["THETA_EEG"] = [] 215 | 216 | data_dict["word_idx"] = word_idx 217 | data_dict["content"] = word_string 218 | word_level_data[word_idx] = data_dict 219 | word_idx += 1 220 | else: 221 | print(word_string + " is not a real word.") 222 | 223 | sentence = " ".join([load_matlab_string(data_container[word_obj[0]]) for word_obj in word_objects['content']]) 224 | #print("Only available objects for the sentence '{}' are {}.".format(sentence, available_objects)) 225 | #word_level_data["word_reading_order"] = extract_word_order_from_fixations(fixations_order_per_word) 226 | else: 227 | word_tokens_has_fixation = [] 228 | word_tokens_with_mask = [] 229 | word_tokens_all = [] 230 | word_level_data = {} 231 | return word_level_data, word_tokens_all, word_tokens_has_fixation, word_tokens_with_mask --------------------------------------------------------------------------------