├── README.md
├── architecture.png
├── config.py
├── data_raw.py
├── environment.yml
├── eval_decoding_raw.py
├── model_decoding_raw.py
├── overview.png
├── scripts
├── eval_decoding_raw.sh
├── prepare_dataset_raw.sh
└── train_decoding_raw.sh
├── train_decoding_raw.py
└── util
├── __pycache__
└── data_loading_helpers_modified.cpython-38.pyc
├── construct_dataset_mat_to_pickle_v1_withRaw.py
├── construct_dataset_mat_to_pickle_v2_withRaw.py
└── data_loading_helpers_modified.py
/README.md:
--------------------------------------------------------------------------------
1 | ---
2 |
18 |
19 |
20 |
21 |
22 | [](https://arxiv.org/abs/2312.09430)
23 |
27 |
28 |
29 |
30 |
31 |
32 | ## Abstract
33 |
34 | Previous research has demonstrated the potential of using pre-trained language models for decoding open vocabulary Electroencephalography (EEG) signals captured through a non-invasive Brain-Computer Interface (BCI). However, the impact of embedding EEG signals in the context of language models and the effect of subjectivity, remain unexplored, leading to uncertainty about the best approach to enhance decoding performance. Additionally, current evaluation metrics used to assess decoding effectiveness are predominantly syntactic and do not provide insights into the comprehensibility of the decoded output for human understanding. We present an end-to-end architecture for non-invasive brain recordings that brings modern representational learning approaches to neuroscience. Our proposal introduces the following innovations: 1) an end-to-end deep learning architecture for open vocabulary EEG decoding, incorporating a subject-dependent representation learning module for raw EEG encoding, a BART language model, and a GPT-4 sentence refinement module; 2) a more comprehensive sentence-level evaluation metric based on the BERTScore; 3) an ablation study that analyses the contributions of each module within our proposal, providing valuable insights for future research. We evaluate our approach on two publicly available datasets, ZuCo v1.0 and v2.0, comprising EEG recordings of 30 subjects engaged in natural reading tasks. Our model achieves a BLEU-1 score of 42.75%, a ROUGE-1-F of 33.28%, and a BERTScore-F of 53.86%, achieving an increment over the previous state-of-the-art by 1.40%, 2.59%, and 3.20%, respectively.
35 |
36 |
37 |
38 | ## Architecture
39 |
40 |

41 |
42 |
43 |
44 |
45 | ## Code
46 |
47 | This repo is based on the [EEG-to-Text](https://github.com/MikeWangWZHL/EEG-To-Text) repository.
48 |
49 | ### Download ZuCo datasets
50 | - Download ZuCo v1.0 'Matlab files' for 'task1-SR','task2-NR','task3-TSR' from https://osf.io/q3zws/files/ under 'OSF Storage' root,
51 | unzip and move all `.mat` files to `/dataset/ZuCo/task1-SR/Matlab_files`,`/dataset/ZuCo/task2-NR/Matlab_files`,`/dataset/ZuCo/task3-TSR/Matlab_files` respectively.
52 | - Download ZuCo v2.0 'Matlab files' for 'task1-NR' from https://osf.io/2urht/files/ under 'OSF Storage' root, unzip and move all `.mat` files to `/dataset/ZuCo/task2-NR-2.0/Matlab_files`.
53 |
54 | ### Preprocess datasets
55 | run `bash ./scripts/prepare_dataset_raw.sh` to preprocess `.mat` files and prepare sentiment labels.
56 |
57 | For each task, all `.mat` files will be converted into one `.pickle` file stored in `/dataset/ZuCo//-dataset.pickle`.
58 |
59 | ### Usage Example
60 | To train an EEG-To-Text decoding model, run `bash ./scripts/train_decoding_raw.sh`.
61 |
62 | To evaluate the trained EEG-To-Text decoding model from above, run `bash ./scripts/eval_decoding_raw.sh`.
63 |
64 |
65 |
66 |
67 | ## Citation
68 |
69 | ```
70 | @article{amrani2024deep,
71 | title={Deep Representation Learning for Open Vocabulary Electroencephalography-to-Text Decoding},
72 | author={Amrani, Hamza and Micucci, Daniela and Napoletano, Paolo},
73 | journal={IEEE Journal of Biomedical and Health Informatics},
74 | year={2024},
75 | publisher={IEEE}
76 | }
77 |
78 | ```
79 |
80 |
--------------------------------------------------------------------------------
/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamzaamrani/EEG-to-Text-Decoding/0b8cf03b040e37f9e573fced95ad47667b194557/architecture.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def str2bool(v):
4 | if isinstance(v, bool):
5 | return v
6 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
7 | return True
8 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
9 | return False
10 | else:
11 | raise argparse.ArgumentTypeError('Boolean value expected.')
12 |
13 | def get_config(case):
14 | if case == 'train_decoding':
15 | # args config for training EEG-To-Text decoder
16 | parser = argparse.ArgumentParser(description='Specify config args for training EEG-To-Text decoder')
17 |
18 | parser.add_argument('-m', '--model_name', help='choose from {BrainTranslator, BrainTranslatorNaive}', default = "BrainTranslator" ,required=True)
19 | parser.add_argument('-t', '--task_name', help='choose from {task1,task1_task2, task1_task2_task3,task1_task2_taskNRv2}', default = "task1", required=True)
20 |
21 | parser.add_argument('-1step', '--one_step', dest='skip_step_one', action='store_true')
22 | parser.add_argument('-2step', '--two_step', dest='skip_step_one', action='store_false')
23 |
24 | parser.add_argument('-pre', '--pretrained', dest='use_random_init', action='store_false')
25 | parser.add_argument('-rand', '--rand_init', dest='use_random_init', action='store_true')
26 |
27 | parser.add_argument('-load1', '--load_step1_checkpoint', dest='load_step1_checkpoint', action='store_true')
28 | parser.add_argument('-no-load1', '--not_load_step1_checkpoint', dest='load_step1_checkpoint', action='store_false')
29 |
30 | parser.add_argument('-ne1', '--num_epoch_step1', type = int, help='num_epoch_step1', default = 20, required=True)
31 | parser.add_argument('-ne2', '--num_epoch_step2', type = int, help='num_epoch_step2', default = 30, required=True)
32 | parser.add_argument('-lr1', '--learning_rate_step1', type = float, help='learning_rate_step1', default = 0.00005, required=True)
33 | parser.add_argument('-lr2', '--learning_rate_step2', type = float, help='learning_rate_step2', default = 0.0000005, required=True)
34 | parser.add_argument('-b', '--batch_size', type = int, help='batch_size', default = 32, required=True)
35 |
36 | parser.add_argument('-s', '--save_path', help='checkpoint save path', default = './checkpoints/decoding', required=True)
37 | parser.add_argument('-subj', '--subjects', help='use all subjects or specify a particular one', default = 'ALL', required=False)
38 | parser.add_argument('-eeg', '--eeg_type', help='choose from {GD, FFD, TRT}', default = 'GD', required=False)
39 | parser.add_argument('-band', '--eeg_bands', nargs='+', help='specify freqency bands', default = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] , required=False)
40 | parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')
41 |
42 | args = vars(parser.parse_args())
43 |
44 | elif case == 'eval_decoding':
45 | # args config for evaluating EEG-To-Text decoder
46 | parser = argparse.ArgumentParser(description='Specify config args for evaluate EEG-To-Text decoder')
47 | parser.add_argument('-checkpoint', '--checkpoint_path', help='specify model checkpoint' ,required=True)
48 | parser.add_argument('-conf', '--config_path', help='specify training config json' ,required=True)
49 | parser.add_argument('-cuda', '--cuda', help='specify cuda device name, e.g. cuda:0, cuda:1, etc', default = 'cuda:0')
50 | args = vars(parser.parse_args())
51 |
52 |
53 | return args
--------------------------------------------------------------------------------
/data_raw.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import pickle
4 | from torch.utils.data import Dataset
5 | import json
6 | import matplotlib.pyplot as plt
7 | from glob import glob
8 | from transformers import BartTokenizer, BertTokenizer
9 | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
10 |
11 |
12 | ZUCO_SENTIMENT_LABELS = json.load(open('./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json'))
13 | SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json'))
14 |
15 | from scipy.signal import butter, lfilter
16 | from scipy.signal import freqz
17 | def butter_bandpass_filter(signal, lowcut, highcut, fs=500, order=5):
18 | nyq = 0.5 * fs
19 | low = lowcut/nyq
20 | high = highcut/nyq
21 | b,a = butter(order, [low, high], btype='band')
22 | y = lfilter(b, a, signal, axis=-1)
23 |
24 | return torch.Tensor(y).float()
25 |
26 | def normalize_1d(input_tensor):
27 | # normalize a 1d tensor
28 | mean = torch.mean(input_tensor)
29 | std = torch.std(input_tensor)
30 | input_tensor = (input_tensor - mean)/std
31 | return input_tensor
32 |
33 | def get_input_sample(sent_obj, tokenizer, eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], max_len = 56, add_CLS_token = False, subj='unspecified',raw_eeg=False):
34 |
35 | def get_word_embedding_eeg_tensor(word_obj, eeg_type, bands):
36 | frequency_features = []
37 | content=word_obj['content']
38 | for band in bands:
39 | frequency_features.append(word_obj['word_level_EEG'][eeg_type][eeg_type+band])
40 | word_eeg_embedding = np.concatenate(frequency_features)
41 | if len(word_eeg_embedding) != 105*len(bands):
42 | print(f'expect word: {content} of subj: {subj} eeg embedding dim to be {105*len(bands)}, but got {len(word_eeg_embedding)}, return None')
43 | return None
44 | # assert len(word_eeg_embedding) == 105*len(bands)
45 | return_tensor = torch.from_numpy(word_eeg_embedding)
46 | return normalize_1d(return_tensor)
47 |
48 | def get_word_raweeg_tensor(word_obj):
49 | word_raw_eeg = word_obj['rawEEG'][0] #1000
50 | return_tensor = torch.from_numpy(word_raw_eeg)
51 | return return_tensor
52 |
53 | def get_sent_eeg(sent_obj, bands):
54 | sent_eeg_features = []
55 | for band in bands:
56 | key = 'mean'+band
57 | sent_eeg_features.append(sent_obj['sentence_level_EEG'][key])
58 | sent_eeg_embedding = np.concatenate(sent_eeg_features)
59 | assert len(sent_eeg_embedding) == 105*len(bands)
60 | return_tensor = torch.from_numpy(sent_eeg_embedding)
61 | return normalize_1d(return_tensor)
62 |
63 |
64 | if sent_obj is None:
65 | # print(f' - skip bad sentence')
66 | return None
67 |
68 | input_sample = {}
69 | # get target label
70 | target_string = sent_obj['content']
71 |
72 | target_tokenized = tokenizer(target_string, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True)
73 | input_sample['target_ids'] = target_tokenized['input_ids'][0]
74 |
75 |
76 | # get sentence level EEG features
77 | sent_level_eeg_tensor = get_sent_eeg(sent_obj, bands)
78 | if torch.isnan(sent_level_eeg_tensor).any():
79 | return None
80 | input_sample['sent_level_EEG'] = sent_level_eeg_tensor
81 |
82 | # get sentiment label
83 | # handle some wierd case
84 | if 'emp11111ty' in target_string:
85 | target_string = target_string.replace('emp11111ty','empty')
86 | if 'film.1' in target_string:
87 | target_string = target_string.replace('film.1','film.')
88 |
89 | if target_string in ZUCO_SENTIMENT_LABELS:
90 | input_sample['sentiment_label'] = torch.tensor(ZUCO_SENTIMENT_LABELS[target_string]+1) # 0:Negative, 1:Neutral, 2:Positive
91 | else:
92 | input_sample['sentiment_label'] = torch.tensor(-100)
93 |
94 | # get input embeddings
95 | word_embeddings = []
96 | word_raw_embeddings = []
97 | word_contents = []
98 |
99 | """add CLS token embedding at the front"""
100 | if add_CLS_token:
101 | word_embeddings.append(torch.ones(104*len(bands)))
102 |
103 |
104 | for word in sent_obj['word']:
105 | # add each word's EEG embedding as Tensors
106 | word_level_eeg_tensor = get_word_embedding_eeg_tensor(word, eeg_type, bands = bands)
107 | if raw_eeg:
108 | try:
109 | word_level_raw_eeg_tensor = get_word_raweeg_tensor(word)
110 | except:
111 | print('error in raw eeg')
112 | print(word['content'])
113 | print(sent_obj['content'])
114 | print()
115 | return None
116 | # check none, for v2 dataset
117 | if word_level_eeg_tensor is None:
118 | return None
119 | # check nan:
120 | if torch.isnan(word_level_eeg_tensor).any():
121 | # print()
122 | # print('[NaN ERROR] problem sent:',sent_obj['content'])
123 | # print('[NaN ERROR] problem word:',word['content'])
124 | # print('[NaN ERROR] problem word feature:',word_level_eeg_tensor)
125 | # print()
126 | return None
127 |
128 | word_contents.append(word['content'])
129 | word_embeddings.append(word_level_eeg_tensor)
130 |
131 | if raw_eeg:
132 | word_level_raw_eeg_tensor = word_level_raw_eeg_tensor[:,:104]
133 | word_raw_embeddings.append(word_level_raw_eeg_tensor)
134 |
135 | if len(word_embeddings)<1:
136 | return None
137 |
138 |
139 | # pad to max_len
140 | n_eeg_representations = len(word_embeddings)
141 | while len(word_embeddings) < max_len:
142 | # TODO: FBCSP
143 | word_embeddings.append(torch.zeros(105*len(bands)))
144 | if raw_eeg:
145 | word_raw_embeddings.append(torch.zeros(1,104))
146 |
147 | word_contents_tokenized = tokenizer(' '.join(word_contents), padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True)
148 |
149 | input_sample['word_contents'] = word_contents_tokenized['input_ids'][0]
150 | input_sample['word_contents_attn'] = word_contents_tokenized['attention_mask'][0] #bart
151 |
152 | input_sample['input_embeddings'] = torch.stack(word_embeddings) # max_len * (105*num_bands)
153 |
154 | if raw_eeg:
155 | input_sample['input_raw_embeddings'] = word_raw_embeddings
156 |
157 |
158 | # mask out padding tokens
159 | input_sample['input_attn_mask'] = torch.zeros(max_len) # 0 is masked out
160 |
161 | if add_CLS_token:
162 | input_sample['input_attn_mask'][:len(sent_obj['word'])+1] = torch.ones(len(sent_obj['word'])+1) # 1 is not masked
163 | else:
164 | input_sample['input_attn_mask'][:len(sent_obj['word'])] = torch.ones(len(sent_obj['word'])) # 1 is not masked
165 |
166 | # mask out padding tokens reverted: handle different use case: this is for pytorch transformers
167 | input_sample['input_attn_mask_invert'] = torch.ones(max_len) # 1 is masked out
168 |
169 | if add_CLS_token:
170 | input_sample['input_attn_mask_invert'][:len(sent_obj['word'])+1] = torch.zeros(len(sent_obj['word'])+1) # 0 is not masked
171 | else:
172 | input_sample['input_attn_mask_invert'][:len(sent_obj['word'])] = torch.zeros(len(sent_obj['word'])) # 0 is not masked
173 |
174 | # mask out target padding for computing cross entropy loss
175 | input_sample['target_mask'] = target_tokenized['attention_mask'][0]
176 | input_sample['seq_len'] = len(sent_obj['word'])
177 |
178 | # clean 0 length data
179 | if input_sample['seq_len'] == 0:
180 | print('discard length zero instance: ', target_string)
181 | return None
182 |
183 | # subject
184 | input_sample['subject']= subj
185 |
186 | return input_sample
187 |
188 | class ZuCo_dataset(Dataset):
189 | def __init__(self, input_dataset_dicts, phase, tokenizer, subject = 'ALL', eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'],raweeg=False, setting = 'unique_sent', is_add_CLS_token = False):
190 | self.inputs = []
191 | self.tokenizer = tokenizer
192 |
193 | if not isinstance(input_dataset_dicts,list):
194 | input_dataset_dicts = [input_dataset_dicts]
195 | print(f'[INFO]loading {len(input_dataset_dicts)} task datasets')
196 | for input_dataset_dict in input_dataset_dicts:
197 | if subject == 'ALL':
198 | subjects = list(input_dataset_dict.keys())
199 | print('[INFO]using subjects: ', subjects)
200 | else:
201 | subjects = [subject]
202 |
203 | total_num_sentence = len(input_dataset_dict[subjects[0]])
204 |
205 | train_divider = int(0.8*total_num_sentence)
206 | dev_divider = train_divider + int(0.1*total_num_sentence)
207 |
208 | print(f'train divider = {train_divider}')
209 | print(f'dev divider = {dev_divider}')
210 |
211 | if setting == 'unique_sent':
212 | # take first 80% as trainset, 10% as dev and 10% as test
213 | if phase == 'train':
214 | print('[INFO]initializing a train set...')
215 | for key in subjects:
216 | for i in range(train_divider):
217 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key,raw_eeg=raweeg)
218 | if input_sample is not None:
219 | input_sample['subject']=key
220 | self.inputs.append(input_sample)
221 | elif phase == 'dev':
222 | print('[INFO]initializing a dev set...')
223 | for key in subjects:
224 | for i in range(train_divider,dev_divider):
225 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key,raw_eeg=raweeg)
226 | if input_sample is not None:
227 | input_sample['subject']=key
228 | self.inputs.append(input_sample)
229 | elif phase == 'all':
230 | print('[INFO]initializing all dataset...')
231 | for key in subjects:
232 | for i in range(int(1*total_num_sentence)):
233 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key,raw_eeg=raweeg)
234 | if input_sample is not None:
235 | input_sample['subject']=key
236 | self.inputs.append(input_sample)
237 | elif phase == 'test':
238 | print('[INFO]initializing a test set...')
239 | for key in subjects:
240 | for i in range(dev_divider,total_num_sentence):
241 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key,raw_eeg=raweeg)
242 | if input_sample is not None:
243 | input_sample['subject']=key
244 | self.inputs.append(input_sample)
245 | elif setting == 'unique_subj':
246 | print('WARNING!!! only implemented for SR v1 dataset ')
247 | # subject ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW'] for train
248 | # subject ['ZMG'] for dev
249 | # subject ['ZPH'] for test
250 | if phase == 'train':
251 | print(f'[INFO]initializing a train set using {setting} setting...')
252 | for i in range(total_num_sentence):
253 | for key in ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH','ZKW']:
254 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key)
255 | if input_sample is not None:
256 | self.inputs.append(input_sample)
257 | if phase == 'dev':
258 | print(f'[INFO]initializing a dev set using {setting} setting...')
259 | for i in range(total_num_sentence):
260 | for key in ['ZMG']:
261 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key)
262 | if input_sample is not None:
263 | self.inputs.append(input_sample)
264 | if phase == 'test':
265 | print(f'[INFO]initializing a test set using {setting} setting...')
266 | for i in range(total_num_sentence):
267 | for key in ['ZPH']:
268 | input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token, subj=key)
269 | if input_sample is not None:
270 | self.inputs.append(input_sample)
271 | print('++ adding task to dataset, now we have:', len(self.inputs))
272 |
273 | #print('[INFO]input tensor size:', self.inputs[0]['input_embeddings'].size())
274 | #print()
275 |
276 | def __len__(self):
277 | return len(self.inputs)
278 |
279 | def __getitem__(self, idx):
280 | input_sample = self.inputs[idx]
281 | return (
282 | input_sample['input_embeddings'],
283 | input_sample['seq_len'],
284 | input_sample['input_attn_mask'],
285 | input_sample['input_attn_mask_invert'],
286 | input_sample['target_ids'],
287 | input_sample['target_mask'],
288 | input_sample['sentiment_label'],
289 | input_sample['sent_level_EEG'],
290 | input_sample['input_raw_embeddings'],
291 | input_sample['word_contents'],
292 | input_sample['word_contents_attn'],
293 | input_sample['subject']
294 | )
295 |
296 |
297 | '''sanity test'''
298 | if __name__ == '__main__':
299 |
300 | check_dataset = 'ZuCo'#'stanford_sentiment'
301 |
302 | if check_dataset == 'ZuCo':
303 | whole_dataset_dicts = []
304 |
305 | dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle'
306 | with open(dataset_path_task1, 'rb') as handle:
307 | whole_dataset_dicts.append(pickle.load(handle))
308 |
309 | dataset_path_task2 = './dataset/ZuCo/task2-NR/pickle/task2-NR-dataset.pickle'
310 | with open(dataset_path_task2, 'rb') as handle:
311 | whole_dataset_dicts.append(pickle.load(handle))
312 |
313 | dataset_path_task2_v2 = './dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset.pickle'
314 | with open(dataset_path_task2_v2, 'rb') as handle:
315 | whole_dataset_dicts.append(pickle.load(handle))
316 |
317 | print()
318 | for key in whole_dataset_dicts[0]:
319 | print(f'task2_v2, sentence num in {key}:',len(whole_dataset_dicts[0][key]))
320 | print()
321 |
322 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
323 | dataset_setting = 'unique_sent'
324 | subject_choice = 'ALL'
325 | print(f'![Debug]using {subject_choice}')
326 | eeg_type_choice = 'GD'
327 | print(f'[INFO]eeg type {eeg_type_choice}')
328 | bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2']
329 | print(f'[INFO]using bands {bands_choice}')
330 | train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, raweeg=True)
331 | dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, raweeg=True)
332 | test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, raweeg=True)
333 |
334 | print('trainset size:',len(train_set))
335 | print('devset size:',len(dev_set))
336 | print('testset size:',len(test_set))
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: EEGToTextOpenVoc
2 | channels:
3 | - pytorch
4 | - dglteam
5 | - nvidia
6 | - conda-forge
7 | - anaconda
8 | - intel
9 | - defaults
10 | dependencies:
11 | - _libgcc_mutex=0.1=main
12 | - _openmp_mutex=4.5=1_gnu
13 | - attrs=21.2.0=pyhd3eb1b0_0
14 | - autopep8=1.6.0=pyhd3eb1b0_1
15 | - blas=1.0=mkl
16 | - brotlipy=0.7.0=py38h27cfd23_1003
17 | - bzip2=1.0.8=h7b6447c_0
18 | - c-ares=1.17.1=h27cfd23_0
19 | - ca-certificates=2022.10.11=h06a4308_0
20 | - catalogue=1.0.0=py38_1
21 | - certifi=2022.9.24=py38h06a4308_0
22 | - cffi=1.14.6=py38h400218f_0
23 | - chardet=4.0.0=py38h06a4308_1003
24 | - click=8.0.1=pyhd3eb1b0_0
25 | - conda=4.10.3=py38h578d9bd_0
26 | - conda-package-handling=1.7.3=py38h27cfd23_1
27 | - cryptography=3.4.7=py38hd23ed53_0
28 | - cudatoolkit=11.1.74=h6bb024c_0
29 | - cycler=0.10.0=py38_0
30 | - cymem=2.0.5=py38h2531618_0
31 | - cython-blis=0.4.1=py38h7b6447c_1
32 | - dbus=1.13.18=hb2f20db_0
33 | - decorator=5.0.9=pyhd3eb1b0_0
34 | - dgl-cuda10.2=0.5.2=py38_0
35 | - expat=2.4.1=h2531618_2
36 | - ffmpeg=4.2.2=h20bf706_0
37 | - fontconfig=2.13.1=h6c09931_0
38 | - freetype=2.10.4=h5ab3b9f_0
39 | - glib=2.69.0=h5202010_0
40 | - gmp=6.2.1=h2531618_2
41 | - gnutls=3.6.15=he1e5248_0
42 | - gst-plugins-base=1.14.0=h8213a91_2
43 | - gstreamer=1.14.0=h28cd5cc_2
44 | - icu=58.2=he6710b0_3
45 | - idna=2.10=pyhd3eb1b0_0
46 | - importlib_metadata=3.10.0=hd3eb1b0_0
47 | - intel-openmp=2021.3.0=h06a4308_3350
48 | - jpeg=9b=h024ee3a_2
49 | - kiwisolver=1.3.1=py38h2531618_0
50 | - krb5=1.19.2=hac12032_0
51 | - lame=3.100=h7b6447c_0
52 | - lcms2=2.12=h3be6417_0
53 | - ld_impl_linux-64=2.35.1=h7274673_9
54 | - libcurl=7.78.0=h0b77cf5_0
55 | - libedit=3.1.20210216=h27cfd23_1
56 | - libev=4.33=h7b6447c_0
57 | - libffi=3.3=he6710b0_2
58 | - libgcc-ng=9.3.0=h5101ec6_17
59 | - libgfortran-ng=7.5.0=ha8ba4b0_17
60 | - libgfortran4=7.5.0=ha8ba4b0_17
61 | - libgomp=9.3.0=h5101ec6_17
62 | - libidn2=2.3.2=h7f8727e_0
63 | - libnghttp2=1.41.0=hf8bcb03_2
64 | - libopus=1.3.1=h7b6447c_0
65 | - libpng=1.6.37=hbc83047_0
66 | - libssh2=1.9.0=h1ba5d50_1
67 | - libstdcxx-ng=9.3.0=hd4cf53a_17
68 | - libtasn1=4.16.0=h27cfd23_0
69 | - libtiff=4.2.0=h85742a9_0
70 | - libunistring=0.9.10=h27cfd23_0
71 | - libuuid=1.0.3=h1bed415_2
72 | - libuv=1.40.0=h7b6447c_0
73 | - libvpx=1.7.0=h439df22_0
74 | - libwebp-base=1.2.0=h27cfd23_0
75 | - libxcb=1.14=h7b6447c_0
76 | - libxml2=2.9.12=h03d6c58_0
77 | - lz4-c=1.9.3=h295c915_1
78 | - matplotlib=3.3.2=0
79 | - matplotlib-base=3.3.2=py38h817c723_0
80 | - mkl=2020.2=256
81 | - mkl-service=2.3.0=py38he904b0f_0
82 | - mkl_fft=1.3.0=py38h54f3939_0
83 | - mkl_random=1.1.1=py38h0573a6f_0
84 | - murmurhash=1.0.5=py38h2531618_0
85 | - ncurses=6.2=he6710b0_1
86 | - nettle=3.7.3=hbbd107a_1
87 | - networkx=2.5=py_0
88 | - ninja=1.10.2=hff7bd54_1
89 | - olefile=0.46=py_0
90 | - openh264=2.1.0=hd408876_0
91 | - openjpeg=2.3.0=h05c96fa_1
92 | - openssl=1.1.1q=h7f8727e_0
93 | - pcre=8.45=h295c915_0
94 | - pip=21.0.1=py38h06a4308_0
95 | - plac=0.9.6=py38_1
96 | - preshed=3.0.5=py38h2531618_4
97 | - pycodestyle=2.10.0=py38h06a4308_0
98 | - pycosat=0.6.3=py38h7b6447c_1
99 | - pycparser=2.20=py_2
100 | - pycurl=7.43.0.6=py38h1ba5d50_0
101 | - pymongo=3.11.2=py38h2531618_0
102 | - pyopenssl=20.0.1=pyhd3eb1b0_1
103 | - pyparsing=2.4.7=pyhd3eb1b0_0
104 | - pyqt=5.9.2=py38h05f1152_4
105 | - pyrsistent=0.17.3=py38h7b6447c_0
106 | - pysocks=1.7.1=py38h06a4308_0
107 | - python=3.8.5=h7579374_1
108 | - python-dateutil=2.8.2=pyhd3eb1b0_0
109 | - python_abi=3.8=2_cp38
110 | - qt=5.9.7=h5867ecd_1
111 | - readline=8.1=h27cfd23_0
112 | - requests=2.25.1=pyhd3eb1b0_0
113 | - ruamel.yaml=0.17.10=py38h497a2fe_0
114 | - ruamel.yaml.clib=0.2.2=py38h497a2fe_2
115 | - ruamel_yaml=0.15.100=py38h27cfd23_0
116 | - setuptools=52.0.0=py38h06a4308_0
117 | - sip=4.19.13=py38he6710b0_0
118 | - six=1.16.0=pyhd3eb1b0_0
119 | - spacy=2.3.2=py38hfd86e86_0
120 | - sqlite=3.36.0=hc218d9a_0
121 | - srsly=1.0.5=py38h2531618_0
122 | - thinc=7.4.1=py38hfd86e86_0
123 | - threadpoolctl=2.2.0=pyhbf3da8f_0
124 | - tk=8.6.10=hbc83047_0
125 | - toml=0.10.2=pyhd3eb1b0_0
126 | - torchaudio=0.9.0=py38
127 | - urllib3=1.26.6=pyhd3eb1b0_1
128 | - wasabi=0.8.2=pyhd3eb1b0_0
129 | - wheel=0.37.0=pyhd3eb1b0_0
130 | - x264=1!157.20191217=h7b6447c_0
131 | - xz=5.2.5=h7b6447c_0
132 | - yaml=0.2.5=h7b6447c_0
133 | - zipp=3.5.0=pyhd3eb1b0_0
134 | - zlib=1.2.11=h7b6447c_3
135 | - zstd=1.4.9=haebb681_0
136 | - pip:
137 | - absl-py==1.4.0
138 | - accelerate==0.17.1
139 | - aiohttp==3.8.4
140 | - aiosignal==1.3.1
141 | - antlr4-python3-runtime==4.9.3
142 | - anyio==3.6.2
143 | - appdirs==1.4.4
144 | - argon2-cffi==21.3.0
145 | - argon2-cffi-bindings==21.2.0
146 | - arrow==1.2.3
147 | - asttokens==2.2.1
148 | - async-timeout==4.0.2
149 | - backcall==0.2.0
150 | - beautifulsoup4==4.11.2
151 | - bert-score==0.3.13
152 | - black==23.1.0
153 | - bleach==6.0.0
154 | - boto3==1.26.94
155 | - botocore==1.29.94
156 | - brevitas==0.8.0
157 | - cachetools==5.3.0
158 | - charset-normalizer==3.1.0
159 | - cloudpickle==2.2.1
160 | - cmake==3.26.0
161 | - comm==0.1.2
162 | - dataclasses-json==0.5.7
163 | - datasets==2.10.1
164 | - debugpy==1.6.6
165 | - defusedxml==0.7.1
166 | - dependencies==2.0.1
167 | - detectron2==0.6
168 | - dill==0.3.6
169 | - executing==1.2.0
170 | - fastjsonschema==2.16.2
171 | - fasttext==0.9.2
172 | - filelock==3.9.0
173 | - fqdn==1.5.1
174 | - frozenlist==1.3.3
175 | - fsspec==2023.3.0
176 | - fst-pso==1.8.1
177 | - funcy==1.18
178 | - future==0.18.3
179 | - fuzzy-match==0.0.1
180 | - fuzzytm==2.0.5
181 | - fvcore==0.1.5.post20221221
182 | - gensim==4.3.0
183 | - geojson==2.5.0
184 | - google-auth==2.16.0
185 | - google-auth-oauthlib==0.4.6
186 | - greenlet==2.0.2
187 | - grpcio==1.51.1
188 | - h5py==2.10.0
189 | - huggingface-hub==0.13.2
190 | - hydra-core==1.3.1
191 | - importlib-metadata==6.0.0
192 | - importlib-resources==5.10.2
193 | - inflect==6.0.2
194 | - inquirerpy==0.3.4
195 | - iopath==0.1.9
196 | - ipykernel==6.21.2
197 | - ipython==8.10.0
198 | - ipython-genutils==0.2.0
199 | - ipywidgets==8.0.4
200 | - isoduration==20.11.0
201 | - jedi==0.18.2
202 | - jinja2==3.1.2
203 | - jmespath==1.0.1
204 | - joblib==1.2.0
205 | - jsonpointer==2.3
206 | - jsonschema==4.17.3
207 | - jupyter==1.0.0
208 | - jupyter-client==8.0.2
209 | - jupyter-console==6.5.1
210 | - jupyter-core==5.2.0
211 | - jupyter-events==0.6.3
212 | - jupyter-server==2.3.0
213 | - jupyter-server-terminals==0.4.4
214 | - jupyterlab-pygments==0.2.2
215 | - jupyterlab-widgets==3.0.5
216 | - langchain==0.0.125
217 | - lit==15.0.7
218 | - mako==1.2.4
219 | - markdown==3.4.1
220 | - markupsafe==2.1.2
221 | - marshmallow==3.19.0
222 | - marshmallow-enum==1.5.1
223 | - matplotlib-inline==0.1.6
224 | - miniful==0.0.6
225 | - mistune==2.0.5
226 | - mpmath==1.3.0
227 | - multidict==6.0.4
228 | - multiprocess==0.70.14
229 | - mypy-extensions==1.0.0
230 | - nbclassic==0.5.1
231 | - nbclient==0.7.2
232 | - nbconvert==7.2.9
233 | - nbformat==5.7.3
234 | - nest-asyncio==1.5.6
235 | - nltk==3.8.1
236 | - notebook==6.5.2
237 | - notebook-shim==0.2.2
238 | - numexpr==2.8.4
239 | - numpy==1.23.5
240 | - nvidia-cublas-cu11==11.10.3.66
241 | - nvidia-cuda-cupti-cu11==11.7.101
242 | - nvidia-cuda-nvrtc-cu11==11.7.99
243 | - nvidia-cuda-runtime-cu11==11.7.99
244 | - nvidia-cudnn-cu11==8.5.0.96
245 | - nvidia-cufft-cu11==10.9.0.58
246 | - nvidia-curand-cu11==10.2.10.91
247 | - nvidia-cusolver-cu11==11.4.0.1
248 | - nvidia-cusparse-cu11==11.7.4.91
249 | - nvidia-nccl-cu11==2.14.3
250 | - nvidia-nvtx-cu11==11.7.91
251 | - oauthlib==3.2.2
252 | - omegaconf==2.3.0
253 | - openai==0.27.2
254 | - opt-einsum==3.3.0
255 | - packaging==23.0
256 | - pandas==1.5.3
257 | - pandocfilters==1.5.0
258 | - parso==0.8.3
259 | - pathspec==0.11.0
260 | - pexpect==4.8.0
261 | - pfzy==0.3.4
262 | - pickleshare==0.7.5
263 | - pillow==9.4.0
264 | - pkgutil-resolve-name==1.3.10
265 | - platformdirs==3.0.0
266 | - plotly==5.13.0
267 | - portalocker==2.7.0
268 | - powerlaw==1.5
269 | - prometheus-client==0.16.0
270 | - prompt-toolkit==3.0.36
271 | - protobuf==3.20.3
272 | - psutil==5.9.4
273 | - ptyprocess==0.7.0
274 | - pure-eval==0.2.2
275 | - pyarrow==11.0.0
276 | - pyasn1==0.4.8
277 | - pyasn1-modules==0.2.8
278 | - pybind11==2.10.4
279 | - pycocotools==2.0.6
280 | - pycuda==2022.2.2
281 | - pydantic==1.10.6
282 | - pyfume==0.2.25
283 | - pygments==2.14.0
284 | - pyldavis==3.3.1
285 | - pyowm==3.3.0
286 | - pyro-api==0.1.2
287 | - pyro-ppl==1.8.4
288 | - python-graphviz==0.20.1
289 | - python-json-logger==2.0.6
290 | - pytools==2022.1.14
291 | - pytorch-pretrained-bert==0.6.2
292 | - pytz==2022.7.1
293 | - pyyaml==6.0
294 | - pyzmq==25.0.0
295 | - qtconsole==5.4.0
296 | - qtpy==2.3.0
297 | - quant-cuda==0.0.0
298 | - regex==2022.10.31
299 | - requests-oauthlib==1.3.1
300 | - responses==0.18.0
301 | - rfc3339-validator==0.1.4
302 | - rfc3986-validator==0.1.1
303 | - rouge==1.0.1
304 | - rouge-score==0.1.2
305 | - rsa==4.9
306 | - s3transfer==0.6.0
307 | - sacremoses==0.0.53
308 | - safetensors==0.3.0
309 | - scikit-learn==1.2.1
310 | - scipy==1.10.1
311 | - send2trash==1.8.0
312 | - sentence-transformers==2.2.2
313 | - sentencepiece==0.1.97
314 | - simpful==2.10.0
315 | - sklearn==0.0.post1
316 | - smart-open==6.3.0
317 | - sniffio==1.3.0
318 | - soupsieve==2.4
319 | - sqlalchemy==1.4.47
320 | - stack-data==0.6.2
321 | - sympy==1.11.1
322 | - tabulate==0.9.0
323 | - tenacity==8.2.1
324 | - tensorboard==2.12.0
325 | - tensorboard-data-server==0.7.0
326 | - tensorboard-plugin-wit==1.8.1
327 | - termcolor==2.2.0
328 | - terminado==0.17.1
329 | - tinycss2==1.2.1
330 | - tokenizers==0.13.2
331 | - tomli==2.0.1
332 | - torch==1.13.1
333 | - torch-summary==1.4.5
334 | - torchcontrib==0.0.2
335 | - torchvision==0.15.1
336 | - torchviz==0.0.2
337 | - tornado==6.2
338 | - tqdm==4.65.0
339 | - traitlets==5.9.0
340 | - transformers==4.28.0.dev0
341 | - triton==2.0.0
342 | - typing-extensions==4.5.0
343 | - typing-inspect==0.8.0
344 | - uri-template==1.2.0
345 | - wcwidth==0.2.6
346 | - webcolors==1.12
347 | - webencodings==0.5.1
348 | - websocket-client==1.5.1
349 | - weightwatcher==0.7.1.5
350 | - werkzeug==2.2.3
351 | - widgetsnbextension==4.0.5
352 | - wordcloud==1.8.2.2
353 | - xxhash==3.2.0
354 | - yacs==0.1.8
355 | - yarl==1.8.2
356 | prefix: /home/hamza/miniconda3/envs/EEGToTextOpenVoc
357 |
--------------------------------------------------------------------------------
/eval_decoding_raw.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.data import DataLoader
6 | import pickle
7 | import json
8 | import matplotlib.pyplot as plt
9 | from glob import glob
10 |
11 | from transformers import BartTokenizer, BartForConditionalGeneration
12 | from data_raw import ZuCo_dataset
13 | from model_decoding_raw import BrainTranslator
14 | from nltk.translate.bleu_score import corpus_bleu
15 | from rouge import Rouge
16 | from config import get_config
17 |
18 | from torch.nn.utils.rnn import pad_sequence
19 |
20 |
21 | from langchain.chat_models import ChatOpenAI
22 | from langchain.schema import (
23 | HumanMessage,
24 | SystemMessage
25 | )
26 | os.environ["OPENAI_API_KEY"] = "" # GPT4 APIs key
27 |
28 |
29 | # LLMs: Get predictions from ChatGPT
30 | def chatgpt_refinement(corrupted_text):
31 | llm = ChatOpenAI(temperature=0.2, model_name="gpt-4", max_tokens=256)
32 |
33 | messages = [
34 | SystemMessage(content="As a text reconstructor, your task is to restore corrupted sentences to their original form while making minimum changes. You should adjust the spaces and punctuation marks as necessary. Do not introduce any additional information. If you are unable to reconstruct the text, respond with [False]."),
35 | HumanMessage(content=f"Reconstruct the following text: [{corrupted_text}].")
36 | ]
37 |
38 | output = llm(messages).content
39 | output = output.replace('[','').replace(']','')
40 |
41 | if len(output)<10 and 'False' in output:
42 | return corrupted_text
43 |
44 | return output
45 |
46 | def eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = './results_raw/temp.txt' ):
47 | # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
48 |
49 | gpt=False
50 |
51 | print("Saving to: ", output_all_results_path)
52 | model.eval() # Set model to evaluate mode
53 | running_loss = 0.0
54 |
55 | # Iterate over data.
56 | sample_count = 0
57 |
58 | target_tokens_list = []
59 | target_string_list = []
60 | pred_tokens_list = []
61 | pred_string_list = []
62 | refine_tokens_list = []
63 | refine_string_list = []
64 |
65 | with open(output_all_results_path,'w') as f:
66 | for _, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, input_raw_embeddings_lengths, word_contents, word_contents_attn, subject_batch in dataloaders['test']:
67 |
68 | # load in batch
69 | input_embeddings_batch = input_raw_embeddings.to(
70 | device).float()
71 | input_embeddings_lengths_batch = torch.stack([torch.tensor(
72 | a.clone().detach()) for a in input_raw_embeddings_lengths], 0).to(device)
73 | input_masks_batch = torch.stack(input_masks, 0).to(device)
74 | input_mask_invert_batch = torch.stack(
75 | input_mask_invert, 0).to(device)
76 | target_ids_batch = torch.stack(target_ids, 0).to(device)
77 | word_contents_batch = torch.stack(
78 | word_contents, 0).to(device)
79 | word_contents_attn_batch = torch.stack(
80 | word_contents_attn, 0).to(device)
81 |
82 | subject_batch = np.array(subject_batch)
83 |
84 | target_tokens = tokenizer.convert_ids_to_tokens(target_ids_batch[0].tolist(), skip_special_tokens = True)
85 | target_string = tokenizer.decode(target_ids_batch[0], skip_special_tokens = True)
86 |
87 |
88 | f.write(f'target string: {target_string}\n')
89 | target_tokens_string = "["
90 | for el in target_tokens:
91 | target_tokens_string = target_tokens_string + str(el) + " "
92 | target_tokens_string += "]"
93 | f.write(f'target tokens: {target_tokens_string}\n')
94 |
95 | # add to list for later calculate bleu metric
96 | target_tokens_list.append([target_tokens])
97 | target_string_list.append(target_string)
98 |
99 | """replace padding ids in target_ids with -100"""
100 | target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100
101 |
102 | # forward
103 | seq2seqLMoutput = model(
104 | input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch, input_embeddings_lengths_batch, word_contents_batch, word_contents_attn_batch, False, subject_batch, device)
105 |
106 | """calculate loss"""
107 | loss = criterion(seq2seqLMoutput.permute(0,2,1), target_ids_batch.long())
108 |
109 | # get predicted tokens
110 | logits = seq2seqLMoutput
111 | probs = logits[0].softmax(dim = 1)
112 | # print('probs size:', probs.size())
113 | values, predictions = probs.topk(1)
114 | # print('predictions before squeeze:',predictions.size())
115 | predictions = torch.squeeze(predictions)
116 | predicted_string = tokenizer.decode(predictions).split('')[0].replace('','')
117 | # print('predicted string:',predicted_string)
118 | f.write(f'predicted string: {predicted_string}\n')
119 |
120 | # convert to int list
121 | predictions = predictions.tolist()
122 | truncated_prediction = []
123 | for t in predictions:
124 | if t != tokenizer.eos_token_id:
125 | truncated_prediction.append(t)
126 | else:
127 | break
128 | pred_tokens = tokenizer.convert_ids_to_tokens(truncated_prediction, skip_special_tokens = True)
129 | # print('predicted tokens:',pred_tokens)
130 | pred_tokens_list.append(pred_tokens)
131 | pred_string_list.append(predicted_string)
132 |
133 | # chatgpt refinement and tokenizer decode
134 | if gpt:
135 | predicted_string_chatgpt = chatgpt_refinement(predicted_string).replace('\n','')
136 | f.write(f'refined string: {predicted_string_chatgpt}\n')
137 | refine_tokens_list.append(tokenizer.convert_ids_to_tokens(tokenizer(predicted_string_chatgpt)['input_ids'], skip_special_tokens=True))
138 | refine_string_list.append(predicted_string_chatgpt)
139 |
140 |
141 | pred_tokens_string = "["
142 | for el in pred_tokens:
143 | pred_tokens_string = pred_tokens_string + str(el) + " "
144 | pred_tokens_string += "]"
145 | f.write(f'predicted tokens (truncated): {pred_tokens_string}\n')
146 | f.write(f'################################################\n\n\n')
147 |
148 | sample_count += 1
149 | # statistics
150 | running_loss += loss.item() * input_embeddings_batch.size()[0] # batch loss
151 |
152 |
153 | epoch_loss = running_loss / dataset_sizes['test_set']
154 | print('test loss: {:4f}'.format(epoch_loss))
155 |
156 |
157 | print("Predicted outputs")
158 | """ calculate corpus bleu score """
159 | weights_list = [(1.0,),(0.5,0.5),(1./3.,1./3.,1./3.),(0.25,0.25,0.25,0.25)]
160 | for weight in weights_list:
161 | corpus_bleu_score = corpus_bleu(target_tokens_list, pred_tokens_list, weights = weight)
162 | print(f'corpus BLEU-{len(list(weight))} score:', corpus_bleu_score)
163 | print()
164 | """ calculate rouge score """
165 | rouge = Rouge()
166 | rouge_scores = rouge.get_scores(pred_string_list,target_string_list, avg = True)
167 | print(rouge_scores)
168 | print()
169 | """ calculate bertscore"""
170 | from bert_score import score
171 | P, R, F1 = score(pred_string_list,target_string_list, lang='en', device="cuda:0", model_type="bert-large-uncased")
172 | print(f"bert_score P: {np.mean(np.array(P))}")
173 | print(f"bert_score R: {np.mean(np.array(R))}")
174 | print(f"bert_score F1: {np.mean(np.array(F1))}")
175 | print("*************************************")
176 |
177 | if gpt:
178 | print()
179 | print("Refined outputs with GPT4")
180 | """ calculate corpus bleu score """
181 | weights_list = [(1.0,),(0.5,0.5),(1./3.,1./3.,1./3.),(0.25,0.25,0.25,0.25)]
182 | for weight in weights_list:
183 | # print('weight:',weight)
184 | corpus_bleu_score = corpus_bleu(target_tokens_list, refine_tokens_list, weights = weight)
185 | print(f'corpus BLEU-{len(list(weight))} score:', corpus_bleu_score)
186 | print()
187 | """ calculate rouge score """
188 | rouge = Rouge()
189 | rouge_scores = rouge.get_scores(refine_string_list,target_string_list, avg = True)
190 | print(rouge_scores)
191 | print()
192 | """ calculate bertscore"""
193 | from bert_score import score
194 | P, R, F1 = score(refine_string_list,target_string_list, lang='en', device="cuda:0", model_type="bert-large-uncased")
195 | print(f"bert_score P: {np.mean(np.array(P))}")
196 | print(f"bert_score R: {np.mean(np.array(R))}")
197 | print(f"bert_score F1: {np.mean(np.array(F1))}")
198 | print("*************************************")
199 | print("*************************************")
200 |
201 |
202 | if __name__ == '__main__':
203 | ''' get args'''
204 | args = get_config('eval_decoding')
205 |
206 | ''' load training config'''
207 | training_config = json.load(open(args['config_path']))
208 |
209 | batch_size = 1
210 |
211 | subject_choice = training_config['subjects']
212 | print(f'[INFO]subjects: {subject_choice}')
213 | eeg_type_choice = training_config['eeg_type']
214 | print(f'[INFO]eeg type: {eeg_type_choice}')
215 | bands_choice = training_config['eeg_bands']
216 | print(f'[INFO]using bands: {bands_choice}')
217 |
218 |
219 |
220 | dataset_setting = 'unique_sent'
221 | task_name = training_config['task_name']
222 | model_name = training_config['model_name']
223 |
224 | output_all_results_path = f'./results_raw/{task_name}-{model_name}-all_decoding_results.txt'
225 | ''' set random seeds '''
226 | seed_val = 312
227 | np.random.seed(seed_val)
228 | torch.manual_seed(seed_val)
229 | torch.cuda.manual_seed_all(seed_val)
230 |
231 | ''' set up device '''
232 | # use cuda
233 | if torch.cuda.is_available():
234 | dev = args['cuda']
235 | else:
236 | dev = "cpu"
237 | # CUDA_VISIBLE_DEVICES=0,1
238 | device = torch.device(dev)
239 | print(f'[INFO]using device {dev}')
240 |
241 |
242 | ''' set up dataloader '''
243 | whole_dataset_dicts = []
244 | if 'task1' in task_name:
245 | dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset_wRaw.pickle'
246 | with open(dataset_path_task1, 'rb') as handle:
247 | whole_dataset_dicts.append(pickle.load(handle))
248 | if 'task2' in task_name:
249 | dataset_path_task2 = './dataset/ZuCo/task2-NR/pickle/task2-NR-dataset_wRaw.pickle'
250 | with open(dataset_path_task2, 'rb') as handle:
251 | whole_dataset_dicts.append(pickle.load(handle))
252 | if 'task3' in task_name:
253 | dataset_path_task3 = './dataset/ZuCo/task3-TSR/pickle/task3-TSR-dataset_wRaw.pickle'
254 | with open(dataset_path_task3, 'rb') as handle:
255 | whole_dataset_dicts.append(pickle.load(handle))
256 | if 'taskNRv2' in task_name:
257 | dataset_path_taskNRv2 = './dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset_wRaw.pickle'
258 | with open(dataset_path_taskNRv2, 'rb') as handle:
259 | whole_dataset_dicts.append(pickle.load(handle))
260 | print()
261 |
262 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
263 |
264 | # test dataset
265 | test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting, raweeg=True)
266 |
267 | dataset_sizes = {"test_set":len(test_set)}
268 | print('[INFO]test_set size: ', len(test_set))
269 |
270 | # Allows to pad and get real size of eeg vectors
271 | def pad_and_sort_batch(data_loader_batch):
272 | """
273 | data_loader_batch should be a list of (sequence, target, length) tuples...
274 | Returns a padded tensor of sequences sorted from longest to shortest,
275 | """
276 | input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, word_contents, word_contents_attn, subject = tuple(
277 | zip(*data_loader_batch))
278 |
279 | raw_eeg = []
280 | input_raw_embeddings_lenghts = []
281 | for sentence in input_raw_embeddings:
282 | input_raw_embeddings_lenghts.append(
283 | torch.Tensor([a.size(0) for a in sentence]))
284 | raw_eeg.append(pad_sequence(
285 | sentence, batch_first=True, padding_value=0).permute(1, 0, 2))
286 |
287 | input_raw_embeddings = pad_sequence(
288 | raw_eeg, batch_first=True, padding_value=0).permute(0, 2, 1, 3)
289 |
290 | return input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, input_raw_embeddings_lenghts, word_contents, word_contents_attn, subject # lengths
291 |
292 |
293 | test_dataloader = DataLoader(test_set, batch_size = 1, shuffle=False, num_workers=4,collate_fn=pad_and_sort_batch)
294 |
295 | dataloaders = {'test':test_dataloader}
296 |
297 | ''' set up model '''
298 | checkpoint_path = args['checkpoint_path']
299 | pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
300 |
301 | if model_name == 'BrainTranslator':
302 | model = BrainTranslator(pretrained_bart, in_feature=1024, decoder_embedding_size=1024,
303 | additional_encoder_nhead=8, additional_encoder_dim_feedforward=4096)
304 |
305 | model.load_state_dict(torch.load(checkpoint_path))
306 | model.to(device)
307 |
308 | criterion = nn.CrossEntropyLoss()
309 |
310 | ''' eval '''
311 | eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = output_all_results_path)
--------------------------------------------------------------------------------
/model_decoding_raw.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch.utils.data
4 |
5 | from torch.nn.utils.rnn import pack_padded_sequence
6 |
7 | def cross_entropy(preds, targets, reduction='none'):
8 | log_softmax = nn.LogSoftmax(dim=-1)
9 | loss = (-targets * log_softmax(preds)).sum(1)
10 | if reduction == "none":
11 | return loss
12 | elif reduction == "mean":
13 | return loss.mean()
14 |
15 | class ProjectionHead(nn.Module):
16 | def __init__(
17 | self,
18 | embedding_dim,
19 | projection_dim=1024,
20 | dropout=0.1
21 | ):
22 | super().__init__()
23 | self.projection = nn.Linear(embedding_dim, projection_dim)
24 | self.gelu = nn.GELU()
25 | self.fc = nn.Linear(projection_dim, projection_dim)
26 | self.dropout = nn.Dropout(dropout)
27 |
28 | def forward(self, x):
29 | projected = self.projection(x)
30 | x = self.gelu(projected)
31 | x = self.fc(x)
32 | x = self.dropout(x)
33 | x = x + projected
34 | return x
35 |
36 | class BrainTranslator(nn.Module):
37 | def __init__(self, bart, in_feature = 840, decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):
38 | super(BrainTranslator, self).__init__()
39 |
40 | # Embedded EEG raw features
41 | self.hidden_dim=512
42 | self.feature_embedded = FeatureEmbedded(input_dim=104, hidden_dim=self.hidden_dim)
43 | self.fc = ProjectionHead(embedding_dim=in_feature,projection_dim=in_feature,dropout=0.1) #nn.Linear(in_feature, in_feature)
44 |
45 | # conv1d
46 | self.conv1d_point = nn.Conv1d(1, 64, 1, stride=1)
47 |
48 | SUBJECTS = ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH',
49 | 'YSD', 'YFS', 'YMD', 'YAC', 'YFR', 'YHS', 'YLS', 'YDG', 'YRH', 'YRK', 'YMS', 'YIS', 'YTL', 'YSL', 'YRP', 'YAG', 'YDR', 'YAK']
50 | self.subjects_map_id = {}
51 | for i in range(len(SUBJECTS)):
52 | self.subjects_map_id[SUBJECTS[i]]=i
53 |
54 | # learnable subject matrices
55 | self.subject_matrices = []
56 | for i in range(len(SUBJECTS)):
57 | self.subject_matrices.append( nn.Parameter(torch.randn(64, 1)) )
58 |
59 | # Brain transformer encoder
60 | self.pos_embedding = nn.Parameter(torch.randn(1, 56, in_feature))
61 | self.encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, dropout=0.1, activation="gelu", batch_first=True)
62 | self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=12)
63 | self.layernorm_embedding = nn.LayerNorm(in_feature, eps=1e-05)
64 |
65 | self.brain_projection = ProjectionHead(embedding_dim=in_feature,projection_dim=1024,dropout=0.2)
66 |
67 | # BART
68 | self.bart = bart
69 |
70 | def freeze_pretrained_bart(self):
71 | for name, param in self.named_parameters():
72 | param.requires_grad = True
73 | if ('bart' in name):
74 | param.requires_grad = False
75 |
76 | def freeze_pretrained_brain(self):
77 | for name, param in self.named_parameters():
78 | param.requires_grad = False
79 | if ('bart' in name):
80 | param.requires_grad = True
81 |
82 | def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, lenghts_words, word_contents_batch, word_contents_attn_batch, stepone, subject_batch, device, features=False):
83 | feature_embedding = self.feature_embedded(input_embeddings_batch, lenghts_words, device)
84 | if len(feature_embedding.shape)==2:
85 | feature_embedding = torch.unsqueeze(feature_embedding,0)
86 | encoded_embedding = self.fc(feature_embedding)
87 |
88 | # subject layer
89 | encoded_embedding_subject=[]
90 | for i in range(encoded_embedding.shape[0]):
91 | tmp = torch.unsqueeze(encoded_embedding[i,:,:],1)
92 | tmp = self.conv1d_point(tmp)
93 | tmp = torch.swapaxes(tmp, 1, 2)
94 | mat_subject = self.subject_matrices[self.subjects_map_id[subject_batch[i]]].to(device)
95 | tmp = torch.matmul(tmp, mat_subject)
96 | tmp = torch.squeeze(tmp)
97 | encoded_embedding_subject.append(tmp)
98 | if len(encoded_embedding_subject) == 1:
99 | encoded_embedding_subject = torch.unsqueeze(encoded_embedding_subject[0],0)
100 | else:
101 | encoded_embedding_subject = torch.stack(encoded_embedding_subject,0).to(device)
102 |
103 | brain_embedding = encoded_embedding_subject + self.pos_embedding
104 | brain_embedding = self.encoder(brain_embedding, src_key_padding_mask = input_masks_invert)
105 | brain_embedding = self.layernorm_embedding(brain_embedding)
106 |
107 | brain_embedding = self.brain_projection(brain_embedding)
108 |
109 | if stepone==True:
110 | words_embedding = self.bart.model.encoder.embed_tokens(word_contents_batch)
111 | loss = nn.MSELoss()
112 | return loss(brain_embedding, words_embedding)
113 | else:
114 | out = self.bart(inputs_embeds = brain_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted)
115 | if features==True:
116 | return out.logits, brain_embedding
117 |
118 | return out.logits
119 |
120 |
121 | class FeatureEmbedded(nn.Module):
122 | def __init__(self, input_dim=105, hidden_dim=512, num_layers=2, is_bidirectional=True):
123 | super(FeatureEmbedded, self).__init__()
124 |
125 | self.input_dim=input_dim
126 | self.hidden_dim=hidden_dim
127 | self.num_layers=num_layers
128 | self.is_bidirectional=is_bidirectional
129 |
130 | self.lstm = nn.GRU(input_size=self.input_dim,
131 | hidden_size=self.hidden_dim,
132 | num_layers=self.num_layers,
133 | batch_first=True,
134 | dropout=0.2,
135 | bidirectional=self.is_bidirectional
136 | )
137 | for name, param in self.lstm.named_parameters():
138 | if 'bias' in name:
139 | nn.init.constant_(param, 0.0)
140 | elif 'weight_ih' in name:
141 | nn.init.kaiming_normal_(param)
142 | elif 'weight_hh' in name:
143 | nn.init.orthogonal_(param)
144 |
145 | def forward(self, x, lenghts, device):
146 |
147 | sentence_embedding_batch=[]
148 | for x_sentence, lenghts_sentence in zip(x,lenghts):
149 |
150 | lstm_input = pack_padded_sequence(x_sentence, lenghts_sentence.cpu().numpy(), batch_first=True, enforce_sorted=False)
151 | lstm_outs, hidden = self.lstm(lstm_input)
152 | lstm_outs, _ = nn.utils.rnn.pad_packed_sequence(lstm_outs)
153 |
154 | if not self.is_bidirectional:
155 | sentence_embedding = []
156 | for i in range(lenghts_sentence.shape[0]):
157 | sentence_embedding.append(lstm_outs[int(lenghts_sentence[i]-1),i,:])
158 | sentence_embedding=torch.stack(sentence_embedding,0) #lstm_outs[-1]
159 | else:
160 | sentence_embedding = []
161 | for i in range(lenghts_sentence.shape[0]):
162 | sentence_embedding.append(lstm_outs[int(lenghts_sentence[i]-1),i,:])
163 | sentence_embedding=torch.stack(sentence_embedding,0)
164 |
165 | sentence_embedding_batch.append(sentence_embedding)
166 |
167 | return torch.squeeze(torch.stack(sentence_embedding_batch,0)).to(device)
--------------------------------------------------------------------------------
/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamzaamrani/EEG-to-Text-Decoding/0b8cf03b040e37f9e573fced95ad47667b194557/overview.png
--------------------------------------------------------------------------------
/scripts/eval_decoding_raw.sh:
--------------------------------------------------------------------------------
1 | python3 eval_decoding_raw.py \
2 | --checkpoint_path ./checkpoints/decoding_raw_104_h/last/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b4_25_25_5e-05_5e-05_unique_sent_best.pt \
3 | --config_path ./config/decoding_raw/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b4_25_25_5e-05_5e-05_unique_sent.json \
4 | -cuda cuda:0
5 |
6 |
--------------------------------------------------------------------------------
/scripts/prepare_dataset_raw.sh:
--------------------------------------------------------------------------------
1 | echo "This scirpt construct .pickle files from .mat files from ZuCo dataset with raw EEG signals."
2 | echo "This script also generates tenary sentiment_labels.json file for ZuCo task1-SR v1.0 and ternary_dataset.json from filtered StanfordSentimentTreebank"
3 | echo "Note: the sentences in ZuCo task1-SR do not overlap with sentences in filtered StanfordSentimentTreebank "
4 | echo "Note: This process can take time, please be patient..."
5 |
6 | python3 ./util/construct_dataset_mat_to_pickle_v1_withRaw.py -t task1-SR
7 | python3 ./util/construct_dataset_mat_to_pickle_v1_withRaw.py -t task2-NR
8 | python3 ./util/construct_dataset_mat_to_pickle_v1_withRaw.py -t task3-TSR
9 | python3 ./util/construct_dataset_mat_to_pickle_v2_withRaw.py
10 |
11 | python3 ./util/get_sentiment_labels.py
12 | python3 ./util/get_SST_ternary_dataset.py
13 |
--------------------------------------------------------------------------------
/scripts/train_decoding_raw.sh:
--------------------------------------------------------------------------------
1 | python3 train_decoding_raw.py --model_name BrainTranslator \
2 | --task_name task1_task2_taskNRv2 \
3 | --two_step \
4 | --pretrained \
5 | --not_load_step1_checkpoint \
6 | --num_epoch_step1 25 \
7 | --num_epoch_step2 25 \
8 | -lr1 0.00005 \
9 | -lr2 0.00005 \
10 | -b 1\
11 | -s ./checkpoints/decoding_raw \
12 | -cuda cuda:0
--------------------------------------------------------------------------------
/train_decoding_raw.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 | from torch.optim import lr_scheduler
8 | from torch.utils.data import DataLoader
9 | import pickle
10 | import json
11 | from glob import glob
12 | import time
13 | from tqdm import tqdm
14 | from transformers import BartTokenizer, BartForConditionalGeneration, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification
15 | from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
16 |
17 | from data_raw import ZuCo_dataset
18 | from model_decoding_raw import BrainTranslator
19 | from config import get_config
20 |
21 | from torch.nn.utils.rnn import pad_sequence
22 |
23 | from nltk.translate.bleu_score import corpus_bleu
24 | from rouge import Rouge
25 | from bert_score import score
26 |
27 | import warnings
28 | warnings.filterwarnings('ignore')
29 | from transformers import logging
30 | logging.set_verbosity_error()
31 | torch.autograd.set_detect_anomaly(True)
32 |
33 | from torch.utils.tensorboard import SummaryWriter
34 | LOG_DIR = "runs_h"
35 | train_writer = SummaryWriter(os.path.join(LOG_DIR, "train"))
36 | val_writer = SummaryWriter(os.path.join(LOG_DIR, "train_full"))
37 | dev_writer = SummaryWriter(os.path.join(LOG_DIR, "dev_full"))
38 |
39 |
40 | SUBJECTS = ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH',
41 | 'YSD', 'YFS', 'YMD', 'YAC', 'YFR', 'YHS', 'YLS', 'YDG', 'YRH', 'YRK', 'YMS', 'YIS', 'YTL', 'YSL', 'YRP', 'YAG', 'YDR', 'YAK']
42 |
43 | def train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best='./checkpoints/decoding_raw/best/temp_decoding.pt', checkpoint_path_last='./checkpoints/decoding_raw/last/temp_decoding.pt', stepone=False):
44 | since = time.time()
45 |
46 | best_loss = 100000000000
47 |
48 | train_losses = []
49 | val_losses = []
50 |
51 | index_plot= 0
52 | index_plot_dev=0
53 |
54 | for epoch in range(num_epochs):
55 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
56 | print(f"lr: {scheduler.get_lr()}")
57 | print('-' * 10)
58 |
59 | # Each epoch has a training and validation phase
60 | for phase in ['train', 'dev', 'test']:
61 | if phase == 'train':
62 | model.train() # Set model to training mode
63 | else:
64 | model.eval() # Set model to evaluate mode
65 |
66 | running_loss = 0.0
67 |
68 | if phase == 'test':
69 | target_tokens_list = []
70 | target_string_list = []
71 | pred_tokens_list = []
72 | pred_string_list = []
73 |
74 | # Iterate over data.
75 | with tqdm(dataloaders[phase], unit="batch") as tepoch:
76 | for batch_idx, (_, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, input_raw_embeddings_lengths, word_contents, word_contents_attn, subject_batch) in enumerate(tepoch):
77 |
78 | # load in batch
79 | #input_embeddings_batch = torch.stack(input_embeddings_fbcsp).to(device).float()
80 | input_embeddings_batch = input_raw_embeddings.float().to(device)
81 | input_embeddings_lengths_batch = torch.stack([torch.tensor(a.clone().detach()) for a in input_raw_embeddings_lengths], 0).to(device)
82 | input_masks_batch = torch.stack(input_masks, 0).to(device)
83 | input_mask_invert_batch = torch.stack(input_mask_invert, 0).to(device)
84 | target_ids_batch = torch.stack(target_ids, 0).to(device)
85 | word_contents_batch = torch.stack(word_contents, 0).to(device)
86 | word_contents_attn_batch = torch.stack(word_contents_attn, 0).to(device)
87 |
88 | subject_batch = np.array(subject_batch)
89 |
90 | target_string_list_bertscore = []
91 |
92 | if phase == 'test' and stepone==False:
93 | target_tokens = tokenizer.convert_ids_to_tokens(
94 | target_ids_batch[0].tolist(), skip_special_tokens=True)
95 | target_string = tokenizer.decode(
96 | target_ids_batch[0], skip_special_tokens=True)
97 | # add to list for later calculate metrics
98 | target_tokens_list.append([target_tokens])
99 | target_string_list.append(target_string)
100 |
101 | # zero the parameter gradients
102 | optimizer.zero_grad()
103 |
104 | seq2seqLMoutput = model(
105 | input_embeddings_batch, input_masks_batch, input_mask_invert_batch, target_ids_batch, input_embeddings_lengths_batch, word_contents_batch, word_contents_attn_batch, stepone, subject_batch, device)
106 |
107 | """replace padding ids in target_ids with -100"""
108 | target_ids_batch[target_ids_batch ==
109 | tokenizer.pad_token_id] = -100
110 |
111 | """calculate loss"""
112 | if stepone==True:
113 | loss = seq2seqLMoutput
114 | else:
115 | loss = criterion(seq2seqLMoutput.permute(0, 2, 1), target_ids_batch.long())
116 |
117 | if phase == 'test' and stepone==False:
118 | logits = seq2seqLMoutput
119 | probs = logits[0].softmax(dim=1)
120 | values, predictions = probs.topk(1)
121 | predictions = torch.squeeze(predictions)
122 | predicted_string = tokenizer.decode(predictions).split(
123 | '')[0].replace('', '')
124 |
125 | # convert to int list
126 | predictions = predictions.tolist()
127 | truncated_prediction = []
128 | for t in predictions:
129 | if t != tokenizer.eos_token_id:
130 | truncated_prediction.append(t)
131 | else:
132 | break
133 | pred_tokens = tokenizer.convert_ids_to_tokens(
134 | truncated_prediction, skip_special_tokens=True)
135 | pred_tokens_list.append(pred_tokens)
136 | pred_string_list.append(predicted_string)
137 |
138 | # backward + optimize only if in training phase
139 | if phase == 'train':
140 | loss.backward()
141 | optimizer.step()
142 |
143 | # statistics
144 | running_loss += loss.item() * \
145 | input_embeddings_batch.size()[0] # batch loss
146 |
147 | tepoch.set_postfix(loss=loss.item(), lr=scheduler.get_lr())
148 |
149 | if phase == 'train':
150 | val_writer.add_scalar("train_full", loss.item(), index_plot) #(epoch+1)*batch_idx)
151 | index_plot=index_plot+1
152 | if phase == 'dev':
153 | dev_writer.add_scalar("dev_full", loss.item(), index_plot_dev) #(epoch+1)*batch_idx)
154 | index_plot_dev=index_plot_dev+1
155 |
156 | if phase == 'train':
157 | scheduler.step()
158 |
159 | epoch_loss = running_loss / dataset_sizes[phase]
160 |
161 | if phase == 'train':
162 | train_losses.append(epoch_loss)
163 | torch.save(model.state_dict(), checkpoint_path_last)
164 | elif phase == 'dev':
165 | val_losses.append(epoch_loss)
166 |
167 | if phase == 'train':
168 | train_losses.append(epoch_loss)
169 | train_epoch_loss = epoch_loss
170 | train_writer.add_scalar("train", epoch_loss, epoch)
171 | elif phase == 'dev':
172 | val_losses.append(epoch_loss)
173 | train_writer.add_scalar("val", epoch_loss, epoch)
174 |
175 | train_writer.add_scalars('loss train/val', {
176 | 'train': train_epoch_loss,
177 | 'val': epoch_loss,
178 | }, epoch)
179 |
180 | print('{} Loss: {:.4f}'.format(phase, epoch_loss))
181 |
182 | # deep copy the model
183 | if phase == 'dev' and epoch_loss < best_loss:
184 | best_loss = epoch_loss
185 | '''save checkpoint'''
186 | torch.save(model.state_dict(), checkpoint_path_best)
187 | print(f'update best on dev checkpoint: {checkpoint_path_best}')
188 |
189 | if phase == 'test' and stepone==False:
190 | print("Evaluation on test")
191 | try:
192 | """ calculate corpus bleu score """
193 | weights_list = [
194 | (1.0,), (0.5, 0.5), (1./3., 1./3., 1./3.), (0.25, 0.25, 0.25, 0.25)]
195 | for weight in weights_list:
196 | # print('weight:',weight)
197 | corpus_bleu_score = corpus_bleu(
198 | target_tokens_list, pred_tokens_list, weights=weight)
199 | print(
200 | f'corpus BLEU-{len(list(weight))} score:', corpus_bleu_score)
201 |
202 | """ calculate rouge score """
203 | rouge = Rouge()
204 | rouge_scores = rouge.get_scores( pred_string_list, target_string_list, avg=True, ignore_empty=True)
205 | print(rouge_scores)
206 | """ calculate bertscore"""
207 | P, R, F1 = score(pred_string_list,target_string_list, lang='en', device="cuda:0", model_type="bert-large-uncased")
208 | print(f"bert_score P: {np.mean(np.array(P))}")
209 | print(f"bert_score R: {np.mean(np.array(R))}")
210 | print(f"bert_score F1: {np.mean(np.array(F1))}")
211 | except:
212 | print("failed")
213 |
214 | print()
215 |
216 | print(f"Train losses: {train_losses}")
217 | print(f"Val losses: {val_losses}")
218 |
219 | time_elapsed = time.time() - since
220 | print('Training complete in {:.0f}m {:.0f}s'.format(
221 | time_elapsed // 60, time_elapsed % 60))
222 | print('Best val loss: {:4f}'.format(best_loss))
223 | torch.save(model.state_dict(), checkpoint_path_last)
224 | print(f'update last checkpoint: {checkpoint_path_last}')
225 |
226 | return model
227 |
228 |
229 | def show_require_grad_layers(model):
230 | print()
231 | print(' require_grad layers:')
232 | # sanity check
233 | for name, param in model.named_parameters():
234 | if param.requires_grad:
235 | print(' ', name)
236 |
237 |
238 | if __name__ == '__main__':
239 | args = get_config('train_decoding')
240 |
241 | ''' config param'''
242 | dataset_setting = 'unique_sent'
243 |
244 | num_epochs_step1 = args['num_epoch_step1']
245 | num_epochs_step2 = args['num_epoch_step2']
246 | step1_lr = args['learning_rate_step1']
247 | step2_lr = args['learning_rate_step2']
248 |
249 | batch_size = args['batch_size']
250 |
251 | model_name = args['model_name']
252 | task_name = args['task_name']
253 |
254 | save_path = args['save_path']
255 |
256 | skip_step_one = args['skip_step_one']
257 | load_step1_checkpoint = args['load_step1_checkpoint']
258 | use_random_init = False
259 |
260 | if use_random_init and skip_step_one:
261 | step2_lr = 5*1e-4
262 |
263 | print(f'[INFO]using model: {model_name}')
264 | print(f'[INFO]using use_random_init: {use_random_init}')
265 |
266 | if skip_step_one:
267 | save_name = f'{task_name}_finetune_{model_name}_skipstep1_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}'
268 | else:
269 | save_name = f'{task_name}_finetune_{model_name}_2steptraining_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}'
270 |
271 | if use_random_init:
272 | save_name = 'randinit_' + save_name
273 |
274 | output_checkpoint_name_best = save_path + f'/best/{save_name}.pt'
275 | output_checkpoint_name_last = save_path + f'/last/{save_name}.pt'
276 |
277 | subject_choice = args['subjects']
278 | print(f'![Debug]using {subject_choice}')
279 | eeg_type_choice = args['eeg_type']
280 | print(f'[INFO]eeg type {eeg_type_choice}')
281 | bands_choice = args['eeg_bands']
282 | print(f'[INFO]using bands {bands_choice}')
283 |
284 | ''' set random seeds '''
285 | seed_val = 312
286 | np.random.seed(seed_val)
287 | torch.manual_seed(seed_val)
288 | torch.cuda.manual_seed_all(seed_val)
289 |
290 | ''' set up device '''
291 | # use cuda
292 | if torch.cuda.is_available():
293 | # dev = "cuda:3"
294 | dev = args['cuda']
295 | else:
296 | dev = "cpu"
297 | device = torch.device(dev)
298 | print(f'[INFO]using device {dev}')
299 | print()
300 |
301 | ''' set up dataloader '''
302 | whole_dataset_dicts = []
303 | if 'task1' in task_name:
304 | dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset_wRaw.pickle'
305 | with open(dataset_path_task1, 'rb') as handle:
306 | whole_dataset_dicts.append(pickle.load(handle))
307 | if 'task2' in task_name:
308 | dataset_path_task2 = './dataset/ZuCo/task2-NR/pickle/task2-NR-dataset_wRaw.pickle'
309 | with open(dataset_path_task2, 'rb') as handle:
310 | whole_dataset_dicts.append(pickle.load(handle))
311 | if 'task3' in task_name:
312 | dataset_path_task3 = './dataset/ZuCo/task3-TSR/pickle/task3-TSR-dataset_wRaw.pickle'
313 | with open(dataset_path_task3, 'rb') as handle:
314 | whole_dataset_dicts.append(pickle.load(handle))
315 | if 'taskNRv2' in task_name:
316 | dataset_path_taskNRv2 = './dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset_wRaw.pickle'
317 | with open(dataset_path_taskNRv2, 'rb') as handle:
318 | whole_dataset_dicts.append(pickle.load(handle))
319 |
320 | print()
321 |
322 | """save config"""
323 | with open(f'./config/decoding_raw/{save_name}.json', 'w') as out_config:
324 | json.dump(args, out_config, indent=4)
325 |
326 | if model_name in ['BrainTranslator', 'BrainTranslatorNaive']:
327 | tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
328 |
329 | # train dataset
330 | train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject=subject_choice,
331 | eeg_type=eeg_type_choice, bands=bands_choice, setting=dataset_setting, raweeg=True)
332 | # dev dataset
333 | dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject=subject_choice,
334 | eeg_type=eeg_type_choice, bands=bands_choice, setting=dataset_setting, raweeg=True)
335 | # test dataset
336 | test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject=subject_choice,
337 | eeg_type=eeg_type_choice, bands=bands_choice, setting=dataset_setting, raweeg=True)
338 |
339 | dataset_sizes = {'train': len(train_set), 'dev': len(
340 | dev_set), 'test': len(test_set)}
341 | print('[INFO]train_set size: ', len(train_set))
342 | print('[INFO]dev_set size: ', len(dev_set))
343 | print('[INFO]test_set size: ', len(test_set))
344 |
345 | # Allows to pad and get real size of eeg vectors
346 | def pad_and_sort_batch(data_loader_batch):
347 | """
348 | data_loader_batch should be a list of (sequence, target, length) tuples...
349 | Returns a padded tensor of sequences sorted from longest to shortest,
350 | """
351 | input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, word_contents, word_contents_attn, subject = tuple(
352 | zip(*data_loader_batch))
353 |
354 | raw_eeg = []
355 | input_raw_embeddings_lenghts = []
356 | for sentence in input_raw_embeddings:
357 | input_raw_embeddings_lenghts.append(
358 | torch.Tensor([a.size(0) for a in sentence]))
359 | raw_eeg.append(pad_sequence(
360 | sentence, batch_first=True, padding_value=0).permute(1, 0, 2))
361 |
362 | input_raw_embeddings = pad_sequence(
363 | raw_eeg, batch_first=True, padding_value=0).permute(0, 2, 1, 3)
364 |
365 | return input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG, input_raw_embeddings, input_raw_embeddings_lenghts, word_contents, word_contents_attn, subject # lengths
366 |
367 |
368 | # train dataloader
369 | train_dataloader = DataLoader(
370 | train_set, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=pad_and_sort_batch) # 4
371 | # dev dataloader
372 | val_dataloader = DataLoader(
373 | dev_set, batch_size=1, shuffle=False, num_workers=0, collate_fn=pad_and_sort_batch) # 4
374 | # dev dataloader
375 | test_dataloader = DataLoader(
376 | test_set, batch_size=1, shuffle=False, num_workers=0, collate_fn=pad_and_sort_batch) # 4
377 | # dataloaders
378 | dataloaders = {'train': train_dataloader,
379 | 'dev': val_dataloader, 'test': test_dataloader}
380 |
381 | ''' set up model '''
382 | if model_name == 'BrainTranslator':
383 | pretrained = BartForConditionalGeneration.from_pretrained(
384 | 'facebook/bart-large')
385 |
386 | model = BrainTranslator(pretrained, in_feature=1024, decoder_embedding_size=1024,
387 | additional_encoder_nhead=8, additional_encoder_dim_feedforward=4096)
388 |
389 | model.to(device)
390 |
391 | ''' training loop '''
392 |
393 | ######################################################
394 | '''step one trainig'''
395 | ######################################################
396 |
397 | # closely follow BART paper
398 | if model_name in ['BrainTranslator']:
399 | for name, param in model.named_parameters():
400 | if param.requires_grad and 'pretrained' in name:
401 | if ('shared' in name) or ('embed_positions' in name) or ('encoder.layers.0' in name):
402 | continue
403 | else:
404 | param.requires_grad = False
405 |
406 | if skip_step_one:
407 | if load_step1_checkpoint:
408 | stepone_checkpoint = 'path_to_step_1_checkpoint.pt'
409 | print(f'skip step one, load checkpoint: {stepone_checkpoint}')
410 | model.load_state_dict(torch.load(stepone_checkpoint))
411 | else:
412 | print('skip step one, start from scratch at step two')
413 | else:
414 | model.to(device)
415 |
416 | ''' set up optimizer and scheduler'''
417 | optimizer_step1 = optim.SGD(filter(
418 | lambda p: p.requires_grad, model.parameters()), lr=step1_lr, momentum=0.9)
419 |
420 | exp_lr_scheduler_step1 = lr_scheduler.CyclicLR(optimizer_step1,
421 | base_lr = step1_lr, # Initial learning rate which is the lower boundary in the cycle for each parameter group
422 | max_lr = 5e-3, # Upper learning rate boundaries in the cycle for each parameter group
423 | mode = "triangular2") #triangular2
424 |
425 | ''' set up loss function '''
426 | criterion = nn.MSELoss()
427 | model.freeze_pretrained_bart()
428 |
429 | print('=== start Step1 training ... ===')
430 | # print training layers
431 | show_require_grad_layers(model)
432 | # return best loss model from step1 training
433 | model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs_step1,
434 | checkpoint_path_best=output_checkpoint_name_best, checkpoint_path_last=output_checkpoint_name_last, stepone=True)
435 |
436 | train_writer.flush()
437 | train_writer.close()
438 | val_writer.flush()
439 | val_writer.close()
440 | dev_writer.flush()
441 | dev_writer.close()
442 |
443 | ######################################################
444 | '''step two trainig'''
445 | ######################################################
446 |
447 | #model.load_state_dict(torch.load("./checkpoints/decoding_raw_104_h/last/task1_task2_taskNRv2_finetune_BrainTranslator_skipstep1_b4_25_100_5e-05_5e-05_unique_sent.pt"))
448 | model.freeze_pretrained_brain()
449 |
450 | ''' set up optimizer and scheduler'''
451 | optimizer_step2 = optim.SGD(filter(
452 | lambda p: p.requires_grad, model.parameters()), lr=step2_lr, momentum=0.9)
453 |
454 | exp_lr_scheduler_step2 = lr_scheduler.CyclicLR(optimizer_step2,
455 | base_lr = 0.0000005, # Initial learning rate which is the lower boundary in the cycle for each parameter group
456 | max_lr = 0.00005, # Upper learning rate boundaries in the cycle for each parameter group
457 | mode = "triangular2") #triangular2'''
458 |
459 | ''' set up loss function '''
460 | criterion = nn.CrossEntropyLoss()
461 |
462 | print()
463 | print('=== start Step2 training ... ===')
464 | # print training layers
465 | show_require_grad_layers(model)
466 |
467 | model.to(device)
468 |
469 | '''main loop'''
470 | trained_model = train_model(dataloaders, device, model, criterion, optimizer_step2, exp_lr_scheduler_step2, num_epochs=num_epochs_step2,
471 | checkpoint_path_best=output_checkpoint_name_best, checkpoint_path_last=output_checkpoint_name_last, stepone=False)
472 |
473 | train_writer.flush()
474 | train_writer.close()
475 | val_writer.flush()
476 | val_writer.close()
477 | dev_writer.flush()
478 | dev_writer.close()
479 |
480 |
--------------------------------------------------------------------------------
/util/__pycache__/data_loading_helpers_modified.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamzaamrani/EEG-to-Text-Decoding/0b8cf03b040e37f9e573fced95ad47667b194557/util/__pycache__/data_loading_helpers_modified.cpython-38.pyc
--------------------------------------------------------------------------------
/util/construct_dataset_mat_to_pickle_v1_withRaw.py:
--------------------------------------------------------------------------------
1 | import scipy.io as io
2 | import h5py
3 | import os
4 | import json
5 | from glob import glob
6 | from tqdm import tqdm
7 | import numpy as np
8 | import pickle
9 | import argparse
10 |
11 | parser = argparse.ArgumentParser(description='Specify task name for converting ZuCo v1.0 Mat file to Pickle')
12 | parser.add_argument('-t', '--task_name', help='name of the task in /dataset/ZuCo, choose from {task1-SR,task2-NR,task3-TSR}', required=True)
13 | args = vars(parser.parse_args())
14 |
15 |
16 | """config"""
17 | version = 'v1' # 'old'
18 | # version = 'v2' # 'new'
19 |
20 | task_name = args['task_name']
21 | # task_name = 'task1-SR'
22 | # task_name = 'task2-NR'
23 | # task_name = 'task3-TSR'
24 |
25 |
26 | print('##############################')
27 | print(f'start processing ZuCo {task_name}...')
28 |
29 |
30 | if version == 'v1':
31 | # old version
32 | input_mat_files_dir = f'./dataset/ZuCo/{task_name}/Matlab_files'
33 | elif version == 'v2':
34 | # new version, mat73
35 | input_mat_files_dir = f'./dataset/ZuCo/{task_name}/Matlab_files'
36 |
37 | output_dir = f'./dataset/ZuCo/{task_name}/pickle'
38 | if not os.path.exists(output_dir):
39 | os.makedirs(output_dir)
40 |
41 | """load files"""
42 | mat_files = glob(os.path.join(input_mat_files_dir,'*.mat'))
43 | mat_files = sorted(mat_files)
44 |
45 | if len(mat_files) == 0:
46 | print(f'No mat files found for {task_name}')
47 | quit()
48 |
49 | dataset_dict = {}
50 | for mat_file in tqdm(mat_files):
51 | subject_name = os.path.basename(mat_file).split('_')[0].replace('results','').strip()
52 | dataset_dict[subject_name] = []
53 |
54 | if version == 'v1':
55 | matdata = io.loadmat(mat_file, squeeze_me=True, struct_as_record=False)['sentenceData']
56 | elif version == 'v2':
57 | matdata = h5py.File(mat_file,'r')
58 | print(matdata)
59 |
60 | for e,sent in enumerate(matdata):
61 | word_data = sent.word
62 | if not isinstance(word_data, float):
63 | # sentence level:
64 | sent_obj = {'content':sent.content}
65 | sent_obj['sentence_level_EEG'] = {'mean_t1':sent.mean_t1, 'mean_t2':sent.mean_t2, 'mean_a1':sent.mean_a1, 'mean_a2':sent.mean_a2, 'mean_b1':sent.mean_b1, 'mean_b2':sent.mean_b2, 'mean_g1':sent.mean_g1, 'mean_g2':sent.mean_g2}
66 |
67 | if task_name == 'task1-SR':
68 | sent_obj['answer_EEG'] = {'answer_mean_t1':sent.answer_mean_t1, 'answer_mean_t2':sent.answer_mean_t2, 'answer_mean_a1':sent.answer_mean_a1, 'answer_mean_a2':sent.answer_mean_a2, 'answer_mean_b1':sent.answer_mean_b1, 'answer_mean_b2':sent.answer_mean_b2, 'answer_mean_g1':sent.answer_mean_g1, 'answer_mean_g2':sent.answer_mean_g2}
69 |
70 | # word level:
71 | sent_obj['word'] = []
72 |
73 | word_tokens_has_fixation = []
74 | word_tokens_with_mask = []
75 | word_tokens_all = []
76 | print(sent.content)
77 | for word in word_data:
78 | word_obj = {'content':word.content}
79 | word_tokens_all.append(word.content)
80 | # TODO: add more version of word level eeg: GD, SFD, GPT
81 | word_obj['nFixations'] = word.nFixations
82 | if word.nFixations > 0:
83 | word_obj['word_level_EEG'] = {'FFD':{'FFD_t1':word.FFD_t1, 'FFD_t2':word.FFD_t2, 'FFD_a1':word.FFD_a1, 'FFD_a2':word.FFD_a2, 'FFD_b1':word.FFD_b1, 'FFD_b2':word.FFD_b2, 'FFD_g1':word.FFD_g1, 'FFD_g2':word.FFD_g2}}
84 | word_obj['word_level_EEG']['TRT'] = {'TRT_t1':word.TRT_t1, 'TRT_t2':word.TRT_t2, 'TRT_a1':word.TRT_a1, 'TRT_a2':word.TRT_a2, 'TRT_b1':word.TRT_b1, 'TRT_b2':word.TRT_b2, 'TRT_g1':word.TRT_g1, 'TRT_g2':word.TRT_g2}
85 | word_obj['word_level_EEG']['GD'] = {'GD_t1':word.GD_t1, 'GD_t2':word.GD_t2, 'GD_a1':word.GD_a1, 'GD_a2':word.GD_a2, 'GD_b1':word.GD_b1, 'GD_b2':word.GD_b2, 'GD_g1':word.GD_g1, 'GD_g2':word.GD_g2}
86 | sent_obj['word'].append(word_obj)
87 | word_tokens_has_fixation.append(word.content)
88 | word_tokens_with_mask.append(word.content)
89 | #print(word.rawEEG)
90 | if type(word.rawEEG) == np.ndarray:
91 | """print(e)
92 | print(f"content: {word.content}")
93 | print(f"shape: {word.rawEEG.shape}")
94 | print(f"size: {word.rawEEG.size}")
95 | print(f"ndim: {word.rawEEG.ndim}")
96 | print("nan:",np.isnan(word.rawEEG[0]).any())"""
97 | if not np.isnan(word.rawEEG[0]).any():
98 | if word.rawEEG[0].ndim == 2:
99 |
100 | word_obj['rawEEG']=[np.float32(arr.transpose()) for arr in word.rawEEG if type(arr)==np.ndarray]
101 | elif word.rawEEG[0].ndim == 1:
102 |
103 | word_obj['rawEEG']=[word.rawEEG.transpose()]
104 | else:
105 | word_obj['rawEEG']=[]
106 | else:
107 | word_obj['rawEEG']=[]
108 | """ print("len:",len(word_obj['rawEEG']))
109 | print("\n")"""
110 |
111 | #print(word_obj['rawEEG'])
112 | else:
113 | word_tokens_with_mask.append('[MASK]')
114 | # if a word has no fixation, use sentence level feature
115 | # word_obj['word_level_EEG'] = {'FFD':{'FFD_t1':sent.mean_t1, 'FFD_t2':sent.mean_t2, 'FFD_a1':sent.mean_a1, 'FFD_a2':sent.mean_a2, 'FFD_b1':sent.mean_b1, 'FFD_b2':sent.mean_b2, 'FFD_g1':sent.mean_g1, 'FFD_g2':sent.mean_g2}}
116 | # word_obj['word_level_EEG']['TRT'] = {'TRT_t1':sent.mean_t1, 'TRT_t2':sent.mean_t2, 'TRT_a1':sent.mean_a1, 'TRT_a2':sent.mean_a2, 'TRT_b1':sent.mean_b1, 'TRT_b2':sent.mean_b2, 'TRT_g1':sent.mean_g1, 'TRT_g2':sent.mean_g2}
117 |
118 | # NOTE:if a word has no fixation, simply skip it
119 | continue
120 |
121 | sent_obj['word_tokens_has_fixation'] = word_tokens_has_fixation
122 | sent_obj['word_tokens_with_mask'] = word_tokens_with_mask
123 | sent_obj['word_tokens_all'] = word_tokens_all
124 |
125 | dataset_dict[subject_name].append(sent_obj)
126 |
127 | else:
128 | print(f'missing sent: subj:{subject_name} content:{sent.content}, return None')
129 | dataset_dict[subject_name].append(None)
130 |
131 | continue
132 | # print(dataset_dict.keys())
133 | # print(dataset_dict[subject_name][0].keys())
134 | # print(dataset_dict[subject_name][0]['content'])
135 | # print(dataset_dict[subject_name][0]['word'][0].keys())
136 | # print(dataset_dict[subject_name][0]['word'][0]['word_level_EEG']['FFD'])
137 |
138 | """output"""
139 | output_name = f'{task_name}-dataset_wRaw.pickle'
140 | # with open(os.path.join(output_dir,'task1-SR-dataset.json'), 'w') as out:
141 | # json.dump(dataset_dict,out,indent = 4)
142 |
143 | with open(os.path.join(output_dir,output_name), 'wb') as handle:
144 | pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
145 | print('write to:', os.path.join(output_dir,output_name))
146 |
147 |
148 | """sanity check"""
149 | # check dataset
150 | with open(os.path.join(output_dir,output_name), 'rb') as handle:
151 | whole_dataset = pickle.load(handle)
152 | print('subjects:', whole_dataset.keys())
153 |
154 | if version == 'v1':
155 | print('num of sent:', len(whole_dataset['ZAB']))
156 | print()
157 |
158 |
159 |
--------------------------------------------------------------------------------
/util/construct_dataset_mat_to_pickle_v2_withRaw.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import h5py
4 | import data_loading_helpers_modified as dh
5 | from glob import glob
6 | from tqdm import tqdm
7 | import pickle
8 |
9 | print("modified version of construct_dataset_mat_to_pickle_v2.py to include raw EEG data, fixed")
10 |
11 | task = "NR"
12 |
13 | rootdir = "./dataset/ZuCo/task2-NR-2.0/Matlab_files/"
14 |
15 | print('##############################')
16 | print(f'start processing ZuCo task2-NR-2.0...')
17 |
18 | dataset_dict = {}
19 |
20 | for file in tqdm(os.listdir(rootdir)):
21 | if file.endswith(task+".mat"):
22 |
23 | file_name = rootdir + file
24 |
25 | # print('file name:', file_name)
26 | subject = file_name.split("ts")[1].split("_")[0]
27 | # print('subject: ', subject)
28 |
29 | # exclude YMH due to incomplete data because of dyslexia
30 | if subject != 'YMH':
31 | assert subject not in dataset_dict
32 | dataset_dict[subject] = []
33 |
34 | f = h5py.File(file_name,'r')
35 | # print('keys in f:', list(f.keys()))
36 | sentence_data = f['sentenceData']
37 | # print('keys in sentence_data:', list(sentence_data.keys()))
38 |
39 | # sent level eeg
40 | # mean_t1 = np.squeeze(f[sentence_data['mean_t1'][0][0]][()])
41 | mean_t1_objs = sentence_data['mean_t1']
42 | mean_t2_objs = sentence_data['mean_t2']
43 | mean_a1_objs = sentence_data['mean_a1']
44 | mean_a2_objs = sentence_data['mean_a2']
45 | mean_b1_objs = sentence_data['mean_b1']
46 | mean_b2_objs = sentence_data['mean_b2']
47 | mean_g1_objs = sentence_data['mean_g1']
48 | mean_g2_objs = sentence_data['mean_g2']
49 |
50 | rawData = sentence_data['rawData']
51 | contentData = sentence_data['content']
52 | # print('contentData shape:', contentData.shape, 'dtype:', contentData.dtype)
53 | omissionR = sentence_data['omissionRate']
54 | wordData = sentence_data['word']
55 |
56 |
57 | for idx in range(len(rawData)):
58 | # get sentence string
59 | obj_reference_content = contentData[idx][0]
60 | sent_string = dh.load_matlab_string(f[obj_reference_content])
61 | # print('sentence string:', sent_string)
62 |
63 | sent_obj = {'content':sent_string}
64 |
65 | # get sentence level EEG
66 | sent_obj['sentence_level_EEG'] = {
67 | 'mean_t1':np.squeeze(f[mean_t1_objs[idx][0]][()]),
68 | 'mean_t2':np.squeeze(f[mean_t2_objs[idx][0]][()]),
69 | 'mean_a1':np.squeeze(f[mean_a1_objs[idx][0]][()]),
70 | 'mean_a2':np.squeeze(f[mean_a2_objs[idx][0]][()]),
71 | 'mean_b1':np.squeeze(f[mean_b1_objs[idx][0]][()]),
72 | 'mean_b2':np.squeeze(f[mean_b2_objs[idx][0]][()]),
73 | 'mean_g1':np.squeeze(f[mean_g1_objs[idx][0]][()]),
74 | 'mean_g2':np.squeeze(f[mean_g2_objs[idx][0]][()])
75 | }
76 | # print(sent_obj)
77 | sent_obj['word'] = []
78 |
79 | # get word level data
80 | word_data, word_tokens_all, word_tokens_has_fixation, word_tokens_with_mask = dh.extract_word_level_data(f, f[wordData[idx][0]])
81 |
82 | if word_data == {}:
83 | print(f'missing sent: subj:{subject} content:{sent_string}, append None')
84 | dataset_dict[subject].append(None)
85 | continue
86 | elif len(word_tokens_all) == 0:
87 | print(f'no word level features: subj:{subject} content:{sent_string}, append None')
88 | dataset_dict[subject].append(None)
89 | continue
90 |
91 | else:
92 | for widx in range(len(word_data)):
93 | data_dict = word_data[widx]
94 | word_obj = {'content':data_dict['content'], 'nFixations': data_dict['nFix'], 'rawEEG':data_dict['RAW_EEG']} #@zavidos
95 | """if 'rawEEG' in data_dict:
96 | word_obj['rawEEG'] = data_dict['rawEEG']
97 | else:
98 | word_obj['rawEEG'] = None"""
99 | if 'GD_EEG' in data_dict:
100 | # print('has fixation: ', data_dict['content'])
101 | gd = data_dict["GD_EEG"]
102 | ffd = data_dict["FFD_EEG"]
103 | trt = data_dict["TRT_EEG"]
104 | assert len(gd) == len(trt) == len(ffd) == 8
105 | word_obj['word_level_EEG'] = {
106 | 'GD':{'GD_t1':gd[0], 'GD_t2':gd[1], 'GD_a1':gd[2], 'GD_a2':gd[3], 'GD_b1':gd[4], 'GD_b2':gd[5], 'GD_g1':gd[6], 'GD_g2':gd[7]},
107 | 'FFD':{'FFD_t1':ffd[0], 'FFD_t2':ffd[1], 'FFD_a1':ffd[2], 'FFD_a2':ffd[3], 'FFD_b1':ffd[4], 'FFD_b2':ffd[5], 'FFD_g1':ffd[6], 'FFD_g2':ffd[7]},
108 | 'TRT':{'TRT_t1':trt[0], 'TRT_t2':trt[1], 'TRT_a1':trt[2], 'TRT_a2':trt[3], 'TRT_b1':trt[4], 'TRT_b2':trt[5], 'TRT_g1':trt[6], 'TRT_g2':trt[7]}
109 | }
110 | sent_obj['word'].append(word_obj)
111 |
112 | sent_obj['word_tokens_has_fixation'] = word_tokens_has_fixation
113 | sent_obj['word_tokens_with_mask'] = word_tokens_with_mask
114 | sent_obj['word_tokens_all'] = word_tokens_all
115 |
116 | # print(sent_obj.keys())
117 | # print(len(sent_obj['word']))
118 | # print(sent_obj['word'][0])
119 |
120 | dataset_dict[subject].append(sent_obj)
121 |
122 | """output"""
123 | task_name = 'task2-NR-2.0'
124 |
125 | if dataset_dict == {}:
126 | print(f'No mat file found for {task_name}')
127 | quit()
128 |
129 | output_dir = f'./dataset/ZuCo/{task_name}/pickle'
130 | output_name = f'{task_name}-dataset_wRaw.pickle'
131 | # with open(os.path.join(output_dir,'task1-SR-dataset.json'), 'w') as out:
132 | # json.dump(dataset_dict,out,indent = 4)
133 |
134 | with open(os.path.join(output_dir,output_name), 'wb') as handle:
135 | pickle.dump(dataset_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
136 | print('write to:', os.path.join(output_dir,output_name))
137 |
138 | """sanity check"""
139 | print('subjects:', dataset_dict.keys())
140 | print('num of sent:', len(dataset_dict['YAC']))
--------------------------------------------------------------------------------
/util/data_loading_helpers_modified.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 |
4 | eeg_float_resolution=np.float16
5 |
6 | Alpha_ffd_names = ['FFD_a1', 'FFD_a1_diff', 'FFD_a2', 'FFD_a2_diff']
7 | Beta_ffd_names = ['FFD_b1', 'FFD_b1_diff', 'FFD_b2', 'FFD_b2_diff']
8 | Gamma_ffd_names = ['FFD_g1', 'FFD_g1_diff', 'FFD_g2', 'FFD_g2_diff']
9 | Theta_ffd_names = ['FFD_t1', 'FFD_t1_diff', 'FFD_t2', 'FFD_t2_diff']
10 | Alpha_gd_names = ['GD_a1', 'GD_a1_diff', 'GD_a2', 'GD_a2_diff']
11 | Beta_gd_names = ['GD_b1', 'GD_b1_diff', 'GD_b2', 'GD_b2_diff']
12 | Gamma_gd_names = ['GD_g1', 'GD_g1_diff', 'GD_g2', 'GD_g2_diff']
13 | Theta_gd_names = ['GD_t1', 'GD_t1_diff', 'GD_t2', 'GD_t2_diff']
14 | Alpha_gpt_names = ['GPT_a1', 'GPT_a1_diff', 'GPT_a2', 'GPT_a2_diff']
15 | Beta_gpt_names = ['GPT_b1', 'GPT_b1_diff', 'GPT_b2', 'GPT_b2_diff']
16 | Gamma_gpt_names = ['GPT_g1', 'GPT_g1_diff', 'GPT_g2', 'GPT_g2_diff']
17 | Theta_gpt_names = ['GPT_t1', 'GPT_t1_diff', 'GPT_t2', 'GPT_t2_diff']
18 | Alpha_sfd_names = ['SFD_a1', 'SFD_a1_diff', 'SFD_a2', 'SFD_a2_diff']
19 | Beta_sfd_names = ['SFD_b1', 'SFD_b1_diff', 'SFD_b2', 'SFD_b2_diff']
20 | Gamma_sfd_names = ['SFD_g1', 'SFD_g1_diff', 'SFD_g2', 'SFD_g2_diff']
21 | Theta_sfd_names = ['SFD_t1', 'SFD_t1_diff', 'SFD_t2', 'SFD_t2_diff']
22 | Alpha_trt_names = ['TRT_a1', 'TRT_a1_diff', 'TRT_a2', 'TRT_a2_diff']
23 | Beta_trt_names = ['TRT_b1', 'TRT_b1_diff', 'TRT_b2', 'TRT_b2_diff']
24 | Gamma_trt_names = ['TRT_g1', 'TRT_g1_diff', 'TRT_g2', 'TRT_g2_diff']
25 | Theta_trt_names = ['TRT_t1', 'TRT_t1_diff', 'TRT_t2', 'TRT_t2_diff']
26 |
27 | # IF YOU CHANGE THOSE YOU MUST ALSO CHANGE CONSTANTS
28 | Alpha_features = Alpha_ffd_names + Alpha_gd_names + Alpha_gpt_names + Alpha_trt_names# + Alpha_sfd_names
29 | Beta_features = Beta_ffd_names + Beta_gd_names + Beta_gpt_names + Beta_trt_names# + Beta_sfd_names
30 | Gamma_features = Gamma_ffd_names + Gamma_gd_names + Gamma_gpt_names + Gamma_trt_names# + Gamma_sfd_names
31 | Theta_features = Theta_ffd_names + Theta_gd_names + Theta_gpt_names + Theta_trt_names# + Theta_sfd_names
32 | # print(Alpha_features)
33 |
34 | # GD_EEG_feautres
35 |
36 |
37 | def extract_all_fixations(data_container, word_data_object, float_resolution = np.float16):
38 | """
39 | Extracts all fixations from a word data object
40 | :param data_container: (h5py) Container of the whole data, h5py object
41 | :param word_data_object: (h5py) Container of fixation objects, h5py object
42 | :param float_resolution: (type) Resolution to which data re to be converted, used for data compression
43 | :return:
44 | fixations_data (list) Data arrays representing each fixation
45 | """
46 | word_data = data_container[word_data_object]
47 | fixations_data = []
48 | if len(word_data.shape) > 1:
49 | for fixation_idx in range(word_data.shape[0]):
50 | fixations_data.append(np.array(data_container[word_data[fixation_idx][0]]).astype(float_resolution))
51 | return fixations_data
52 |
53 |
54 | def is_real_word(word):
55 | """
56 | Check if the word is a real word
57 | :param word: (str) word string
58 | :return:
59 | is_word (bool) True if it is a real word
60 | """
61 | is_word = re.search('[a-zA-Z0-9]', word)
62 | return is_word
63 |
64 |
65 | def load_matlab_string(matlab_extracted_object):
66 | """
67 | Converts a string loaded from h5py into a python string
68 | :param matlab_extracted_object: (h5py) matlab string object
69 | :return:
70 | extracted_string (str) translated string
71 | """
72 | extracted_string = u''.join(chr(c[0]) for c in matlab_extracted_object)
73 | return extracted_string
74 |
75 |
76 | def extract_word_level_data(data_container, word_objects, eeg_float_resolution = np.float16):
77 | """
78 | Extracts word level data for a specific sentence
79 | :param data_container: (h5py) Container of the whole data, h5py object
80 | :param word_objects: (h5py) Container of all word data for a specific sentence
81 | :param eeg_float_resolution: (type) Resolution with which to save EEG, used for data compression
82 | :return:
83 | word_level_data (dict) Contains all word level data indexed by their index number in the sentence,
84 | together with the reading order, indexed by "word_reading_order"
85 | """
86 | available_objects = list(word_objects)
87 | #print(available_objects)
88 | #print(len(available_objects))
89 | # print('available_objects:', available_objects)
90 |
91 | if isinstance(available_objects[0], str):
92 |
93 | contentData = word_objects['content']
94 | #fixations_order_per_word = []
95 | if "rawEEG" in available_objects:
96 |
97 | rawData = word_objects['rawEEG']
98 | etData = word_objects['rawET']
99 |
100 | ffdData = word_objects['FFD']
101 | gdData = word_objects['GD']
102 | gptData = word_objects['GPT']
103 | trtData = word_objects['TRT']
104 |
105 | try:
106 | sfdData = word_objects['SFD']
107 | except KeyError:
108 | print("no SFD!")
109 | sfdData = []
110 | nFixData = word_objects['nFixations']
111 | fixPositions = word_objects["fixPositions"]
112 |
113 | Alpha_features_data = [word_objects[feature] for feature in Alpha_features]
114 | Beta_features_data = [word_objects[feature] for feature in Beta_features]
115 | Gamma_features_data = [word_objects[feature] for feature in Gamma_features]
116 | Theta_features_data = [word_objects[feature] for feature in Theta_features]
117 | ####
118 | GD_EEG_features = [word_objects[feature] for feature in ['GD_t1','GD_t2','GD_a1','GD_a2','GD_b1','GD_b2','GD_g1','GD_g2']]
119 | FFD_EEG_features = [word_objects[feature] for feature in ['FFD_t1','FFD_t2','FFD_a1','FFD_a2','FFD_b1','FFD_b2','FFD_g1','FFD_g2']]
120 | TRT_EEG_features = [word_objects[feature] for feature in ['TRT_t1','TRT_t2','TRT_a1','TRT_a2','TRT_b1','TRT_b2','TRT_g1','TRT_g2']]
121 | ####
122 | assert len(contentData) == len(etData) == len(rawData), "different amounts of different data!!"
123 |
124 | zipped_data = zip(rawData, etData, contentData, ffdData, gdData, gptData, trtData, sfdData, nFixData, fixPositions)
125 |
126 | word_level_data = {}
127 | word_idx = 0
128 |
129 | word_tokens_has_fixation = []
130 | word_tokens_with_mask = []
131 | word_tokens_all = []
132 | for raw_eegs_obj, ets_obj, word_obj, ffd, gd, gpt, trt, sfd, nFix, fixPos in zipped_data:
133 | word_string = load_matlab_string(data_container[word_obj[0]])
134 | if is_real_word(word_string):
135 | data_dict = {}
136 | data_dict["RAW_EEG"] = extract_all_fixations(data_container, raw_eegs_obj[0], eeg_float_resolution)
137 | data_dict["RAW_ET"] = extract_all_fixations(data_container, ets_obj[0], np.float32)
138 |
139 | data_dict["FFD"] = data_container[ffd[0]][()][0, 0] if len(data_container[ffd[0]][()].shape) == 2 else None
140 | data_dict["GD"] = data_container[gd[0]][()][0, 0] if len(data_container[gd[0]][()].shape) == 2 else None
141 | data_dict["GPT"] = data_container[gpt[0]][()][0, 0] if len(data_container[gpt[0]][()].shape) == 2 else None
142 | data_dict["TRT"] = data_container[trt[0]][()][0, 0] if len(data_container[trt[0]][()].shape) == 2 else None
143 | data_dict["SFD"] = data_container[sfd[0]][()][0, 0] if len(data_container[sfd[0]][()].shape) == 2 else None
144 | data_dict["nFix"] = data_container[nFix[0]][()][0, 0] if len(data_container[nFix[0]][()].shape) == 2 else None
145 |
146 | #fixations_order_per_word.append(np.array(data_container[fixPos[0]]))
147 |
148 | #print([data_container[obj[word_idx][0]][()] for obj in Alpha_features_data])
149 |
150 |
151 | data_dict["ALPHA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()]
152 | if len(data_container[obj[word_idx][0]][()].shape) == 2 else []
153 | for obj in Alpha_features_data], 0)
154 |
155 | data_dict["BETA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()]
156 | if len(data_container[obj[word_idx][0]][()].shape) == 2 else []
157 | for obj in Beta_features_data], 0)
158 |
159 | data_dict["GAMMA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()]
160 | if len(data_container[obj[word_idx][0]][()].shape) == 2 else []
161 | for obj in Gamma_features_data], 0)
162 |
163 | data_dict["THETA_EEG"] = np.concatenate([data_container[obj[word_idx][0]][()]
164 | if len(data_container[obj[word_idx][0]][()].shape) == 2 else []
165 | for obj in Theta_features_data], 0)
166 |
167 |
168 |
169 |
170 | data_dict["word_idx"] = word_idx
171 | data_dict["content"] = word_string
172 | ####################################
173 | word_tokens_all.append(word_string)
174 | if data_dict["nFix"] is not None:
175 | ####################################
176 | data_dict["GD_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in GD_EEG_features]
177 | data_dict["FFD_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in FFD_EEG_features]
178 | data_dict["TRT_EEG"] = [np.squeeze(data_container[obj[word_idx][0]][()]) if len(data_container[obj[word_idx][0]][()].shape) == 2 else [] for obj in TRT_EEG_features]
179 | ####################################
180 | word_tokens_has_fixation.append(word_string)
181 | word_tokens_with_mask.append(word_string)
182 | else:
183 | word_tokens_with_mask.append('[MASK]')
184 |
185 |
186 | word_level_data[word_idx] = data_dict
187 | word_idx += 1
188 | else:
189 | print(word_string + " is not a real word.")
190 | else:
191 | # If there are no word-level data it will be word embeddings alone
192 | word_level_data = {}
193 | word_idx = 0
194 | word_tokens_has_fixation = []
195 | word_tokens_with_mask = []
196 | word_tokens_all = []
197 |
198 | for word_obj in contentData:
199 | word_string = load_matlab_string(data_container[word_obj[0]])
200 | if is_real_word(word_string):
201 | data_dict = {}
202 | data_dict["RAW_EEG"] = []
203 | data_dict["ICA_EEG"] = []
204 | data_dict["RAW_ET"] = []
205 | data_dict["FFD"] = None
206 | data_dict["GD"] = None
207 | data_dict["GPT"] = None
208 | data_dict["TRT"] = None
209 | data_dict["SFD"] = None
210 | data_dict["nFix"] = None
211 | data_dict["ALPHA_EEG"] = []
212 | data_dict["BETA_EEG"] = []
213 | data_dict["GAMMA_EEG"] = []
214 | data_dict["THETA_EEG"] = []
215 |
216 | data_dict["word_idx"] = word_idx
217 | data_dict["content"] = word_string
218 | word_level_data[word_idx] = data_dict
219 | word_idx += 1
220 | else:
221 | print(word_string + " is not a real word.")
222 |
223 | sentence = " ".join([load_matlab_string(data_container[word_obj[0]]) for word_obj in word_objects['content']])
224 | #print("Only available objects for the sentence '{}' are {}.".format(sentence, available_objects))
225 | #word_level_data["word_reading_order"] = extract_word_order_from_fixations(fixations_order_per_word)
226 | else:
227 | word_tokens_has_fixation = []
228 | word_tokens_with_mask = []
229 | word_tokens_all = []
230 | word_level_data = {}
231 | return word_level_data, word_tokens_all, word_tokens_has_fixation, word_tokens_with_mask
--------------------------------------------------------------------------------