├── README.md └── decoding ├── Decoder.py ├── EncodingModel.py ├── GPT.py ├── LanguageModel.py ├── StimulusModel.py ├── config.py ├── evaluate_predictions.py ├── run_decoder.py ├── train_EM.py ├── train_WR.py ├── utils_eval.py ├── utils_resp.py ├── utils_ridge ├── DataSequence.py ├── dsutils.py ├── interpdata.py ├── ridge.py ├── stimulus_utils.py ├── textgrid.py ├── util.py └── utils.py └── utils_stim.py /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Decoding 2 | 3 | This repository contains code used in the paper "Semantic reconstruction of continuous language from non-invasive brain recordings" by Jerry Tang, Amanda LeBel, Shailee Jain, and Alexander G. Huth. 4 | 5 | ## Usage 6 | 7 | 1. Download [language model data](https://utexas.box.com/shared/static/7ab8qm5e3i0vfsku0ee4dc6hzgeg7nyh.zip) and extract contents into new `data_lm/` directory. 8 | 9 | 2. Download [training data](https://utexas.box.com/shared/static/3go1g4gcdar2cntjit2knz5jwr3mvxwe.zip) and extract contents into new `data_train/` directory. Stimulus data for `train_stimulus/` and response data for `train_response/[SUBJECT_ID]` can be downloaded from [OpenNeuro](https://openneuro.org/datasets/ds003020/). 10 | 11 | 3. Download [test data](https://utexas.box.com/shared/static/ae5u0t3sh4f46nvmrd3skniq0kk2t5uh.zip) and extract contents into new `data_test/` directory. Stimulus data for `test_stimulus/[EXPERIMENT]` and response data for `test_response/[SUBJECT_ID]` can be downloaded from [OpenNeuro](https://openneuro.org/datasets/ds004510/). 12 | 13 | 4. Estimate the encoding model. The encoding model predicts brain responses from contextual features of the stimulus extracted using GPT. The `--gpt` parameter determines the GPT checkpoint used. Use `--gpt imagined` when estimating models for imagined speech data, as this will extract features using a GPT checkpoint that was not trained on the imagined speech stories. Use `--gpt perceived` when estimating models for other data. The encoding model will be saved in `MODEL_DIR/[SUBJECT_ID]`. Alternatively, download [pre-fit encoding models](https://utexas.box.com/s/ri13t06iwpkyk17h8tfk0dtyva7qtqlz). 14 | 15 | ```bash 16 | python3 decoding/train_EM.py --subject [SUBJECT_ID] --gpt perceived 17 | ``` 18 | 19 | 5. Estimate the word rate model. The word rate model predicts word times from brain responses. Two word rate models will be saved in `MODEL_DIR/[SUBJECT_ID]`. The `word_rate_model_speech` model uses brain responses in speech regions, and should be used when decoding imagined speech and perceived movie data. The `word_rate_model_auditory` model uses brain responses in auditory cortex, and should be used when decoding perceived speech data. Alternatively, download [pre-fit word rate models](https://utexas.box.com/s/ri13t06iwpkyk17h8tfk0dtyva7qtqlz). 20 | 21 | ```bash 22 | python3 decoding/train_WR.py --subject [SUBJECT_ID] 23 | ``` 24 | 25 | 6. Test the decoder on brain responses not used in model estimation. The decoder predictions will be saved in `RESULTS_DIR/[SUBJECT_ID]/[EXPERIMENT_NAME]`. 26 | 27 | ```bash 28 | python3 decoding/run_decoder.py --subject [SUBJECT_ID] --experiment [EXPERIMENT_NAME] --task [TASK_NAME] 29 | ``` 30 | 31 | 7. Evaluate the decoder predictions against reference transcripts. The evaluation results will be saved in `SCORE_DIR/[SUBJECT_ID]/[EXPERIMENT_NAME]`. 32 | 33 | ```bash 34 | python3 decoding/evaluate_predictions.py --subject [SUBJECT_ID] --experiment [EXPERIMENT_NAME] --task [TASK_NAME] 35 | ``` -------------------------------------------------------------------------------- /decoding/Decoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats as ss 3 | 4 | class Decoder(object): 5 | """class for beam search decoding 6 | """ 7 | def __init__(self, word_times, beam_width, extensions = 5): 8 | self.word_times = word_times 9 | self.beam_width, self.extensions = beam_width, extensions 10 | self.beam = [Hypothesis()] # initialize with empty hypothesis 11 | self.scored_extensions = [] # global extension pool 12 | 13 | def first_difference(self): 14 | """get first index where hypotheses on the beam differ 15 | """ 16 | words_arr = np.array([hypothesis.words for hypothesis in self.beam]) 17 | if words_arr.shape[0] == 1: return words_arr.shape[1] 18 | for index in range(words_arr.shape[1]): 19 | if len(set(words_arr[:, index])) > 1: return index 20 | return 0 21 | 22 | def time_window(self, sample_index, seconds, floor = 0): 23 | """number of prior words within [seconds] of the currently sampled time point""" 24 | window = [time for time in self.word_times if time < self.word_times[sample_index] 25 | and time > self.word_times[sample_index] - seconds] 26 | return max(len(window), floor) 27 | 28 | def get_hypotheses(self): 29 | """get the number of permitted extensions for each hypothesis on the beam 30 | """ 31 | if len(self.beam[0].words) == 0: 32 | return zip(self.beam, [self.extensions for hypothesis in self.beam]) 33 | logprobs = [sum(hypothesis.logprobs) for hypothesis in self.beam] 34 | num_extensions = [int(np.ceil(self.extensions * rank / len(logprobs))) for 35 | rank in ss.rankdata(logprobs)] 36 | return zip(self.beam, num_extensions) 37 | 38 | def add_extensions(self, extensions, likelihoods, num_extensions): 39 | """add extensions for each hypothesis to global extension pool 40 | """ 41 | scored_extensions = sorted(zip(extensions, likelihoods), key = lambda x : -x[1]) 42 | self.scored_extensions.extend(scored_extensions[:num_extensions]) 43 | 44 | def extend(self, verbose = False): 45 | """update beam based on global extension pool 46 | """ 47 | self.beam = [x[0] for x in sorted(self.scored_extensions, key = lambda x : -x[1])[:self.beam_width]] 48 | self.scored_extensions = [] 49 | if verbose: print(self.beam[0].words) 50 | 51 | def save(self, path): 52 | """save decoder results 53 | """ 54 | np.savez(path, words = np.array(self.beam[0].words), times = np.array(self.word_times)) 55 | 56 | class Hypothesis(object): 57 | """a class for representing word sequence hypotheses 58 | """ 59 | def __init__(self, parent = None, extension = None): 60 | if parent is None: 61 | self.words, self.logprobs, self.embs = [], [], [] 62 | else: 63 | word, logprob, emb = extension 64 | self.words = parent.words + [word] 65 | self.logprobs = parent.logprobs + [logprob] 66 | self.embs = parent.embs + [emb] -------------------------------------------------------------------------------- /decoding/EncodingModel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | torch.set_default_tensor_type(torch.FloatTensor) 4 | 5 | class EncodingModel(): 6 | """class for computing the likelihood of observing brain recordings given a word sequence 7 | """ 8 | def __init__(self, resp, weights, voxels, sigma, device = "cpu"): 9 | self.device = device 10 | self.weights = torch.from_numpy(weights[:, voxels]).float().to(self.device) 11 | self.resp = torch.from_numpy(resp[:, voxels]).float().to(self.device) 12 | self.sigma = sigma 13 | 14 | def set_shrinkage(self, alpha): 15 | """compute precision from empirical covariance with shrinkage factor alpha 16 | """ 17 | precision = np.linalg.inv(self.sigma * (1 - alpha) + np.eye(len(self.sigma)) * alpha) 18 | self.precision = torch.from_numpy(precision).float().to(self.device) 19 | 20 | def prs(self, stim, trs): 21 | """compute P(R | S) on affected TRs for each hypothesis 22 | """ 23 | with torch.no_grad(): 24 | stim = stim.float().to(self.device) 25 | diff = torch.matmul(stim, self.weights) - self.resp[trs] # encoding model residuals 26 | multi = torch.matmul(torch.matmul(diff, self.precision), diff.permute(0, 2, 1)) 27 | return -0.5 * multi.diagonal(dim1 = -2, dim2 = -1).sum(dim = 1).detach().cpu().numpy() -------------------------------------------------------------------------------- /decoding/GPT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from transformers import AutoModelForCausalLM 4 | from torch.nn.functional import softmax 5 | 6 | class GPT(): 7 | """wrapper for https://huggingface.co/openai-gpt 8 | """ 9 | def __init__(self, path, vocab, device = 'cpu'): 10 | self.device = device 11 | self.model = AutoModelForCausalLM.from_pretrained(path).eval().to(self.device) 12 | self.vocab = vocab 13 | self.word2id = {w : i for i, w in enumerate(self.vocab)} 14 | self.UNK_ID = self.word2id[''] 15 | 16 | def encode(self, words): 17 | """map from words to ids 18 | """ 19 | return [self.word2id[x] if x in self.word2id else self.UNK_ID for x in words] 20 | 21 | def get_story_array(self, words, context_words): 22 | """get word ids for each phrase in a stimulus story 23 | """ 24 | nctx = context_words + 1 25 | story_ids = self.encode(words) 26 | story_array = np.zeros([len(story_ids), nctx]) + self.UNK_ID 27 | for i in range(len(story_array)): 28 | segment = story_ids[i:i+nctx] 29 | story_array[i, :len(segment)] = segment 30 | return torch.tensor(story_array).long() 31 | 32 | def get_context_array(self, contexts): 33 | """get word ids for each context 34 | """ 35 | context_array = np.array([self.encode(words) for words in contexts]) 36 | return torch.tensor(context_array).long() 37 | 38 | def get_hidden(self, ids, layer): 39 | """get hidden layer representations 40 | """ 41 | mask = torch.ones(ids.shape).int() 42 | with torch.no_grad(): 43 | outputs = self.model(input_ids = ids.to(self.device), 44 | attention_mask = mask.to(self.device), output_hidden_states = True) 45 | return outputs.hidden_states[layer].detach().cpu().numpy() 46 | 47 | def get_probs(self, ids): 48 | """get next word probability distributions 49 | """ 50 | mask = torch.ones(ids.shape).int() 51 | with torch.no_grad(): 52 | outputs = self.model(input_ids = ids.to(self.device), attention_mask = mask.to(self.device)) 53 | probs = softmax(outputs.logits, dim = 2).detach().cpu().numpy() 54 | return probs -------------------------------------------------------------------------------- /decoding/LanguageModel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from nltk.stem.snowball import SnowballStemmer 3 | stemmer = SnowballStemmer("english") 4 | 5 | INIT = ['i', 'we', 'she', 'he', 'they', 'it'] 6 | STOPWORDS = {'is', 'does', 's', 'having', 'doing', 'these', 'shan', 'yourself', 'other', 'are', 'hasn', 'at', 'for', 'while', 'down', "hadn't", 'until', 'above', 'during', 'each', 'now', 'have', "won't", 'once', 'why', 'here', 'ourselves', 'to', 'over', 'into', 'who', 'that', 'myself', 'he', 'themselves', 'were', 'against', 'about', 'some', 'has', 'but', 'ma', 'their', 'this', 'there', 'with', "that'll", "shan't", "wouldn't", 'a', 'those', "you'll", 'll', 'few', 'couldn', 'an', 'd', "weren't", 'doesn', 'own', 'won', 'didn', 'what', 'when', 'in', 'below', 'where', "it's", 'most', 'just', "you're", 'yourselves', 'too', "don't", "she's", "didn't", "hasn't", 'isn', "mustn't", 'of', 'did', 'how', 'himself', 'aren', 'if', 'very', 'or', 'weren', 'it', 'be', 'itself', "doesn't", 'my', 'o', 'no', "isn't", 'before', 'after', 'off', 'was', 'can', 'the', 'been', 'her', 'him', "wasn't", 've', 'through', "needn't", 'because', 'nor', 'will', 'm', 't', 'out', 'on', 'she', 'all', 'then', 'than', "mightn't", 'hers', 'herself', 'only', 'should', 're', 'ain', 'wasn', "aren't", "couldn't", 'they', 'hadn', 'had', 'more', 'and', 'under', "shouldn't", 'any', 'y', 'don', 'from', 'so', 'whom', 'as', 'mustn', 'between', 'up', 'do', 'both', 'such', 'our', 'its', 'which', 'not', "haven't", 'needn', 'by', "should've", 'again', 'shouldn', 'his', 'me', 'further', 'yours', 'am', 'your', 'haven', 'wouldn', 'being', 'ours', 'you', 'i', 'theirs', 'mightn', 'same', 'we', "you've", 'them', "you'd"} 7 | 8 | def get_nucleus(probs, nuc_mass, nuc_ratio): 9 | """identify words that constitute a given fraction of the probability mass 10 | """ 11 | nuc_ids = np.where(probs >= np.max(probs) * nuc_ratio)[0] 12 | nuc_pairs = sorted(zip(nuc_ids, probs[nuc_ids]), key = lambda x : -x[1]) 13 | sum_mass = np.cumsum([x[1] for x in nuc_pairs]) 14 | cutoffs = np.where(sum_mass >= nuc_mass)[0] 15 | if len(cutoffs) > 0: nuc_pairs = nuc_pairs[:cutoffs[0]+1] 16 | nuc_ids = [x[0] for x in nuc_pairs] 17 | return nuc_ids 18 | 19 | def in_context(word, context): 20 | """test whether [word] or a stem of [word] is in [context] 21 | """ 22 | stem_context = [stemmer.stem(x) for x in context] 23 | stem_word = stemmer.stem(word) 24 | return (stem_word in stem_context or stem_word in context) 25 | 26 | def context_filter(proposals, context): 27 | """filter out words that occur in a context to prevent repetitions 28 | """ 29 | cut_words = [] 30 | cut_words.extend([context[i+1] for i, word in enumerate(context[:-1]) if word == context[-1]]) # bigrams 31 | cut_words.extend([x for x in proposals if x not in STOPWORDS and in_context(x, context)]) # unigrams 32 | return [x for x in proposals if x not in cut_words] 33 | 34 | class LanguageModel(): 35 | """class for generating word sequences using a language model 36 | """ 37 | def __init__(self, model, vocab, nuc_mass = 1.0, nuc_ratio = 0.0): 38 | self.model = model 39 | self.ids = {i for word, i in self.model.word2id.items() if word in set(vocab)} 40 | self.nuc_mass, self.nuc_ratio = nuc_mass, nuc_ratio 41 | 42 | def ps(self, contexts): 43 | """get probability distributions over the next words for each context 44 | """ 45 | context_arr = self.model.get_context_array(contexts) 46 | probs = self.model.get_probs(context_arr) 47 | return probs[:, len(contexts[0]) - 1] 48 | 49 | def beam_propose(self, beam, context_words): 50 | """get possible extension words for each hypothesis in the decoder beam 51 | """ 52 | if len(beam) == 1: 53 | nuc_words = [w for w in INIT if self.model.word2id[w] in self.ids] 54 | nuc_logprobs = np.log(np.ones(len(nuc_words)) / len(nuc_words)) 55 | return [(nuc_words, nuc_logprobs)] 56 | else: 57 | contexts = [hyp.words[-context_words:] for hyp in beam] 58 | beam_probs = self.ps(contexts) 59 | beam_nucs = [] 60 | for context, probs in zip(contexts, beam_probs): 61 | nuc_ids = get_nucleus(probs, nuc_mass = self.nuc_mass, nuc_ratio = self.nuc_ratio) 62 | nuc_words = [self.model.vocab[i] for i in nuc_ids if i in self.ids] 63 | nuc_words = context_filter(nuc_words, context) 64 | nuc_logprobs = np.log([probs[self.model.word2id[w]] for w in nuc_words]) 65 | beam_nucs.append((nuc_words, nuc_logprobs)) 66 | return beam_nucs -------------------------------------------------------------------------------- /decoding/StimulusModel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | torch.set_default_tensor_type(torch.FloatTensor) 4 | 5 | import config 6 | from utils_ridge.interpdata import lanczosfun 7 | 8 | def get_lanczos_mat(oldtime, newtime, window = 3, cutoff_mult = 1.0, rectify = False): 9 | """get matrix for downsampling from TR times to word times 10 | """ 11 | cutoff = 1 / np.mean(np.diff(newtime)) * cutoff_mult 12 | sincmat = np.zeros((len(newtime), len(oldtime))) 13 | for ndi in range(len(newtime)): 14 | sincmat[ndi,:] = lanczosfun(cutoff, newtime[ndi] - oldtime, window) 15 | return sincmat 16 | 17 | def affected_trs(start_index, end_index, lanczos_mat, delay = True): 18 | """identify TRs influenced by words in the range [start_index, end_index] 19 | """ 20 | start_tr, end_tr = np.where(lanczos_mat[:, start_index])[0][0], np.where(lanczos_mat[:, end_index])[0][-1] 21 | start_tr, end_tr = start_tr + min(config.STIM_DELAYS), end_tr + max(config.STIM_DELAYS) 22 | start_tr, end_tr = max(start_tr, 0), min(end_tr, lanczos_mat.shape[0] - 1) 23 | return np.arange(start_tr, end_tr + 1) 24 | 25 | class StimulusModel(): 26 | """class for constructing stimulus features 27 | """ 28 | def __init__(self, lanczos_mat, tr_stats, word_mean, device = 'cpu'): 29 | self.device = device 30 | self.lanczos_mat = torch.from_numpy(lanczos_mat).float().to(self.device) 31 | self.tr_mean = torch.from_numpy(tr_stats[0]).float().to(device) 32 | self.tr_std_inv = torch.from_numpy(np.diag(1 / tr_stats[1])).float().to(device) 33 | self.blank = torch.from_numpy(word_mean).float().to(self.device) 34 | 35 | def _downsample(self, variants): 36 | """downsamples word embeddings to TR embeddings for each hypothesis 37 | """ 38 | return torch.matmul(self.lanczos_mat.unsqueeze(0), variants) 39 | 40 | def _normalize(self, tr_variants): 41 | """normalize TR embeddings for each hypothesis 42 | """ 43 | centered = tr_variants - self.tr_mean 44 | return torch.matmul(centered, self.tr_std_inv) 45 | 46 | def _delay(self, tr_variants, n_vars, n_feats): 47 | """apply finite impulse response delays to TR embeddings 48 | """ 49 | delays = config.STIM_DELAYS 50 | n_trs = tr_variants.shape[1] 51 | del_tr_variants = torch.zeros(n_vars, n_trs, len(delays)*n_feats).to(self.device) 52 | for c, d in enumerate(delays): 53 | feat_ind_start = c * n_feats 54 | feat_ind_end = (c + 1) * n_feats 55 | del_tr_variants[:, d:, feat_ind_start:feat_ind_end] = tr_variants[:, :n_trs - d, :] 56 | return del_tr_variants 57 | 58 | def make_variants(self, sample_index, hypothesis_embs, var_embs, affected_trs): 59 | """create stimulus features for each hypothesis 60 | """ 61 | n_variants, n_feats = len(var_embs), self.blank.shape[0] 62 | with torch.no_grad(): 63 | full = self.blank.repeat(self.lanczos_mat.shape[1], 1) # word times x features 64 | full[:sample_index] = torch.tensor(np.array(hypothesis_embs)).float().reshape(-1, n_feats).to(self.device) 65 | variants = full.repeat(n_variants, 1, 1) # variants x word times x features 66 | variants[:, sample_index, :] = torch.tensor(np.array(var_embs)).float().to(self.device) 67 | tr_variants = self._normalize(self._downsample(variants)) 68 | del_tr_variants = self._delay(tr_variants, n_variants, n_feats) 69 | return del_tr_variants[:, affected_trs, :].to('cpu') 70 | 71 | class LMFeatures(): 72 | """class for extracting contextualized features of stimulus words 73 | """ 74 | def __init__(self, model, layer, context_words): 75 | self.model, self.layer, self.context_words = model, layer, context_words 76 | 77 | def extend(self, extensions, verbose = False): 78 | """outputs array of vectors corresponding to the last words of each extension 79 | """ 80 | contexts = [extension[-(self.context_words+1):] for extension in extensions] 81 | if verbose: print(contexts) 82 | context_array = self.model.get_context_array(contexts) 83 | embs = self.model.get_hidden(context_array, layer = self.layer) 84 | return embs[:, len(contexts[0]) - 1] 85 | 86 | def make_stim(self, words): 87 | """outputs matrix of features corresponding to the stimulus words 88 | """ 89 | context_array = self.model.get_story_array(words, self.context_words) 90 | embs = self.model.get_hidden(context_array, layer = self.layer) 91 | return np.vstack([embs[0, :self.context_words], 92 | embs[:context_array.shape[0] - self.context_words, self.context_words]]) -------------------------------------------------------------------------------- /decoding/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | # paths 5 | 6 | REPO_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | DATA_LM_DIR = os.path.join(REPO_DIR, "data_lm") 8 | DATA_TRAIN_DIR = os.path.join(REPO_DIR, "data_train") 9 | DATA_TEST_DIR = os.path.join(REPO_DIR, "data_test") 10 | MODEL_DIR = os.path.join(REPO_DIR, "models") 11 | RESULT_DIR = os.path.join(REPO_DIR, "results") 12 | SCORE_DIR = os.path.join(REPO_DIR, "scores") 13 | 14 | # GPT encoding model parameters 15 | 16 | TRIM = 5 17 | STIM_DELAYS = [1, 2, 3, 4] 18 | RESP_DELAYS = [-4, -3, -2, -1] 19 | ALPHAS = np.logspace(1, 3, 10) 20 | NBOOTS = 50 21 | VOXELS = 10000 22 | CHUNKLEN = 40 23 | GPT_LAYER = 9 24 | GPT_WORDS = 5 25 | 26 | # decoder parameters 27 | 28 | RANKED = True 29 | WIDTH = 200 30 | NM_ALPHA = 2/3 31 | LM_TIME = 8 32 | LM_MASS = 0.9 33 | LM_RATIO = 0.1 34 | EXTENSIONS = 5 35 | 36 | # evaluation parameters 37 | 38 | WINDOW = 20 39 | 40 | # devices 41 | 42 | GPT_DEVICE = "cuda" 43 | EM_DEVICE = "cuda" 44 | SM_DEVICE = "cuda" -------------------------------------------------------------------------------- /decoding/evaluate_predictions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import argparse 5 | 6 | import config 7 | from utils_eval import generate_null, load_transcript, windows, segment_data, WER, BLEU, METEOR, BERTSCORE 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--subject", type = str, required = True) 12 | parser.add_argument("--experiment", type = str, required = True) 13 | parser.add_argument("--task", type = str, required = True) 14 | parser.add_argument("--metrics", nargs = "+", type = str, default = ["WER", "BLEU", "METEOR", "BERT"]) 15 | parser.add_argument("--references", nargs = "+", type = str, default = []) 16 | parser.add_argument("--null", type = int, default = 10) 17 | args = parser.parse_args() 18 | 19 | if len(args.references) == 0: 20 | args.references.append(args.task) 21 | 22 | with open(os.path.join(config.DATA_TEST_DIR, "eval_segments.json"), "r") as f: 23 | eval_segments = json.load(f) 24 | 25 | # load language similarity metrics 26 | metrics = {} 27 | if "WER" in args.metrics: metrics["WER"] = WER(use_score = True) 28 | if "BLEU" in args.metrics: metrics["BLEU"] = BLEU(n = 1) 29 | if "METEOR" in args.metrics: metrics["METEOR"] = METEOR() 30 | if "BERT" in args.metrics: metrics["BERT"] = BERTSCORE( 31 | idf_sents = np.load(os.path.join(config.DATA_TEST_DIR, "idf_segments.npy")), 32 | rescale = False, 33 | score = "recall") 34 | 35 | # load prediction transcript 36 | pred_path = os.path.join(config.RESULT_DIR, args.subject, args.experiment, args.task + ".npz") 37 | pred_data = np.load(pred_path) 38 | pred_words, pred_times = pred_data["words"], pred_data["times"] 39 | 40 | # generate null sequences 41 | if args.experiment in ["imagined_speech"]: gpt_checkpoint = "imagined" 42 | else: gpt_checkpoint = "perceived" 43 | null_word_list = generate_null(pred_times, gpt_checkpoint, args.null) 44 | 45 | window_scores, window_zscores = {}, {} 46 | story_scores, story_zscores = {}, {} 47 | for reference in args.references: 48 | 49 | # load reference transcript 50 | ref_data = load_transcript(args.experiment, reference) 51 | ref_words, ref_times = ref_data["words"], ref_data["times"] 52 | 53 | # segment prediction and reference words into windows 54 | window_cutoffs = windows(*eval_segments[args.task], config.WINDOW) 55 | ref_windows = segment_data(ref_words, ref_times, window_cutoffs) 56 | pred_windows = segment_data(pred_words, pred_times, window_cutoffs) 57 | null_window_list = [segment_data(null_words, pred_times, window_cutoffs) for null_words in null_word_list] 58 | 59 | for mname, metric in metrics.items(): 60 | 61 | # get null score for each window and the entire story 62 | window_null_scores = np.array([metric.score(ref = ref_windows, pred = null_windows) 63 | for null_windows in null_window_list]) 64 | story_null_scores = window_null_scores.mean(1) 65 | 66 | # get raw score and normalized score for each window 67 | window_scores[(reference, mname)] = metric.score(ref = ref_windows, pred = pred_windows) 68 | window_zscores[(reference, mname)] = (window_scores[(reference, mname)] 69 | - window_null_scores.mean(0)) / window_null_scores.std(0) 70 | 71 | # get raw score and normalized score for the entire story 72 | story_scores[(reference, mname)] = metric.score(ref = ref_windows, pred = pred_windows) 73 | story_zscores[(reference, mname)] = (story_scores[(reference, mname)].mean() 74 | - story_null_scores.mean()) / story_null_scores.std() 75 | 76 | save_location = os.path.join(config.REPO_DIR, "scores", args.subject, args.experiment) 77 | os.makedirs(save_location, exist_ok = True) 78 | np.savez(os.path.join(save_location, args.task), 79 | window_scores = window_scores, window_zscores = window_zscores, 80 | story_scores = story_scores, story_zscores = story_zscores) -------------------------------------------------------------------------------- /decoding/run_decoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import argparse 5 | import h5py 6 | from pathlib import Path 7 | 8 | import config 9 | from GPT import GPT 10 | from Decoder import Decoder, Hypothesis 11 | from LanguageModel import LanguageModel 12 | from EncodingModel import EncodingModel 13 | from StimulusModel import StimulusModel, get_lanczos_mat, affected_trs, LMFeatures 14 | from utils_stim import predict_word_rate, predict_word_times 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--subject", type = str, required = True) 19 | parser.add_argument("--experiment", type = str, required = True) 20 | parser.add_argument("--task", type = str, required = True) 21 | args = parser.parse_args() 22 | 23 | # determine GPT checkpoint based on experiment 24 | if args.experiment in ["imagined_speech"]: gpt_checkpoint = "imagined" 25 | else: gpt_checkpoint = "perceived" 26 | 27 | # determine word rate model voxels based on experiment 28 | if args.experiment in ["imagined_speech", "perceived_movies"]: word_rate_voxels = "speech" 29 | else: word_rate_voxels = "auditory" 30 | 31 | # load responses 32 | hf = h5py.File(os.path.join(config.DATA_TEST_DIR, "test_response", args.subject, args.experiment, args.task + ".hf5"), "r") 33 | resp = np.nan_to_num(hf["data"][:]) 34 | hf.close() 35 | 36 | # load gpt 37 | with open(os.path.join(config.DATA_LM_DIR, gpt_checkpoint, "vocab.json"), "r") as f: 38 | gpt_vocab = json.load(f) 39 | with open(os.path.join(config.DATA_LM_DIR, "decoder_vocab.json"), "r") as f: 40 | decoder_vocab = json.load(f) 41 | gpt = GPT(path = os.path.join(config.DATA_LM_DIR, gpt_checkpoint, "model"), vocab = gpt_vocab, device = config.GPT_DEVICE) 42 | features = LMFeatures(model = gpt, layer = config.GPT_LAYER, context_words = config.GPT_WORDS) 43 | lm = LanguageModel(gpt, decoder_vocab, nuc_mass = config.LM_MASS, nuc_ratio = config.LM_RATIO) 44 | 45 | # load models 46 | load_location = os.path.join(config.MODEL_DIR, args.subject) 47 | word_rate_model = np.load(os.path.join(load_location, "word_rate_model_%s.npz" % word_rate_voxels), allow_pickle = True) 48 | encoding_model = np.load(os.path.join(load_location, "encoding_model_%s.npz" % gpt_checkpoint)) 49 | weights = encoding_model["weights"] 50 | noise_model = encoding_model["noise_model"] 51 | tr_stats = encoding_model["tr_stats"] 52 | word_stats = encoding_model["word_stats"] 53 | em = EncodingModel(resp, weights, encoding_model["voxels"], noise_model, device = config.EM_DEVICE) 54 | em.set_shrinkage(config.NM_ALPHA) 55 | assert args.task not in encoding_model["stories"] 56 | 57 | # predict word times 58 | word_rate = predict_word_rate(resp, word_rate_model["weights"], word_rate_model["voxels"], word_rate_model["mean_rate"]) 59 | if args.experiment == "perceived_speech": word_times, tr_times = predict_word_times(word_rate, resp, starttime = -10) 60 | else: word_times, tr_times = predict_word_times(word_rate, resp, starttime = 0) 61 | lanczos_mat = get_lanczos_mat(word_times, tr_times) 62 | 63 | # decode responses 64 | decoder = Decoder(word_times, config.WIDTH) 65 | sm = StimulusModel(lanczos_mat, tr_stats, word_stats[0], device = config.SM_DEVICE) 66 | for sample_index in range(len(word_times)): 67 | trs = affected_trs(decoder.first_difference(), sample_index, lanczos_mat) 68 | ncontext = decoder.time_window(sample_index, config.LM_TIME, floor = 5) 69 | beam_nucs = lm.beam_propose(decoder.beam, ncontext) 70 | for c, (hyp, nextensions) in enumerate(decoder.get_hypotheses()): 71 | nuc, logprobs = beam_nucs[c] 72 | if len(nuc) < 1: continue 73 | extend_words = [hyp.words + [x] for x in nuc] 74 | extend_embs = list(features.extend(extend_words)) 75 | stim = sm.make_variants(sample_index, hyp.embs, extend_embs, trs) 76 | likelihoods = em.prs(stim, trs) 77 | local_extensions = [Hypothesis(parent = hyp, extension = x) for x in zip(nuc, logprobs, extend_embs)] 78 | decoder.add_extensions(local_extensions, likelihoods, nextensions) 79 | decoder.extend(verbose = False) 80 | 81 | if args.experiment in ["perceived_movie", "perceived_multispeaker"]: decoder.word_times += 10 82 | save_location = os.path.join(config.RESULT_DIR, args.subject, args.experiment) 83 | os.makedirs(save_location, exist_ok = True) 84 | decoder.save(os.path.join(save_location, args.task)) -------------------------------------------------------------------------------- /decoding/train_EM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import argparse 5 | 6 | import config 7 | from GPT import GPT 8 | from StimulusModel import LMFeatures 9 | from utils_stim import get_stim 10 | from utils_resp import get_resp 11 | from utils_ridge.ridge import ridge, bootstrap_ridge 12 | np.random.seed(42) 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--subject", type = str, required = True) 17 | parser.add_argument("--gpt", type = str, default = "perceived") 18 | parser.add_argument("--sessions", nargs = "+", type = int, 19 | default = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20]) 20 | args = parser.parse_args() 21 | 22 | # training stories 23 | stories = [] 24 | with open(os.path.join(config.DATA_TRAIN_DIR, "sess_to_story.json"), "r") as f: 25 | sess_to_story = json.load(f) 26 | for sess in args.sessions: 27 | stories.extend(sess_to_story[str(sess)]) 28 | 29 | # load gpt 30 | with open(os.path.join(config.DATA_LM_DIR, args.gpt, "vocab.json"), "r") as f: 31 | gpt_vocab = json.load(f) 32 | gpt = GPT(path = os.path.join(config.DATA_LM_DIR, args.gpt, "model"), vocab = gpt_vocab, device = config.GPT_DEVICE) 33 | features = LMFeatures(model = gpt, layer = config.GPT_LAYER, context_words = config.GPT_WORDS) 34 | 35 | # estimate encoding model 36 | rstim, tr_stats, word_stats = get_stim(stories, features) 37 | rresp = get_resp(args.subject, stories, stack = True) 38 | nchunks = int(np.ceil(rresp.shape[0] / 5 / config.CHUNKLEN)) 39 | weights, alphas, bscorrs = bootstrap_ridge(rstim, rresp, use_corr = False, alphas = config.ALPHAS, 40 | nboots = config.NBOOTS, chunklen = config.CHUNKLEN, nchunks = nchunks) 41 | bscorrs = bscorrs.mean(2).max(0) 42 | vox = np.sort(np.argsort(bscorrs)[-config.VOXELS:]) 43 | del rstim, rresp 44 | 45 | # estimate noise model 46 | stim_dict = {story : get_stim([story], features, tr_stats = tr_stats) for story in stories} 47 | resp_dict = get_resp(args.subject, stories, stack = False, vox = vox) 48 | noise_model = np.zeros([len(vox), len(vox)]) 49 | for hstory in stories: 50 | tstim, hstim = np.vstack([stim_dict[tstory] for tstory in stories if tstory != hstory]), stim_dict[hstory] 51 | tresp, hresp = np.vstack([resp_dict[tstory] for tstory in stories if tstory != hstory]), resp_dict[hstory] 52 | bs_weights = ridge(tstim, tresp, alphas[vox]) 53 | resids = hresp - hstim.dot(bs_weights) 54 | bs_noise_model = resids.T.dot(resids) 55 | noise_model += bs_noise_model / np.diag(bs_noise_model).mean() / len(stories) 56 | del stim_dict, resp_dict 57 | 58 | # save 59 | save_location = os.path.join(config.MODEL_DIR, args.subject) 60 | os.makedirs(save_location, exist_ok = True) 61 | np.savez(os.path.join(save_location, "encoding_model_%s" % args.gpt), 62 | weights = weights, noise_model = noise_model, alphas = alphas, voxels = vox, stories = stories, 63 | tr_stats = np.array(tr_stats), word_stats = np.array(word_stats)) -------------------------------------------------------------------------------- /decoding/train_WR.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import argparse 5 | 6 | import config 7 | from utils_stim import get_story_wordseqs 8 | from utils_resp import get_resp 9 | from utils_ridge.DataSequence import DataSequence 10 | from utils_ridge.util import make_delayed 11 | from utils_ridge.ridge import bootstrap_ridge 12 | np.random.seed(42) 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--subject", type = str, required = True) 17 | parser.add_argument("--sessions", nargs = "+", type = int, 18 | default = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 20]) 19 | args = parser.parse_args() 20 | 21 | # training stories 22 | stories = [] 23 | with open(os.path.join(config.DATA_TRAIN_DIR, "sess_to_story.json"), "r") as f: 24 | sess_to_story = json.load(f) 25 | for sess in args.sessions: 26 | stories.extend(sess_to_story[str(sess)]) 27 | 28 | # ROI voxels 29 | with open(os.path.join(config.DATA_TRAIN_DIR, "ROIs", "%s.json" % args.subject), "r") as f: 30 | vox = json.load(f) 31 | 32 | # estimate word rate model 33 | save_location = os.path.join(config.MODEL_DIR, args.subject) 34 | os.makedirs(save_location, exist_ok = True) 35 | 36 | wordseqs = get_story_wordseqs(stories) 37 | rates = {} 38 | for story in stories: 39 | ds = wordseqs[story] 40 | words = DataSequence(np.ones(len(ds.data_times)), ds.split_inds, ds.data_times, ds.tr_times) 41 | rates[story] = words.chunksums("lanczos", window = 3) 42 | nz_rate = np.concatenate([rates[story][5+config.TRIM:-config.TRIM] for story in stories], axis = 0) 43 | nz_rate = np.nan_to_num(nz_rate).reshape([-1, 1]) 44 | mean_rate = np.mean(nz_rate) 45 | rate = nz_rate - mean_rate 46 | 47 | for roi in ["speech", "auditory"]: 48 | resp = get_resp(args.subject, stories, stack = True, vox = vox[roi]) 49 | delresp = make_delayed(resp, config.RESP_DELAYS) 50 | nchunks = int(np.ceil(delresp.shape[0] / 5 / config.CHUNKLEN)) 51 | weights, _, _ = bootstrap_ridge(delresp, rate, use_corr = False, 52 | alphas = config.ALPHAS, nboots = config.NBOOTS, chunklen = config.CHUNKLEN, nchunks = nchunks) 53 | np.savez(os.path.join(save_location, "word_rate_model_%s" % roi), 54 | weights = weights, mean_rate = mean_rate, voxels = vox[roi]) -------------------------------------------------------------------------------- /decoding/utils_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | 5 | import config 6 | from GPT import GPT 7 | from Decoder import Decoder, Hypothesis 8 | from LanguageModel import LanguageModel 9 | 10 | from jiwer import wer 11 | from datasets import load_metric 12 | from bert_score import BERTScorer 13 | 14 | BAD_WORDS_PERCEIVED_SPEECH = frozenset(["sentence_start", "sentence_end", "br", "lg", "ls", "ns", "sp"]) 15 | BAD_WORDS_OTHER_TASKS = frozenset(["", "sp", "uh"]) 16 | 17 | from utils_ridge.textgrid import TextGrid 18 | def load_transcript(experiment, task): 19 | if experiment in ["perceived_speech", "perceived_multispeaker"]: skip_words = BAD_WORDS_PERCEIVED_SPEECH 20 | else: skip_words = BAD_WORDS_OTHER_TASKS 21 | grid_path = os.path.join(config.DATA_TEST_DIR, "test_stimulus", experiment, task.split("_")[0] + ".TextGrid") 22 | transcript_data = {} 23 | with open(grid_path) as f: 24 | grid = TextGrid(f.read()) 25 | if experiment == "perceived_speech": transcript = grid.tiers[1].make_simple_transcript() 26 | else: transcript = grid.tiers[0].make_simple_transcript() 27 | transcript = [(float(s), float(e), w.lower()) for s, e, w in transcript if w.lower().strip("{}").strip() not in skip_words] 28 | transcript_data["words"] = np.array([x[2] for x in transcript]) 29 | transcript_data["times"] = np.array([(x[0] + x[1]) / 2 for x in transcript]) 30 | return transcript_data 31 | 32 | """windows of [duration] seconds at each time point""" 33 | def windows(start_time, end_time, duration, step = 1): 34 | start_time, end_time = int(start_time), int(end_time) 35 | half = int(duration / 2) 36 | return [(center - half, center + half) for center in range(start_time + half, end_time - half + 1) if center % step == 0] 37 | 38 | """divide [data] into list of segments defined by [cutoffs]""" 39 | def segment_data(data, times, cutoffs): 40 | return [[x for c, x in zip(times, data) if c >= start and c < end] for start, end in cutoffs] 41 | 42 | """generate null sequences with same times as predicted sequence""" 43 | def generate_null(pred_times, gpt_checkpoint, n): 44 | 45 | # load language model 46 | with open(os.path.join(config.DATA_LM_DIR, gpt_checkpoint, "vocab.json"), "r") as f: 47 | gpt_vocab = json.load(f) 48 | with open(os.path.join(config.DATA_LM_DIR, "decoder_vocab.json"), "r") as f: 49 | decoder_vocab = json.load(f) 50 | gpt = GPT(path = os.path.join(config.DATA_LM_DIR, gpt_checkpoint, "model"), vocab = gpt_vocab, device = config.GPT_DEVICE) 51 | lm = LanguageModel(gpt, decoder_vocab, nuc_mass = config.LM_MASS, nuc_ratio = config.LM_RATIO) 52 | 53 | # generate null sequences 54 | null_words = [] 55 | for _count in range(n): 56 | decoder = Decoder(pred_times, 2 * config.EXTENSIONS) 57 | for sample_index in range(len(pred_times)): 58 | ncontext = decoder.time_window(sample_index, config.LM_TIME, floor = 5) 59 | beam_nucs = lm.beam_propose(decoder.beam, ncontext) 60 | for c, (hyp, nextensions) in enumerate(decoder.get_hypotheses()): 61 | nuc, logprobs = beam_nucs[c] 62 | if len(nuc) < 1: continue 63 | extend_words = [hyp.words + [x] for x in nuc] 64 | likelihoods = np.random.random(len(nuc)) 65 | local_extensions = [Hypothesis(parent = hyp, extension = x) 66 | for x in zip(nuc, logprobs, [np.zeros(1) for _ in nuc])] 67 | decoder.add_extensions(local_extensions, likelihoods, nextensions) 68 | decoder.extend(verbose = False) 69 | null_words.append(decoder.beam[0].words) 70 | return null_words 71 | 72 | """ 73 | WER 74 | """ 75 | class WER(object): 76 | def __init__(self, use_score = True): 77 | self.use_score = use_score 78 | 79 | def score(self, ref, pred): 80 | scores = [] 81 | for ref_seg, pred_seg in zip(ref, pred): 82 | if len(ref_seg) == 0 : error = 1.0 83 | else: error = wer(ref_seg, pred_seg) 84 | if self.use_score: scores.append(1 - error) 85 | else: use_score.append(error) 86 | return np.array(scores) 87 | 88 | """ 89 | BLEU (https://aclanthology.org/P02-1040.pdf) 90 | """ 91 | class BLEU(object): 92 | def __init__(self, n = 4): 93 | self.metric = load_metric("bleu", keep_in_memory=True) 94 | self.n = n 95 | 96 | def score(self, ref, pred): 97 | results = [] 98 | for r, p in zip(ref, pred): 99 | self.metric.add_batch(predictions=[p], references=[[r]]) 100 | results.append(self.metric.compute(max_order = self.n)["bleu"]) 101 | return np.array(results) 102 | 103 | """ 104 | METEOR (https://aclanthology.org/W05-0909.pdf) 105 | """ 106 | class METEOR(object): 107 | def __init__(self): 108 | self.metric = load_metric("meteor", keep_in_memory=True) 109 | 110 | def score(self, ref, pred): 111 | results = [] 112 | ref_strings = [" ".join(x) for x in ref] 113 | pred_strings = [" ".join(x) for x in pred] 114 | for r, p in zip(ref_strings, pred_strings): 115 | self.metric.add_batch(predictions=[p], references=[r]) 116 | results.append(self.metric.compute()["meteor"]) 117 | return np.array(results) 118 | 119 | """ 120 | BERTScore (https://arxiv.org/abs/1904.09675) 121 | """ 122 | class BERTSCORE(object): 123 | def __init__(self, idf_sents=None, rescale = True, score = "f"): 124 | self.metric = BERTScorer(lang = "en", rescale_with_baseline = rescale, idf = (idf_sents is not None), idf_sents = idf_sents) 125 | if score == "precision": self.score_id = 0 126 | elif score == "recall": self.score_id = 1 127 | else: self.score_id = 2 128 | 129 | def score(self, ref, pred): 130 | ref_strings = [" ".join(x) for x in ref] 131 | pred_strings = [" ".join(x) for x in pred] 132 | return self.metric.score(cands = pred_strings, refs = ref_strings)[self.score_id].numpy() -------------------------------------------------------------------------------- /decoding/utils_resp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import h5py 4 | 5 | import config 6 | 7 | def get_resp(subject, stories, stack = True, vox = None): 8 | """loads response data 9 | """ 10 | subject_dir = os.path.join(config.DATA_TRAIN_DIR, "train_response", subject) 11 | resp = {} 12 | for story in stories: 13 | resp_path = os.path.join(subject_dir, "%s.hf5" % story) 14 | hf = h5py.File(resp_path, "r") 15 | resp[story] = np.nan_to_num(hf["data"][:]) 16 | if vox is not None: 17 | resp[story] = resp[story][:, vox] 18 | hf.close() 19 | if stack: return np.vstack([resp[story] for story in stories]) 20 | else: return resp -------------------------------------------------------------------------------- /decoding/utils_ridge/DataSequence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools as itools 3 | from utils_ridge.interpdata import sincinterp2D, gabor_xfm2D, lanczosinterp2D 4 | 5 | class DataSequence(object): 6 | """DataSequence class provides a nice interface for handling data that is both continuous 7 | and discretely chunked. For example, semantic projections of speech stimuli must be 8 | considered both at the level of single words (which are continuous throughout the stimulus) 9 | and at the level of TRs (which contain discrete chunks of words). 10 | """ 11 | def __init__(self, data, split_inds, data_times=None, tr_times=None): 12 | """Initializes the DataSequence with the given [data] object (which can be any iterable) 13 | and a collection of [split_inds], which should be the indices where the data is split into 14 | separate TR chunks. 15 | """ 16 | self.data = data 17 | self.split_inds = split_inds 18 | self.data_times = data_times 19 | self.tr_times = tr_times 20 | 21 | def mapdata(self, fun): 22 | """Creates a new DataSequence where each element of [data] is produced by mapping the 23 | function [fun] onto this DataSequence's [data]. 24 | The [split_inds] are preserved exactly. 25 | """ 26 | return DataSequence(self, map(fun, self.data), self.split_inds) 27 | 28 | def chunks(self): 29 | """Splits the stored [data] into the discrete chunks and returns them. 30 | """ 31 | return np.split(self.data, self.split_inds) 32 | 33 | def data_to_chunk_ind(self, dataind): 34 | """Returns the index of the chunk containing the data with the given index. 35 | """ 36 | zc = np.zeros((len(self.data),)) 37 | zc[dataind] = 1.0 38 | ch = np.array([ch.sum() for ch in np.split(zc, self.split_inds)]) 39 | return np.nonzero(ch)[0][0] 40 | 41 | def chunk_to_data_ind(self, chunkind): 42 | """Returns the indexes of the data contained in the chunk with the given index. 43 | """ 44 | return list(np.split(np.arange(len(self.data)), self.split_inds)[chunkind]) 45 | 46 | def chunkmeans(self): 47 | """Splits the stored [data] into the discrete chunks, then takes the mean of each chunk 48 | (this is assuming that [data] is a numpy array) and returns the resulting matrix with 49 | one row per chunk. 50 | """ 51 | dsize = self.data.shape[1] 52 | outmat = np.zeros((len(self.split_inds)+1, dsize)) 53 | for ci, c in enumerate(self.chunks()): 54 | if len(c): 55 | outmat[ci] = np.vstack(c).mean(0) 56 | 57 | return outmat 58 | 59 | def chunksums(self, interp="rect", **kwargs): 60 | """Splits the stored [data] into the discrete chunks, then takes the sum of each chunk 61 | (this is assuming that [data] is a numpy array) and returns the resulting matrix with 62 | one row per chunk. 63 | If [interp] is "sinc", the signal will be downsampled using a truncated sinc filter 64 | instead of a rectangular filter. 65 | if [interp] is "lanczos", the signal will be downsampled using a Lanczos filter. 66 | [kwargs] are passed to the interpolation function. 67 | """ 68 | if interp=="sinc": 69 | ## downsample using sinc filter 70 | return sincinterp2D(self.data, self.data_times, self.tr_times, **kwargs) 71 | elif interp=="lanczos": 72 | ## downsample using Lanczos filter 73 | return lanczosinterp2D(self.data, self.data_times, self.tr_times, **kwargs) 74 | elif interp=="gabor": 75 | ## downsample using Gabor filter 76 | return np.abs(gabor_xfm2D(self.data.T, self.data_times, self.tr_times, **kwargs)).T 77 | else: 78 | dsize = self.data.shape[1] 79 | outmat = np.zeros((len(self.split_inds)+1, dsize)) 80 | for ci, c in enumerate(self.chunks()): 81 | if len(c): 82 | outmat[ci] = np.vstack(c).sum(0) 83 | 84 | return outmat 85 | 86 | def copy(self): 87 | """Returns a copy of this DataSequence. 88 | """ 89 | return DataSequence(list(self.data), self.split_inds.copy(), self.data_times, self.tr_times) 90 | 91 | @classmethod 92 | def from_grid(cls, grid_transcript, trfile): 93 | """Creates a new DataSequence from a [grid_transript] and a [trfile]. 94 | grid_transcript should be the product of the 'make_simple_transcript' method of TextGrid. 95 | """ 96 | data_entries = list(zip(*grid_transcript))[2] 97 | if isinstance(data_entries[0], str): 98 | data = list(map(str.lower, list(zip(*grid_transcript))[2])) 99 | else: 100 | data = data_entries 101 | word_starts = np.array(list(map(float, list(zip(*grid_transcript))[0]))) 102 | word_ends = np.array(list(map(float, list(zip(*grid_transcript))[1]))) 103 | word_avgtimes = (word_starts + word_ends)/2.0 104 | 105 | tr = trfile.avgtr 106 | trtimes = trfile.get_reltriggertimes() 107 | 108 | split_inds = [(word_starts<(t+tr)).sum() for t in trtimes][:-1] 109 | return cls(data, split_inds, word_avgtimes, trtimes+tr/2.0) 110 | 111 | @classmethod 112 | def from_chunks(cls, chunks): 113 | """The inverse operation of DataSequence.chunks(), this function concatenates 114 | the [chunks] and infers split_inds. 115 | """ 116 | lens = map(len, chunks) 117 | split_inds = np.cumsum(lens)[:-1] 118 | #data = reduce(list.__add__, map(list, chunks)) ## 2.26s for 10k 6-w chunks 119 | data = list(itools.chain(*map(list, chunks))) ## 19.6ms for 10k 6-w chunks 120 | return cls(data, split_inds) -------------------------------------------------------------------------------- /decoding/utils_ridge/dsutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools as itools 3 | from utils_ridge.DataSequence import DataSequence 4 | 5 | DEFAULT_BAD_WORDS = frozenset(["sentence_start", "sentence_end", "br", "lg", "ls", "ns", "sp"]) 6 | 7 | def make_word_ds(grids, trfiles, bad_words=DEFAULT_BAD_WORDS): 8 | """Creates DataSequence objects containing the words from each grid, with any words appearing 9 | in the [bad_words] set removed. 10 | """ 11 | ds = dict() 12 | stories = grids.keys() 13 | for st in stories: 14 | grtranscript = grids[st].tiers[1].make_simple_transcript() 15 | ## Filter out bad words 16 | goodtranscript = [x for x in grtranscript 17 | if x[2].lower().strip("{}").strip() not in bad_words] 18 | d = DataSequence.from_grid(goodtranscript, trfiles[st][0]) 19 | ds[st] = d 20 | 21 | return ds 22 | 23 | def make_phoneme_ds(grids, trfiles): 24 | """Creates DataSequence objects containing the phonemes from each grid. 25 | """ 26 | ds = dict() 27 | stories = grids.keys() 28 | for st in stories: 29 | grtranscript = grids[st].tiers[0].make_simple_transcript() 30 | d = DataSequence.from_grid(grtranscript, trfiles[st][0]) 31 | ds[st] = d 32 | 33 | return ds 34 | 35 | phonemes = ['AA', 'AE','AH','AO','AW','AY','B','CH','D', 36 | 'DH', 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 37 | 'K', 'L', 'M', 'N', 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH', 38 | 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 'ZH'] 39 | 40 | def make_character_ds(grids, trfiles): 41 | ds = dict() 42 | stories = grids.keys() 43 | for st in stories: 44 | grtranscript = grids[st].tiers[2].make_simple_transcript() 45 | fixed_grtranscript = [(s,e,map(int, c.split(","))) for s,e,c in grtranscript if c] 46 | d = DataSequence.from_grid(fixed_grtranscript, trfiles[st][0]) 47 | ds[st] = d 48 | return ds 49 | 50 | def make_dialogue_ds(grids, trfiles): 51 | ds = dict() 52 | for st, gr in grids.iteritems(): 53 | grtranscript = gr.tiers[3].make_simple_transcript() 54 | fixed_grtranscript = [(s,e,c) for s,e,c in grtranscript if c] 55 | ds[st] = DataSequence.from_grid(fixed_grtranscript, trfiles[st][0]) 56 | return ds 57 | 58 | def histogram_phonemes(ds, phonemeset=phonemes): 59 | """Histograms the phonemes in the DataSequence [ds]. 60 | """ 61 | olddata = ds.data 62 | N = len(ds.data) 63 | newdata = np.zeros((N, len(phonemeset))) 64 | phind = dict(enumerate(phonemeset)) 65 | for ii,ph in enumerate(olddata): 66 | try: 67 | #ind = phonemeset.index(ph.upper().strip("0123456789")) 68 | ind = phind[ph.upper().strip("0123456789")] 69 | newdata[ii][ind] = 1 70 | except Exception as e: 71 | pass 72 | 73 | return DataSequence(newdata, ds.split_inds, ds.data_times, ds.tr_times) 74 | 75 | def histogram_phonemes2(ds, phonemeset=phonemes): 76 | """Histograms the phonemes in the DataSequence [ds]. 77 | """ 78 | olddata = np.array([ph.upper().strip("0123456789") for ph in ds.data]) 79 | newdata = np.vstack([olddata==ph for ph in phonemeset]).T 80 | return DataSequence(newdata, ds.split_inds, ds.data_times, ds.tr_times) 81 | 82 | def make_semantic_model(ds, lsasm): 83 | newdata = [] 84 | for w in ds.data: 85 | try: 86 | v = lsasm[w] 87 | except KeyError as e: 88 | v = np.zeros((lsasm.data.shape[0],)) 89 | newdata.append(v) 90 | return DataSequence(np.array(newdata), ds.split_inds, ds.data_times, ds.tr_times) 91 | 92 | def make_character_model(dss): 93 | """Make character indicator model for a dict of datasequences. 94 | """ 95 | stories = dss.keys() 96 | storychars = dict([(st,np.unique(np.hstack(ds.data))) for st,ds in dss.iteritems()]) 97 | total_chars = sum(map(len, storychars.values())) 98 | char_inds = dict() 99 | ncharsdone = 0 100 | for st in stories: 101 | char_inds[st] = dict(zip(storychars[st], range(ncharsdone, ncharsdone+len(storychars[st])))) 102 | ncharsdone += len(storychars[st]) 103 | 104 | charmodels = dict() 105 | for st,ds in dss.iteritems(): 106 | charmat = np.zeros((len(ds.data), total_chars)) 107 | for ti,charlist in enumerate(ds.data): 108 | for char in charlist: 109 | charmat[ti, char_inds[st][char]] = 1 110 | charmodels[st] = DataSequence(charmat, ds.split_inds, ds.data_times, ds.tr_times) 111 | 112 | return charmodels, char_inds 113 | 114 | def make_dialogue_model(ds): 115 | return DataSequence(np.ones((len(ds.data),1)), ds.split_inds, ds.data_times, ds.tr_times) 116 | 117 | def modulate(ds, vec): 118 | """Multiplies each row (each word/phoneme) by the corresponding value in [vec]. 119 | """ 120 | return DataSequence((ds.data.T*vec).T, ds.split_inds, ds.data_times, ds.tr_times) 121 | 122 | def catmats(*seqs): 123 | keys = seqs[0].keys() 124 | return dict([(k, DataSequence(np.hstack([s[k].data for s in seqs]), seqs[0][k].split_inds)) for k in keys]) -------------------------------------------------------------------------------- /decoding/utils_ridge/interpdata.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | logger = logging.getLogger("text.regression.interpdata") 5 | 6 | def interpdata(data, oldtime, newtime): 7 | """Interpolates the columns of [data] to find the values at [newtime], given that the current 8 | values are at [oldtime]. [oldtime] must have the same number of elements as [data] has rows. 9 | """ 10 | ## Check input sizes ## 11 | if not len(oldtime) == data.shape[0]: 12 | raise IndexError("oldtime must have same number of elements as data has rows.") 13 | 14 | ## Set up matrix to hold output ## 15 | newdata = np.empty((len(newtime), data.shape[1])) 16 | 17 | ## Interpolate each column of data ## 18 | for ci in range(data.shape[1]): 19 | if (ci%100) == 0: 20 | logger.info("Interpolating column %d/%d.." % (ci+1, data.shape[1])) 21 | 22 | newdata[:,ci] = np.interp(newtime, oldtime, data[:,ci]) 23 | 24 | ## Return interpolated data ## 25 | return newdata 26 | 27 | def sincinterp1D(data, oldtime, newtime, cutoff_mult=1.0, window=1): 28 | """Interpolates the one-dimensional signal [data] at the times given by [newtime], assuming 29 | that each sample in [data] was collected at the corresponding time in [oldtime]. Clearly, 30 | [oldtime] and [data] must have the same length, but [newtime] can have any length. 31 | 32 | This function will assume that the time points in [newtime] are evenly spaced and will use 33 | that frequency multipled by [cutoff_mult] as the cutoff frequency of the sinc filter. 34 | 35 | The sinc function will be computed with [window] lobes. With [window]=1, this will 36 | effectively compute the Lanczos filter. 37 | 38 | This is a very simplistic filtering algorithm, so will take O(N*M) time, where N is the 39 | length of [oldtime] and M is the length of [newtime]. 40 | 41 | This filter is non-causal. 42 | """ 43 | ## Find the cutoff frequency ## 44 | cutoff = 1/np.mean(np.diff(newtime)) * cutoff_mult 45 | print("Doing sinc interpolation with cutoff=%0.3f and %d lobes." % (cutoff, window)) 46 | 47 | ## Construct new signal ## 48 | newdata = np.zeros((len(newtime),1)) 49 | for ndi in range(len(newtime)): 50 | for di in range(len(oldtime)): 51 | newdata[ndi] += sincfun(cutoff, newtime[ndi]-oldtime[di], window) * data[di] 52 | return newdata 53 | 54 | def sincinterp2D(data, oldtime, newtime, cutoff_mult=1.0, window=1, causal=False, renorm=True): 55 | """Interpolates the columns of [data], assuming that the i'th row of data corresponds to 56 | oldtime(i). A new matrix with the same number of columns and a number of rows given 57 | by the length of [newtime] is returned. If [causal], only past time points will be used 58 | to computed the present value, and future time points will be ignored. 59 | 60 | The time points in [newtime] are assumed to be evenly spaced, and their frequency will 61 | be used to calculate the low-pass cutoff of the sinc interpolation filter. 62 | 63 | [window] lobes of the sinc function will be used. [window] should be an integer. 64 | """ 65 | ## Find the cutoff frequency ## 66 | cutoff = 1/np.mean(np.diff(newtime)) * cutoff_mult 67 | print("Doing sinc interpolation with cutoff=%0.3f and %d lobes." % (cutoff, window)) 68 | 69 | ## Construct new signal ## 70 | # newdata = np.zeros((len(newtime), data.shape[1])) 71 | # for ndi in range(len(newtime)): 72 | # for di in range(len(oldtime)): 73 | # newdata[ndi,:] += sincfun(cutoff, newtime[ndi]-oldtime[di], window, causal) * data[di,:] 74 | 75 | ## Build up sinc matrix ## 76 | sincmat = np.zeros((len(newtime), len(oldtime))) 77 | for ndi in range(len(newtime)): 78 | sincmat[ndi,:] = sincfun(cutoff, newtime[ndi]-oldtime, window, causal, renorm) 79 | 80 | ## Construct new signal by multiplying the sinc matrix by the data ## 81 | newdata = np.dot(sincmat, data) 82 | 83 | return newdata 84 | 85 | def lanczosinterp2D(data, oldtime, newtime, window=3, cutoff_mult=1.0, rectify=False): 86 | """Interpolates the columns of [data], assuming that the i'th row of data corresponds to 87 | oldtime(i). A new matrix with the same number of columns and a number of rows given 88 | by the length of [newtime] is returned. 89 | 90 | The time points in [newtime] are assumed to be evenly spaced, and their frequency will 91 | be used to calculate the low-pass cutoff of the interpolation filter. 92 | 93 | [window] lobes of the sinc function will be used. [window] should be an integer. 94 | """ 95 | ## Find the cutoff frequency ## 96 | cutoff = 1/np.mean(np.diff(newtime)) * cutoff_mult 97 | # print "Doing lanczos interpolation with cutoff=%0.3f and %d lobes." % (cutoff, window) 98 | 99 | ## Build up sinc matrix ## 100 | sincmat = np.zeros((len(newtime), len(oldtime))) 101 | for ndi in range(len(newtime)): 102 | sincmat[ndi,:] = lanczosfun(cutoff, newtime[ndi]-oldtime, window) 103 | 104 | if rectify: 105 | newdata = np.hstack([np.dot(sincmat, np.clip(data, -np.inf, 0)), 106 | np.dot(sincmat, np.clip(data, 0, np.inf))]) 107 | else: 108 | ## Construct new signal by multiplying the sinc matrix by the data ## 109 | newdata = np.dot(sincmat, data) 110 | 111 | return newdata 112 | 113 | def sincupinterp2D(data, oldtime, newtimes, cutoff, window=1): 114 | """Uses sinc interpolation to upsample the columns of [data], assuming that the i'th 115 | row of data comes from oldtime[i]. A new matrix with the same number of columns 116 | and a number of rows given by the length of [newtime] is returned. 117 | The times points in [oldtime] are assumed to be evenly spaced, and their frequency 118 | will be used to calculate the low-pass cutoff of the sinc interpolation filter. 119 | [window] lobes of the sinc function will be used. [window] should be an integer. 120 | Setting [window] to 1 yields a Lanczos filter. 121 | """ 122 | #cutoff = 1/np.mean(np.diff(oldtime)) 123 | print("Doing sinc interpolation with cutoff=%0.3f and %d lobes."%(cutoff, window)) 124 | 125 | sincmat = np.zeros((len(newtimes), len(oldtime))) 126 | for ndi in range(len(newtimes)): 127 | sincmat[ndi,:] = sincfun(cutoff, newtimes[ndi]-oldtime, window, False) 128 | 129 | newdata = np.dot(sincmat, data) 130 | return newdata 131 | 132 | def sincfun(B, t, window=np.inf, causal=False, renorm=True): 133 | """Compute the sinc function with some cutoff frequency [B] at some time [t]. 134 | [t] can be a scalar or any shaped numpy array. 135 | If given a [window], only the lowest-order [window] lobes of the sinc function 136 | will be non-zero. 137 | If [causal], only past values (i.e. t<0) will have non-zero weights. 138 | """ 139 | val = 2*B*np.sin(2*np.pi*B*t)/(2*np.pi*B*t+1e-20) 140 | if t.shape: 141 | val[np.abs(t)>window/(2*B)] = 0 142 | if causal: 143 | val[t<0] = 0 144 | if not np.sum(val)==0.0 and renorm: 145 | val = val/np.sum(val) 146 | elif np.abs(t)>window/(2*B): 147 | val = 0 148 | if causal and t<0: 149 | val = 0 150 | return val 151 | 152 | def lanczosfun(cutoff, t, window=3): 153 | """Compute the lanczos function with some cutoff frequency [B] at some time [t]. 154 | [t] can be a scalar or any shaped numpy array. 155 | If given a [window], only the lowest-order [window] lobes of the sinc function 156 | will be non-zero. 157 | """ 158 | t = t * cutoff 159 | val = window * np.sin(np.pi*t) * np.sin(np.pi*t/window) / (np.pi**2 * t**2) 160 | val[t==0] = 1.0 161 | val[np.abs(t)>window] = 0.0 162 | return val# / (val.sum() + 1e-10) 163 | 164 | def expinterp2D(data, oldtime, newtime, theta): 165 | intmat = np.zeros((len(newtime), len(oldtime))) 166 | for ndi in range(len(newtime)): 167 | intmat[ndi,:] = expfun(theta, newtime[ndi]-oldtime) 168 | 169 | ## Construct new signal by multiplying the sinc matrix by the data ## 170 | newdata = np.dot(intmat, data) 171 | return newdata 172 | 173 | def expfun(theta, t): 174 | """Computes an exponential weighting function for interpolation. 175 | """ 176 | val = np.exp(-t*theta) 177 | val[t<0] = 0.0 178 | if not np.sum(val)==0.0: 179 | val = val/np.sum(val) 180 | return val 181 | 182 | def gabor_xfm(data, oldtimes, newtimes, freqs, sigma): 183 | sinvals = np.vstack([np.sin(oldtimes*f*2*np.pi) for f in freqs]) 184 | cosvals = np.vstack([np.cos(oldtimes*f*2*np.pi) for f in freqs]) 185 | outvals = np.zeros((len(newtimes), len(freqs)), dtype=np.complex128) 186 | for ti,t in enumerate(newtimes): 187 | ## Build gaussian function 188 | gaussvals = np.exp(-0.5*(oldtimes-t)**2/(2*sigma**2))*data 189 | ## Take product with sin/cos vals 190 | sprod = np.dot(sinvals, gaussvals) 191 | cprod = np.dot(cosvals, gaussvals) 192 | ## Store the output 193 | outvals[ti,:] = cprod + 1j*sprod 194 | 195 | return outvals 196 | 197 | def gabor_xfm2D(ddata, oldtimes, newtimes, freqs, sigma): 198 | return np.vstack([gabor_xfm(d, oldtimes, newtimes, freqs, sigma).T for d in ddata]) 199 | 200 | def test_interp(**kwargs): 201 | """Tests sincinterp2D passing it the given [kwargs] and interpolating known signals 202 | between the two time domains. 203 | """ 204 | oldtime = np.linspace(0, 10, 100) 205 | newtime = np.linspace(0, 10, 49) 206 | data = np.zeros((4, 100)) 207 | ## The first row has a single nonzero value 208 | data[0,50] = 1.0 209 | ## The second row has a few nonzero values in a row 210 | data[1,45:55] = 1.0 211 | ## The third row has a few nonzero values separated by zeros 212 | data[2,40:45] = 1.0 213 | data[2,55:60] = 1.0 214 | ## The fourth row has different values 215 | data[3,40:45] = 1.0 216 | data[3,55:60] = 2.0 217 | 218 | ## Interpolate the data 219 | interpdata = sincinterp2D(data.T, oldtime, newtime, **kwargs).T 220 | 221 | ## Plot the results 222 | from matplotlib.pyplot import figure, show 223 | fig = figure() 224 | for d in range(4): 225 | ax = fig.add_subplot(4,1,d+1) 226 | ax.plot(newtime, interpdata[d,:], 'go-') 227 | ax.plot(oldtime, data[d,:], 'bo-') 228 | 229 | #ax.tight() 230 | show() 231 | return newtime, interpdata -------------------------------------------------------------------------------- /decoding/utils_ridge/ridge.py: -------------------------------------------------------------------------------- 1 | #import scipy 2 | from functools import reduce 3 | import numpy as np 4 | import logging 5 | from utils_ridge.utils import mult_diag, counter 6 | import random 7 | import itertools as itools 8 | 9 | zs = lambda v: (v-v.mean(0))/v.std(0) ## z-score function 10 | 11 | def ridge(stim, resp, alpha, singcutoff=1e-10, normalpha=False): 12 | """Uses ridge regression to find a linear transformation of [stim] that approximates 13 | [resp]. The regularization parameter is [alpha]. 14 | Parameters 15 | ---------- 16 | stim : array_like, shape (T, N) 17 | Stimuli with T time points and N features. 18 | resp : array_like, shape (T, M) 19 | Responses with T time points and M separate responses. 20 | alpha : float or array_like, shape (M,) 21 | Regularization parameter. Can be given as a single value (which is applied to 22 | all M responses) or separate values for each response. 23 | normalpha : boolean 24 | Whether ridge parameters should be normalized by the largest singular value of stim. Good for 25 | comparing models with different numbers of parameters. 26 | Returns 27 | ------- 28 | wt : array_like, shape (N, M) 29 | Linear regression weights. 30 | """ 31 | try: 32 | U,S,Vh = np.linalg.svd(stim, full_matrices=False) 33 | except np.linalg.LinAlgError: 34 | from text.regression.svd_dgesvd import svd_dgesvd 35 | U,S,Vh = svd_dgesvd(stim, full_matrices=False) 36 | 37 | UR = np.dot(U.T, np.nan_to_num(resp)) 38 | 39 | # Expand alpha to a collection if it's just a single value 40 | if isinstance(alpha, (float,int)): 41 | alpha = np.ones(resp.shape[1]) * alpha 42 | 43 | # Normalize alpha by the LSV norm 44 | norm = S[0] 45 | if normalpha: 46 | nalphas = alpha * norm 47 | else: 48 | nalphas = alpha 49 | 50 | # Compute weights for each alpha 51 | ualphas = np.unique(nalphas) 52 | wt = np.zeros((stim.shape[1], resp.shape[1])) 53 | for ua in ualphas: 54 | selvox = np.nonzero(nalphas==ua)[0] 55 | #awt = reduce(np.dot, [Vh.T, np.diag(S/(S**2+ua**2)), UR[:,selvox]]) 56 | awt = Vh.T.dot(np.diag(S/(S**2+ua**2))).dot(UR[:,selvox]) 57 | wt[:,selvox] = awt 58 | 59 | return wt 60 | 61 | def ridge_corr(Rstim, Pstim, Rresp, Presp, alphas, normalpha=False, dtype=np.single, corrmin=0.2, 62 | singcutoff=1e-10, use_corr=True, logger=logging.getLogger("ridge_corr")): 63 | """Uses ridge regression to find a linear transformation of [Rstim] that approximates [Rresp]. 64 | Then tests by comparing the transformation of [Pstim] to [Presp]. This procedure is repeated 65 | for each regularization parameter alpha in [alphas]. The correlation between each prediction and 66 | each response for each alpha is returned. Note that the regression weights are NOT returned. 67 | Parameters 68 | ---------- 69 | Rstim : array_like, shape (TR, N) 70 | Training stimuli with TR time points and N features. Each feature should be Z-scored across time. 71 | Pstim : array_like, shape (TP, N) 72 | Test stimuli with TP time points and N features. Each feature should be Z-scored across time. 73 | Rresp : array_like, shape (TR, M) 74 | Training responses with TR time points and M responses (voxels, neurons, what-have-you). 75 | Each response should be Z-scored across time. 76 | Presp : array_like, shape (TP, M) 77 | Test responses with TP time points and M responses. 78 | alphas : list or array_like, shape (A,) 79 | Ridge parameters to be tested. Should probably be log-spaced. np.logspace(0, 3, 20) works well. 80 | normalpha : boolean 81 | Whether ridge parameters should be normalized by the Frobenius norm of Rstim. Good for 82 | comparing models with different numbers of parameters. 83 | dtype : np.dtype 84 | All data will be cast as this dtype for computation. np.single is used by default for memory 85 | efficiency. 86 | corrmin : float in [0..1] 87 | Purely for display purposes. After each alpha is tested, the number of responses with correlation 88 | greater than corrmin minus the number of responses with correlation less than negative corrmin 89 | will be printed. For long-running regressions this vague metric of non-centered skewness can 90 | give you a rough sense of how well the model is working before it's done. 91 | singcutoff : float 92 | The first step in ridge regression is computing the singular value decomposition (SVD) of the 93 | stimulus Rstim. If Rstim is not full rank, some singular values will be approximately equal 94 | to zero and the corresponding singular vectors will be noise. These singular values/vectors 95 | should be removed both for speed (the fewer multiplications the better!) and accuracy. Any 96 | singular values less than singcutoff will be removed. 97 | use_corr : boolean 98 | If True, this function will use correlation as its metric of model fit. If False, this function 99 | will instead use variance explained (R-squared) as its metric of model fit. For ridge regression 100 | this can make a big difference -- highly regularized solutions will have very small norms and 101 | will thus explain very little variance while still leading to high correlations, as correlation 102 | is scale-free while R**2 is not. 103 | Returns 104 | ------- 105 | Rcorrs : array_like, shape (A, M) 106 | The correlation between each predicted response and each column of Presp for each alpha. 107 | 108 | """ 109 | ## Calculate SVD of stimulus matrix 110 | logger.info("Doing SVD...") 111 | try: 112 | U,S,Vh = np.linalg.svd(Rstim, full_matrices=False) 113 | except np.linalg.LinAlgError as e: 114 | logger.info("NORMAL SVD FAILED, trying more robust dgesvd..") 115 | from text.regression.svd_dgesvd import svd_dgesvd 116 | U,S,Vh = svd_dgesvd(Rstim, full_matrices=False) 117 | 118 | ## Truncate tiny singular values for speed 119 | origsize = S.shape[0] 120 | ngoodS = np.sum(S>singcutoff) 121 | nbad = origsize-ngoodS 122 | U = U[:,:ngoodS] 123 | S = S[:ngoodS] 124 | Vh = Vh[:ngoodS] 125 | logger.info("Dropped %d tiny singular values.. (U is now %s)"%(nbad, str(U.shape))) 126 | 127 | ## Normalize alpha by the Frobenius norm 128 | #frob = np.sqrt((S**2).sum()) ## Frobenius! 129 | frob = S[0] 130 | #frob = S.sum() 131 | logger.info("Training stimulus has Frobenius norm: %0.03f"%frob) 132 | if normalpha: 133 | nalphas = alphas * frob 134 | else: 135 | nalphas = alphas 136 | 137 | ## Precompute some products for speed 138 | UR = np.dot(U.T, Rresp) ## Precompute this matrix product for speed 139 | PVh = np.dot(Pstim, Vh.T) ## Precompute this matrix product for speed 140 | 141 | #Prespnorms = np.apply_along_axis(np.linalg.norm, 0, Presp) ## Precompute test response norms 142 | zPresp = zs(Presp) 143 | Prespvar = Presp.var(0) 144 | Rcorrs = [] ## Holds training correlations for each alpha 145 | for na, a in zip(nalphas, alphas): 146 | #D = np.diag(S/(S**2+a**2)) ## Reweight singular vectors by the ridge parameter 147 | D = S/(S**2+na**2) ## Reweight singular vectors by the (normalized?) ridge parameter 148 | 149 | pred = np.dot(mult_diag(D, PVh, left=False), UR) ## Best (1.75 seconds to prediction in test) 150 | # pred = np.dot(mult_diag(D, np.dot(Pstim, Vh.T), left=False), UR) ## Better (2.0 seconds to prediction in test) 151 | 152 | # pvhd = reduce(np.dot, [Pstim, Vh.T, D]) ## Pretty good (2.4 seconds to prediction in test) 153 | # pred = np.dot(pvhd, UR) 154 | 155 | # wt = reduce(np.dot, [Vh.T, D, UR]).astype(dtype) ## Bad (14.2 seconds to prediction in test) 156 | # wt = reduce(np.dot, [Vh.T, D, U.T, Rresp]).astype(dtype) ## Worst 157 | # pred = np.dot(Pstim, wt) ## Predict test responses 158 | 159 | if use_corr: 160 | #prednorms = np.apply_along_axis(np.linalg.norm, 0, pred) ## Compute predicted test response norms 161 | #Rcorr = np.array([np.corrcoef(Presp[:,ii], pred[:,ii].ravel())[0,1] for ii in range(Presp.shape[1])]) ## Slowly compute correlations 162 | #Rcorr = np.array(np.sum(np.multiply(Presp, pred), 0)).squeeze()/(prednorms*Prespnorms) ## Efficiently compute correlations 163 | Rcorr = (zPresp*zs(pred)).mean(0) 164 | else: 165 | ## Compute variance explained 166 | resvar = (Presp-pred).var(0) 167 | Rcorr = np.clip(1-(resvar/Prespvar), 0, 1) 168 | 169 | Rcorr[np.isnan(Rcorr)] = 0 170 | Rcorrs.append(Rcorr) 171 | 172 | log_template = "Training: alpha=%0.3f, mean corr=%0.5f, max corr=%0.5f, over-under(%0.2f)=%d" 173 | log_msg = log_template % (a, 174 | np.mean(Rcorr), 175 | np.max(Rcorr), 176 | corrmin, 177 | (Rcorr>corrmin).sum()-(-Rcorr>corrmin).sum()) 178 | if logger is not None: 179 | logger.info(log_msg) 180 | else: 181 | print (log_msg) 182 | 183 | return Rcorrs 184 | 185 | def bootstrap_ridge(Rstim, Rresp, alphas, nboots, chunklen, nchunks, dtype=np.single, 186 | corrmin=0.2, joined=None, singcutoff=1e-10, normalpha=False, single_alpha=False, 187 | use_corr=True, logger=logging.getLogger("ridge_corr")): 188 | """Uses ridge regression with a bootstrapped held-out set to get optimal alpha values for each response. 189 | [nchunks] random chunks of length [chunklen] will be taken from [Rstim] and [Rresp] for each regression 190 | run. [nboots] total regression runs will be performed. The best alpha value for each response will be 191 | averaged across the bootstraps to estimate the best alpha for that response. 192 | 193 | If [joined] is given, it should be a list of lists where the STRFs for all the voxels in each sublist 194 | will be given the same regularization parameter (the one that is the best on average). 195 | 196 | Parameters 197 | ---------- 198 | Rstim : array_like, shape (TR, N) 199 | Training stimuli with TR time points and N features. Each feature should be Z-scored across time. 200 | Rresp : array_like, shape (TR, M) 201 | Training responses with TR time points and M different responses (voxels, neurons, what-have-you). 202 | Each response should be Z-scored across time. 203 | alphas : list or array_like, shape (A,) 204 | Ridge parameters that will be tested. Should probably be log-spaced. np.logspace(0, 3, 20) works well. 205 | nboots : int 206 | The number of bootstrap samples to run. 15 to 30 works well. 207 | chunklen : int 208 | On each sample, the training data is broken into chunks of this length. This should be a few times 209 | longer than your delay/STRF. e.g. for a STRF with 3 delays, I use chunks of length 10. 210 | nchunks : int 211 | The number of training chunks held out to test ridge parameters for each bootstrap sample. The product 212 | of nchunks and chunklen is the total number of training samples held out for each sample, and this 213 | product should be about 20 percent of the total length of the training data. 214 | dtype : np.dtype 215 | All data will be cast as this dtype for computation. np.single is used by default for memory efficiency, 216 | as using np.double will thrash most machines on a big problem. If you want to do regression on 217 | complex variables, this should be changed to np.complex128. 218 | corrmin : float in [0..1] 219 | Purely for display purposes. After each alpha is tested for each bootstrap sample, the number of 220 | responses with correlation greater than this value will be printed. For long-running regressions this 221 | can give a rough sense of how well the model works before it's done. 222 | joined : None or list of array_like indices 223 | If you want the STRFs for two (or more) responses to be directly comparable, you need to ensure that 224 | the regularization parameter that they use is the same. To do that, supply a list of the response sets 225 | that should use the same ridge parameter here. For example, if you have four responses, joined could 226 | be [np.array([0,1]), np.array([2,3])], in which case responses 0 and 1 will use the same ridge parameter 227 | (which will be parameter that is best on average for those two), and likewise for responses 2 and 3. 228 | singcutoff : float 229 | The first step in ridge regression is computing the singular value decomposition (SVD) of the 230 | stimulus Rstim. If Rstim is not full rank, some singular values will be approximately equal 231 | to zero and the corresponding singular vectors will be noise. These singular values/vectors 232 | should be removed both for speed (the fewer multiplications the better!) and accuracy. Any 233 | singular values less than singcutoff will be removed. 234 | normalpha : boolean 235 | Whether ridge parameters (alphas) should be normalized by the Frobenius norm of Rstim. Good for rigorously 236 | comparing models with different numbers of parameters. 237 | single_alpha : boolean 238 | Whether to use a single alpha for all responses. Good for identification/decoding. 239 | use_corr : boolean 240 | If True, this function will use correlation as its metric of model fit. If False, this function 241 | will instead use variance explained (R-squared) as its metric of model fit. For ridge regression 242 | this can make a big difference -- highly regularized solutions will have very small norms and 243 | will thus explain very little variance while still leading to high correlations, as correlation 244 | is scale-free while R**2 is not. 245 | 246 | Returns 247 | ------- 248 | wt : array_like, shape (N, M) 249 | Regression weights for N features and M responses. 250 | corrs : array_like, shape (M,) 251 | Validation set correlations. Predicted responses for the validation set are obtained using the regression 252 | weights: pred = np.dot(Pstim, wt), and then the correlation between each predicted response and each 253 | column in Presp is found. 254 | alphas : array_like, shape (M,) 255 | The regularization coefficient (alpha) selected for each voxel using bootstrap cross-validation. 256 | bootstrap_corrs : array_like, shape (A, M, B) 257 | Correlation between predicted and actual responses on randomly held out portions of the training set, 258 | for each of A alphas, M voxels, and B bootstrap samples. 259 | valinds : array_like, shape (TH, B) 260 | The indices of the training data that were used as "validation" for each bootstrap sample. 261 | """ 262 | nresp, nvox = Rresp.shape 263 | bestalphas = np.zeros((nboots, nvox)) ## Will hold the best alphas for each voxel 264 | valinds = [] ## Will hold the indices into the validation data for each bootstrap 265 | 266 | Rcmats = [] 267 | for bi in counter(range(nboots), countevery=1, total=nboots): 268 | logger.info("Selecting held-out test set..") 269 | allinds = range(nresp) 270 | indchunks = list(zip(*[iter(allinds)]*chunklen)) 271 | random.shuffle(indchunks) 272 | heldinds = list(itools.chain(*indchunks[:nchunks])) 273 | notheldinds = list(set(allinds)-set(heldinds)) 274 | valinds.append(heldinds) 275 | 276 | RRstim = Rstim[notheldinds,:] 277 | PRstim = Rstim[heldinds,:] 278 | RRresp = Rresp[notheldinds,:] 279 | PRresp = Rresp[heldinds,:] 280 | 281 | ## Run ridge regression using this test set 282 | Rcmat = ridge_corr(RRstim, PRstim, RRresp, PRresp, alphas, 283 | dtype=dtype, corrmin=corrmin, singcutoff=singcutoff, 284 | normalpha=normalpha, use_corr=use_corr) 285 | 286 | Rcmats.append(Rcmat) 287 | 288 | ## Find weights for each voxel 289 | try: 290 | U,S,Vh = np.linalg.svd(Rstim, full_matrices=False) 291 | except np.linalg.LinAlgError as e: 292 | logger.info("NORMAL SVD FAILED, trying more robust dgesvd..") 293 | from text.regression.svd_dgesvd import svd_dgesvd 294 | U,S,Vh = svd_dgesvd(Rstim, full_matrices=False) 295 | 296 | ## Normalize alpha by the Frobenius norm 297 | #frob = np.sqrt((S**2).sum()) ## Frobenius! 298 | frob = S[0] 299 | #frob = S.sum() 300 | logger.info("Total training stimulus has Frobenius norm: %0.03f"%frob) 301 | if normalpha: 302 | nalphas = alphas * frob 303 | else: 304 | nalphas = alphas 305 | 306 | allRcorrs = np.dstack(Rcmats) 307 | if not single_alpha: 308 | logger.info("Finding best alpha for each response..") 309 | if joined is None: 310 | ## Find best alpha for each voxel 311 | meanbootcorrs = allRcorrs.mean(2) 312 | bestalphainds = np.argmax(meanbootcorrs, 0) 313 | valphas = nalphas[bestalphainds] 314 | else: 315 | ## Find best alpha for each group of voxels 316 | valphas = np.zeros((nvox,)) 317 | for jl in joined: 318 | jcorrs = allRcorrs[:,jl,:].mean(1).mean(1) ## Mean across voxels in the set, then mean across bootstraps 319 | bestalpha = np.argmax(jcorrs) 320 | valphas[jl] = nalphas[bestalpha] 321 | else: 322 | logger.info("Finding single best alpha..") 323 | meanbootcorr = allRcorrs.mean(2).mean(1) 324 | bestalphaind = np.argmax(meanbootcorr) 325 | bestalpha = alphas[bestalphaind] 326 | valphas = np.array([bestalpha]*nvox) 327 | logger.info("Best alpha = %0.3f"%bestalpha) 328 | 329 | logger.info("Computing weights for each response using entire training set..") 330 | UR = np.dot(U.T, np.nan_to_num(Rresp)) 331 | wt = np.zeros((Rstim.shape[1], Rresp.shape[1])) 332 | for ai,alpha in enumerate(nalphas): 333 | selvox = np.nonzero(valphas==alpha)[0] 334 | awt = reduce(np.dot, [Vh.T, np.diag(S/(S**2+alpha**2)), UR[:,selvox]]) 335 | wt[:,selvox] = awt 336 | 337 | return wt, valphas, allRcorrs -------------------------------------------------------------------------------- /decoding/utils_ridge/stimulus_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from os.path import join, dirname 4 | 5 | from utils_ridge.textgrid import TextGrid 6 | 7 | def load_textgrids(stories, data_dir: str): 8 | base = join(data_dir, "train_stimulus") 9 | grids = {} 10 | for story in stories: 11 | grid_path = os.path.join(base, "%s.TextGrid" % story) 12 | grids[story] = TextGrid(open(grid_path).read()) 13 | return grids 14 | 15 | class TRFile(object): 16 | def __init__(self, trfilename, expectedtr=2.0045): 17 | """Loads data from [trfilename], should be output from stimulus presentation code. 18 | """ 19 | self.trtimes = [] 20 | self.soundstarttime = -1 21 | self.soundstoptime = -1 22 | self.otherlabels = [] 23 | self.expectedtr = expectedtr 24 | 25 | if trfilename is not None: 26 | self.load_from_file(trfilename) 27 | 28 | 29 | def load_from_file(self, trfilename): 30 | """Loads TR data from report with given [trfilename]. 31 | """ 32 | ## Read the report file and populate the datastructure 33 | for ll in open(trfilename): 34 | timestr = ll.split()[0] 35 | label = " ".join(ll.split()[1:]) 36 | time = float(timestr) 37 | 38 | if label in ("init-trigger", "trigger"): 39 | self.trtimes.append(time) 40 | 41 | elif label=="sound-start": 42 | self.soundstarttime = time 43 | 44 | elif label=="sound-stop": 45 | self.soundstoptime = time 46 | 47 | else: 48 | self.otherlabels.append((time, label)) 49 | 50 | ## Fix weird TR times 51 | itrtimes = np.diff(self.trtimes) 52 | badtrtimes = np.nonzero(itrtimes>(itrtimes.mean()*1.5))[0] 53 | newtrs = [] 54 | for btr in badtrtimes: 55 | ## Insert new TR where it was missing.. 56 | newtrtime = self.trtimes[btr]+self.expectedtr 57 | newtrs.append((newtrtime,btr)) 58 | 59 | for ntr,btr in newtrs: 60 | self.trtimes.insert(btr+1, ntr) 61 | 62 | def simulate(self, ntrs): 63 | """Simulates [ntrs] TRs that occur at the expected TR. 64 | """ 65 | self.trtimes = list(np.arange(ntrs)*self.expectedtr) 66 | 67 | def get_reltriggertimes(self): 68 | """Returns the times of all trigger events relative to the sound. 69 | """ 70 | return np.array(self.trtimes)-self.soundstarttime 71 | 72 | @property 73 | def avgtr(self): 74 | """Returns the average TR for this run. 75 | """ 76 | return np.diff(self.trtimes).mean() 77 | 78 | def load_simulated_trfiles(respdict, tr=2.0, start_time=10.0, pad=5): 79 | trdict = dict() 80 | for story, resps in respdict.items(): 81 | trf = TRFile(None, tr) 82 | trf.soundstarttime = start_time 83 | trf.simulate(resps - pad) 84 | trdict[story] = [trf] 85 | return trdict -------------------------------------------------------------------------------- /decoding/utils_ridge/textgrid.py: -------------------------------------------------------------------------------- 1 | # Natural Language Toolkit: TextGrid analysis 2 | # 3 | # Copyright (C) 2001-2011 NLTK Project 4 | # Author: Margaret Mitchell 5 | # Steven Bird (revisions) 6 | # URL: 7 | # For license information, see LICENSE.TXT 8 | # 9 | 10 | """ 11 | Tools for reading TextGrid files, the format used by Praat. 12 | 13 | Module contents 14 | =============== 15 | 16 | The textgrid corpus reader provides 4 data items and 1 function 17 | for each textgrid file. For each tier in the file, the reader 18 | provides 10 data items and 2 functions. 19 | 20 | For the full textgrid file: 21 | 22 | - size 23 | The number of tiers in the file. 24 | 25 | - xmin 26 | First marked time of the file. 27 | 28 | - xmax 29 | Last marked time of the file. 30 | 31 | - t_time 32 | xmax - xmin. 33 | 34 | - text_type 35 | The style of TextGrid format: 36 | - ooTextFile: Organized by tier. 37 | - ChronTextFile: Organized by time. 38 | - OldooTextFile: Similar to ooTextFile. 39 | 40 | - to_chron() 41 | Convert given file to a ChronTextFile format. 42 | 43 | - to_oo() 44 | Convert given file to an ooTextFile format. 45 | 46 | For each tier: 47 | 48 | - text_type 49 | The style of TextGrid format, as above. 50 | 51 | - classid 52 | The style of transcription on this tier: 53 | - IntervalTier: Transcription is marked as intervals. 54 | - TextTier: Transcription is marked as single points. 55 | 56 | - nameid 57 | The name of the tier. 58 | 59 | - xmin 60 | First marked time of the tier. 61 | 62 | - xmax 63 | Last marked time of the tier. 64 | 65 | - size 66 | Number of entries in the tier. 67 | 68 | - transcript 69 | The raw transcript for the tier. 70 | 71 | - simple_transcript 72 | The transcript formatted as a list of tuples: (time1, time2, utterance). 73 | 74 | - tier_info 75 | List of (classid, nameid, xmin, xmax, size, transcript). 76 | 77 | - min_max() 78 | A tuple of (xmin, xmax). 79 | 80 | - time(non_speech_marker) 81 | Returns the utterance time of a given tier. 82 | Excludes entries that begin with a non-speech marker. 83 | 84 | """ 85 | 86 | # needs more cleanup, subclassing, epydoc docstrings 87 | 88 | import sys 89 | import re 90 | 91 | TEXTTIER = "TextTier" 92 | INTERVALTIER = "IntervalTier" 93 | 94 | OOTEXTFILE = re.compile(r"""(?x) 95 | xmin\ =\ (.*)[\r\n]+ 96 | xmax\ =\ (.*)[\r\n]+ 97 | [\s\S]+?size\ =\ (.*)[\r\n]+ 98 | """) 99 | 100 | CHRONTEXTFILE = re.compile(r"""(?x) 101 | [\r\n]+(\S+)\ 102 | (\S+)\ +!\ Time\ domain.\ *[\r\n]+ 103 | (\S+)\ +!\ Number\ of\ tiers.\ *[\r\n]+" 104 | """) 105 | 106 | OLDOOTEXTFILE = re.compile(r"""(?x) 107 | [\r\n]+(\S+) 108 | [\r\n]+(\S+) 109 | [\r\n]+.+[\r\n]+(\S+) 110 | """) 111 | 112 | 113 | 114 | ################################################################# 115 | # TextGrid Class 116 | ################################################################# 117 | 118 | class TextGrid(object): 119 | """ 120 | Class to manipulate the TextGrid format used by Praat. 121 | Separates each tier within this file into its own Tier 122 | object. Each TextGrid object has 123 | a number of tiers (size), xmin, xmax, a text type to help 124 | with the different styles of TextGrid format, and tiers with their 125 | own attributes. 126 | """ 127 | 128 | def __init__(self, read_file): 129 | """ 130 | Takes open read file as input, initializes attributes 131 | of the TextGrid file. 132 | @type read_file: An open TextGrid file, mode "r". 133 | @param size: Number of tiers. 134 | @param xmin: xmin. 135 | @param xmax: xmax. 136 | @param t_time: Total time of TextGrid file. 137 | @param text_type: TextGrid format. 138 | @type tiers: A list of tier objects. 139 | """ 140 | 141 | self.read_file = read_file 142 | self.size = 0 143 | self.xmin = 0 144 | self.xmax = 0 145 | self.t_time = 0 146 | self.text_type = self._check_type() 147 | self.tiers = self._find_tiers() 148 | 149 | def __iter__(self): 150 | for tier in self.tiers: 151 | yield tier 152 | 153 | def next(self): 154 | if self.idx == (self.size - 1): 155 | raise StopIteration 156 | self.idx += 1 157 | return self.tiers[self.idx] 158 | 159 | @staticmethod 160 | def load(file): 161 | """ 162 | @param file: a file in TextGrid format 163 | """ 164 | 165 | return TextGrid(open(file).read()) 166 | 167 | def _load_tiers(self, header): 168 | """ 169 | Iterates over each tier and grabs tier information. 170 | """ 171 | 172 | tiers = [] 173 | if self.text_type == "ChronTextFile": 174 | m = re.compile(header) 175 | tier_headers = m.findall(self.read_file) 176 | tier_re = " \d+.?\d* \d+.?\d*[\r\n]+\"[^\"]*\"" 177 | for i in range(0, self.size): 178 | tier_info = [tier_headers[i]] + \ 179 | re.findall(str(i + 1) + tier_re, self.read_file) 180 | tier_info = "\n".join(tier_info) 181 | tiers.append(Tier(tier_info, self.text_type, self.t_time)) 182 | return tiers 183 | 184 | tier_re = header + "[\s\S]+?(?=" + header + "|$$)" 185 | m = re.compile(tier_re) 186 | tier_iter = m.finditer(self.read_file) 187 | for iterator in tier_iter: 188 | (begin, end) = iterator.span() 189 | tier_info = self.read_file[begin:end] 190 | tiers.append(Tier(tier_info, self.text_type, self.t_time)) 191 | return tiers 192 | 193 | def _check_type(self): 194 | """ 195 | Figures out the TextGrid format. 196 | """ 197 | 198 | m = re.match("(.*)[\r\n](.*)[\r\n](.*)[\r\n](.*)", self.read_file) 199 | try: 200 | type_id = m.group(1).strip() 201 | except AttributeError: 202 | raise TypeError("Cannot read file -- try TextGrid.load()") 203 | xmin = m.group(4) 204 | if type_id == "File type = \"ooTextFile\"": 205 | if "xmin" not in xmin: 206 | text_type = "OldooTextFile" 207 | else: 208 | text_type = "ooTextFile" 209 | elif type_id == "\"Praat chronological TextGrid text file\"": 210 | text_type = "ChronTextFile" 211 | else: 212 | raise TypeError("Unknown format '(%s)'", (type_id)) 213 | return text_type 214 | 215 | def _find_tiers(self): 216 | """ 217 | Splits the textgrid file into substrings corresponding to tiers. 218 | """ 219 | 220 | if self.text_type == "ooTextFile": 221 | m = OOTEXTFILE 222 | header = " +item \[" 223 | elif self.text_type == "ChronTextFile": 224 | m = CHRONTEXTFILE 225 | header = "\"\S+\" \".*\" \d+\.?\d* \d+\.?\d*" 226 | elif self.text_type == "OldooTextFile": 227 | m = OLDOOTEXTFILE 228 | header = "\".*\"[\r\n]+\".*\"" 229 | 230 | file_info = m.findall(self.read_file)[0] 231 | self.xmin = float(file_info[0]) 232 | self.xmax = float(file_info[1]) 233 | self.t_time = self.xmax - self.xmin 234 | self.size = int(file_info[2]) 235 | tiers = self._load_tiers(header) 236 | return tiers 237 | 238 | def to_chron(self): 239 | """ 240 | @return: String in Chronological TextGrid file format. 241 | """ 242 | 243 | chron_file = "" 244 | chron_file += "\"Praat chronological TextGrid text file\"\n" 245 | chron_file += str(self.xmin) + " " + str(self.xmax) 246 | chron_file += " ! Time domain.\n" 247 | chron_file += str(self.size) + " ! Number of tiers.\n" 248 | for tier in self.tiers: 249 | idx = (self.tiers.index(tier)) + 1 250 | tier_header = "\"" + tier.classid + "\" \"" \ 251 | + tier.nameid + "\" " + str(tier.xmin) \ 252 | + " " + str(tier.xmax) 253 | chron_file += tier_header + "\n" 254 | transcript = tier.simple_transcript 255 | for (xmin, xmax, utt) in transcript: 256 | chron_file += str(idx) + " " + str(xmin) 257 | chron_file += " " + str(xmax) +"\n" 258 | chron_file += "\"" + utt + "\"\n" 259 | return chron_file 260 | 261 | def to_oo(self): 262 | """ 263 | @return: A string in OoTextGrid file format. 264 | """ 265 | 266 | oo_file = "" 267 | oo_file += "File type = \"ooTextFile\"\n" 268 | oo_file += "Object class = \"TextGrid\"\n\n" 269 | oo_file += "xmin = ", self.xmin, "\n" 270 | oo_file += "xmax = ", self.xmax, "\n" 271 | oo_file += "tiers? \n" 272 | oo_file += "size = ", self.size, "\n" 273 | oo_file += "item []:\n" 274 | for i in range(len(self.tiers)): 275 | oo_file += "%4s%s [%s]" % ("", "item", i + 1) 276 | _curr_tier = self.tiers[i] 277 | for (x, y) in _curr_tier.header: 278 | oo_file += "%8s%s = \"%s\"" % ("", x, y) 279 | if _curr_tier.classid != TEXTTIER: 280 | for (xmin, xmax, text) in _curr_tier.simple_transcript: 281 | oo_file += "%12s%s = %s" % ("", "xmin", xmin) 282 | oo_file += "%12s%s = %s" % ("", "xmax", xmax) 283 | oo_file += "%12s%s = \"%s\"" % ("", "text", text) 284 | else: 285 | for (time, mark) in _curr_tier.simple_transcript: 286 | oo_file += "%12s%s = %s" % ("", "time", time) 287 | oo_file += "%12s%s = %s" % ("", "mark", mark) 288 | return oo_file 289 | 290 | 291 | ################################################################# 292 | # Tier Class 293 | ################################################################# 294 | 295 | class Tier(object): 296 | """ 297 | A container for each tier. 298 | """ 299 | 300 | def __init__(self, tier, text_type, t_time): 301 | """ 302 | Initializes attributes of the tier: class, name, xmin, xmax 303 | size, transcript, total time. 304 | Utilizes text_type to guide how to parse the file. 305 | @type tier: a tier object; single item in the TextGrid list. 306 | @param text_type: TextGrid format 307 | @param t_time: Total time of TextGrid file. 308 | @param classid: Type of tier (point or interval). 309 | @param nameid: Name of tier. 310 | @param xmin: xmin of the tier. 311 | @param xmax: xmax of the tier. 312 | @param size: Number of entries in the tier 313 | @param transcript: The raw transcript for the tier. 314 | """ 315 | 316 | self.tier = tier 317 | self.text_type = text_type 318 | self.t_time = t_time 319 | self.classid = "" 320 | self.nameid = "" 321 | self.xmin = 0 322 | self.xmax = 0 323 | self.size = 0 324 | self.transcript = "" 325 | self.tier_info = "" 326 | self._make_info() 327 | self.simple_transcript = self.make_simple_transcript() 328 | if self.classid != TEXTTIER: 329 | self.mark_type = "intervals" 330 | else: 331 | self.mark_type = "points" 332 | self.header = [("class", self.classid), ("name", self.nameid), \ 333 | ("xmin", self.xmin), ("xmax", self.xmax), ("size", self.size)] 334 | 335 | def __iter__(self): 336 | return self 337 | 338 | def _make_info(self): 339 | """ 340 | Figures out most attributes of the tier object: 341 | class, name, xmin, xmax, transcript. 342 | """ 343 | 344 | trans = "([\S\s]*)" 345 | if self.text_type == "ChronTextFile": 346 | classid = "\"(.*)\" +" 347 | nameid = "\"(.*)\" +" 348 | xmin = "(\d+\.?\d*) +" 349 | xmax = "(\d+\.?\d*) *[\r\n]+" 350 | # No size values are given in the Chronological Text File format. 351 | self.size = None 352 | size = "" 353 | elif self.text_type == "ooTextFile": 354 | classid = " +class = \"(.*)\" *[\r\n]+" 355 | nameid = " +name = \"(.*)\" *[\r\n]+" 356 | xmin = " +xmin = (\d+\.?\d*) *[\r\n]+" 357 | xmax = " +xmax = (\d+\.?\d*) *[\r\n]+" 358 | size = " +\S+: size = (\d+) *[\r\n]+" 359 | elif self.text_type == "OldooTextFile": 360 | classid = "\"(.*)\" *[\r\n]+" 361 | nameid = "\"(.*)\" *[\r\n]+" 362 | xmin = "(\d+\.?\d*) *[\r\n]+" 363 | xmax = "(\d+\.?\d*) *[\r\n]+" 364 | size = "(\d+) *[\r\n]+" 365 | m = re.compile(classid + nameid + xmin + xmax + size + trans) 366 | self.tier_info = m.findall(self.tier)[0] 367 | self.classid = self.tier_info[0] 368 | self.nameid = self.tier_info[1] 369 | self.xmin = float(self.tier_info[2]) 370 | self.xmax = float(self.tier_info[3]) 371 | if self.size != None: 372 | self.size = int(self.tier_info[4]) 373 | self.transcript = self.tier_info[-1] 374 | 375 | def make_simple_transcript(self): 376 | """ 377 | @return: Transcript of the tier, in form [(start_time end_time label)] 378 | """ 379 | 380 | if self.text_type == "ChronTextFile": 381 | trans_head = "" 382 | trans_xmin = " (\S+)" 383 | trans_xmax = " (\S+)[\r\n]+" 384 | trans_text = "\"([\S\s]*?)\"" 385 | elif self.text_type == "ooTextFile": 386 | trans_head = " +\S+ \[\d+\]: *[\r\n]+" 387 | trans_xmin = " +\S+ = (\S+) *[\r\n]+" 388 | trans_xmax = " +\S+ = (\S+) *[\r\n]+" 389 | trans_text = " +\S+ = \"([^\"]*?)\"" 390 | elif self.text_type == "OldooTextFile": 391 | trans_head = "" 392 | trans_xmin = "(.*)[\r\n]+" 393 | trans_xmax = "(.*)[\r\n]+" 394 | trans_text = "\"([\S\s]*?)\"" 395 | if self.classid == TEXTTIER: 396 | trans_xmin = "" 397 | trans_m = re.compile(trans_head + trans_xmin + trans_xmax + trans_text) 398 | self.simple_transcript = trans_m.findall(self.transcript) 399 | return self.simple_transcript 400 | 401 | def transcript(self): 402 | """ 403 | @return: Transcript of the tier, as it appears in the file. 404 | """ 405 | 406 | return self.transcript 407 | 408 | def time(self, non_speech_char="."): 409 | """ 410 | @return: Utterance time of a given tier. 411 | Screens out entries that begin with a non-speech marker. 412 | """ 413 | 414 | total = 0.0 415 | if self.classid != TEXTTIER: 416 | for (time1, time2, utt) in self.simple_transcript: 417 | utt = utt.strip() 418 | if utt and not utt[0] == ".": 419 | total += (float(time2) - float(time1)) 420 | return total 421 | 422 | def tier_name(self): 423 | """ 424 | @return: Tier name of a given tier. 425 | """ 426 | 427 | return self.nameid 428 | 429 | def classid(self): 430 | """ 431 | @return: Type of transcription on tier. 432 | """ 433 | 434 | return self.classid 435 | 436 | def min_max(self): 437 | """ 438 | @return: (xmin, xmax) tuple for a given tier. 439 | """ 440 | 441 | return (self.xmin, self.xmax) 442 | 443 | def __repr__(self): 444 | return "<%s \"%s\" (%.2f, %.2f) %.2f%%>" % (self.classid, self.nameid, self.xmin, self.xmax, 100*self.time()/self.t_time) 445 | 446 | def __str__(self): 447 | return self.__repr__() + "\n " + "\n ".join(" ".join(row) for row in self.simple_transcript) 448 | 449 | def demo_TextGrid(demo_data): 450 | print ("** Demo of the TextGrid class. **") 451 | 452 | fid = TextGrid(demo_data) 453 | print ("Tiers:", fid.size) 454 | 455 | for i, tier in enumerate(fid): 456 | print ("\n***") 457 | print ("Tier:", i + 1) 458 | print (tier) 459 | 460 | def demo(): 461 | # Each demo demonstrates different TextGrid formats. 462 | print ("Format 1") 463 | demo_TextGrid(demo_data1) 464 | print ("\nFormat 2") 465 | demo_TextGrid(demo_data2) 466 | print ("\nFormat 3") 467 | demo_TextGrid(demo_data3) 468 | 469 | 470 | demo_data1 = """File type = "ooTextFile" 471 | Object class = "TextGrid" 472 | 473 | xmin = 0 474 | xmax = 2045.144149659864 475 | tiers? 476 | size = 3 477 | item []: 478 | item [1]: 479 | class = "IntervalTier" 480 | name = "utterances" 481 | xmin = 0 482 | xmax = 2045.144149659864 483 | intervals: size = 5 484 | intervals [1]: 485 | xmin = 0 486 | xmax = 2041.4217474125382 487 | text = "" 488 | intervals [2]: 489 | xmin = 2041.4217474125382 490 | xmax = 2041.968276643991 491 | text = "this" 492 | intervals [3]: 493 | xmin = 2041.968276643991 494 | xmax = 2042.5281632653062 495 | text = "is" 496 | intervals [4]: 497 | xmin = 2042.5281632653062 498 | xmax = 2044.0487352585324 499 | text = "a" 500 | intervals [5]: 501 | xmin = 2044.0487352585324 502 | xmax = 2045.144149659864 503 | text = "demo" 504 | item [2]: 505 | class = "TextTier" 506 | name = "notes" 507 | xmin = 0 508 | xmax = 2045.144149659864 509 | points: size = 3 510 | points [1]: 511 | time = 2041.4217474125382 512 | mark = ".begin_demo" 513 | points [2]: 514 | time = 2043.8338291031832 515 | mark = "voice gets quiet here" 516 | points [3]: 517 | time = 2045.144149659864 518 | mark = ".end_demo" 519 | item [3]: 520 | class = "IntervalTier" 521 | name = "phones" 522 | xmin = 0 523 | xmax = 2045.144149659864 524 | intervals: size = 12 525 | intervals [1]: 526 | xmin = 0 527 | xmax = 2041.4217474125382 528 | text = "" 529 | intervals [2]: 530 | xmin = 2041.4217474125382 531 | xmax = 2041.5438290324326 532 | text = "D" 533 | intervals [3]: 534 | xmin = 2041.5438290324326 535 | xmax = 2041.7321032910372 536 | text = "I" 537 | intervals [4]: 538 | xmin = 2041.7321032910372 539 | xmax = 2041.968276643991 540 | text = "s" 541 | intervals [5]: 542 | xmin = 2041.968276643991 543 | xmax = 2042.232189031843 544 | text = "I" 545 | intervals [6]: 546 | xmin = 2042.232189031843 547 | xmax = 2042.5281632653062 548 | text = "z" 549 | intervals [7]: 550 | xmin = 2042.5281632653062 551 | xmax = 2044.0487352585324 552 | text = "eI" 553 | intervals [8]: 554 | xmin = 2044.0487352585324 555 | xmax = 2044.2487352585324 556 | text = "dc" 557 | intervals [9]: 558 | xmin = 2044.2487352585324 559 | xmax = 2044.3102321849011 560 | text = "d" 561 | intervals [10]: 562 | xmin = 2044.3102321849011 563 | xmax = 2044.5748932104329 564 | text = "E" 565 | intervals [11]: 566 | xmin = 2044.5748932104329 567 | xmax = 2044.8329108578437 568 | text = "m" 569 | intervals [12]: 570 | xmin = 2044.8329108578437 571 | xmax = 2045.144149659864 572 | text = "oU" 573 | """ 574 | 575 | demo_data2 = """File type = "ooTextFile" 576 | Object class = "TextGrid" 577 | 578 | 0 579 | 2.8 580 | 581 | 2 582 | "IntervalTier" 583 | "utterances" 584 | 0 585 | 2.8 586 | 3 587 | 0 588 | 1.6229213249309031 589 | "" 590 | 1.6229213249309031 591 | 2.341428074708195 592 | "demo" 593 | 2.341428074708195 594 | 2.8 595 | "" 596 | "IntervalTier" 597 | "phones" 598 | 0 599 | 2.8 600 | 6 601 | 0 602 | 1.6229213249309031 603 | "" 604 | 1.6229213249309031 605 | 1.6428291382019483 606 | "dc" 607 | 1.6428291382019483 608 | 1.65372183721983721 609 | "d" 610 | 1.65372183721983721 611 | 1.94372874328943728 612 | "E" 613 | 1.94372874328943728 614 | 2.13821938291038210 615 | "m" 616 | 2.13821938291038210 617 | 2.341428074708195 618 | "oU" 619 | 2.341428074708195 620 | 2.8 621 | "" 622 | """ 623 | 624 | demo_data3 = """"Praat chronological TextGrid text file" 625 | 0 2.8 ! Time domain. 626 | 2 ! Number of tiers. 627 | "IntervalTier" "utterances" 0 2.8 628 | "IntervalTier" "utterances" 0 2.8 629 | 1 0 1.6229213249309031 630 | "" 631 | 2 0 1.6229213249309031 632 | "" 633 | 2 1.6229213249309031 1.6428291382019483 634 | "dc" 635 | 2 1.6428291382019483 1.65372183721983721 636 | "d" 637 | 2 1.65372183721983721 1.94372874328943728 638 | "E" 639 | 2 1.94372874328943728 2.13821938291038210 640 | "m" 641 | 2 2.13821938291038210 2.341428074708195 642 | "oU" 643 | 1 1.6229213249309031 2.341428074708195 644 | "demo" 645 | 1 2.341428074708195 2.8 646 | "" 647 | 2 2.341428074708195 2.8 648 | "" 649 | """ 650 | 651 | if __name__ == "__main__": 652 | demo() -------------------------------------------------------------------------------- /decoding/utils_ridge/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tables 3 | #from matplotlib.pyplot import figure, show 4 | import scipy.linalg 5 | 6 | def make_delayed(stim, delays, circpad=False): 7 | """Creates non-interpolated concatenated delayed versions of [stim] with the given [delays] 8 | (in samples). 9 | 10 | If [circpad], instead of being padded with zeros, [stim] will be circularly shifted. 11 | """ 12 | nt,ndim = stim.shape 13 | dstims = [] 14 | for di,d in enumerate(delays): 15 | dstim = np.zeros((nt, ndim)) 16 | if d<0: ## negative delay 17 | dstim[:d,:] = stim[-d:,:] 18 | if circpad: 19 | dstim[d:,:] = stim[:-d,:] 20 | elif d>0: 21 | dstim[d:,:] = stim[:-d,:] 22 | if circpad: 23 | dstim[:d,:] = stim[-d:,:] 24 | else: ## d==0 25 | dstim = stim.copy() 26 | dstims.append(dstim) 27 | return np.hstack(dstims) 28 | 29 | def best_corr_vec(wvec, vocab, SU, n=10): 30 | """Returns the [n] words from [vocab] most similar to the given [wvec], where each word is represented 31 | as a row in [SU]. Similarity is computed using correlation.""" 32 | wvec = wvec - np.mean(wvec) 33 | nwords = len(vocab) 34 | corrs = np.nan_to_num([np.corrcoef(wvec, SU[wi,:]-np.mean(SU[wi,:]))[1,0] for wi in range(nwords-1)]) 35 | scorrs = np.argsort(corrs) 36 | words = list(reversed([(corrs[i],vocab[i]) for i in scorrs[-n:]])) 37 | return words 38 | 39 | def get_word_prob(): 40 | """Returns the probabilities of all the words in the mechanical turk video labels. 41 | """ 42 | import constants as c 43 | import cPickle 44 | data = cPickle.load(open(c.datafile)) # Read in the words from the labels 45 | wordcount = dict() 46 | totalcount = 0 47 | for label in data: 48 | for word in label: 49 | totalcount += 1 50 | if word in wordcount: 51 | wordcount[word] += 1 52 | else: 53 | wordcount[word] = 1 54 | 55 | wordprob = dict([(word, float(wc)/totalcount) for word, wc in wordcount.items()]) 56 | return wordprob 57 | 58 | def best_prob_vec(wvec, vocab, space, wordprobs): 59 | """Orders the words by correlation with the given [wvec], but also weights the correlations by the prior 60 | probability of the word appearing in the mechanical turk video labels. 61 | """ 62 | words = best_corr_vec(wvec, vocab, space, n=len(vocab)) ## get correlations for all words 63 | ## weight correlations by the prior probability of the word in the labels 64 | weightwords = [] 65 | for wcorr,word in words: 66 | if word in wordprobs: 67 | weightwords.append((wordprobs[word]*wcorr, word)) 68 | 69 | return sorted(weightwords, key=lambda ww: ww[0]) 70 | 71 | def find_best_words(vectors, vocab, wordspace, actual, display=True, num=15): 72 | cwords = [] 73 | for si in range(len(vectors)): 74 | cw = best_corr_vec(vectors[si], vocab, wordspace, n=num) 75 | cwords.append(cw) 76 | if display: 77 | print ("Closest words to scene %d:" % si) 78 | print ([b[1] for b in cw]) 79 | print ("Actual words:") 80 | print (actual[si]) 81 | print ("") 82 | return cwords 83 | 84 | def find_best_stims_for_word(wordvector, decstims, n): 85 | """Returns a list of the indexes of the [n] stimuli in [decstims] (should be decoded stimuli) 86 | that lie closest to the vector [wordvector], which should be taken from the same space as the 87 | stimuli. 88 | """ 89 | scorrs = np.array([np.corrcoef(wordvector, ds)[0,1] for ds in decstims]) 90 | scorrs[np.isnan(scorrs)] = -1 91 | return np.argsort(scorrs)[-n:][::-1] 92 | 93 | def princomp(x, use_dgesvd=False): 94 | """Does principal components analysis on [x]. 95 | Returns coefficients, scores and latent variable values. 96 | Translated from MATLAB princomp function. Unlike the matlab princomp function, however, the 97 | rows of the returned value 'coeff' are the principal components, not the columns. 98 | """ 99 | 100 | n,p = x.shape 101 | #cx = x-np.tile(x.mean(0), (n,1)) ## column-centered x 102 | cx = x-x.mean(0) 103 | r = np.min([n-1,p]) ## maximum possible rank of cx 104 | 105 | if use_dgesvd: 106 | from svd_dgesvd import svd_dgesvd 107 | U,sigma,coeff = svd_dgesvd(cx, full_matrices=False) 108 | else: 109 | U,sigma,coeff = np.linalg.svd(cx, full_matrices=False) 110 | 111 | sigma = np.diag(sigma) 112 | score = np.dot(cx, coeff.T) 113 | sigma = sigma/np.sqrt(n-1) 114 | 115 | latent = sigma**2 116 | 117 | return coeff, score, latent 118 | 119 | def eigprincomp(x, npcs=None, norm=False, weights=None): 120 | """Does principal components analysis on [x]. 121 | Returns coefficients (eigenvectors) and eigenvalues. 122 | If given, only the [npcs] greatest eigenvectors/values will be returned. 123 | If given, the covariance matrix will be computed using [weights] on the samples. 124 | """ 125 | n,p = x.shape 126 | #cx = x-np.tile(x.mean(0), (n,1)) ## column-centered x 127 | cx = x-x.mean(0) 128 | r = np.min([n-1,p]) ## maximum possible rank of cx 129 | 130 | xcov = np.cov(cx.T) 131 | if norm: 132 | xcov /= n 133 | 134 | if npcs is not None: 135 | latent,coeff = scipy.linalg.eigh(xcov, eigvals=(p-npcs,p-1)) 136 | else: 137 | latent,coeff = np.linalg.eigh(xcov) 138 | 139 | ## Transpose coeff, reverse its rows 140 | return coeff.T[::-1], latent[::-1] 141 | 142 | def weighted_cov(x, weights=None): 143 | """If given [weights], the covariance will be computed using those weights on the samples. 144 | Otherwise the simple covariance will be returned. 145 | """ 146 | if weights is None: 147 | return np.cov(x) 148 | else: 149 | w = weights/weights.sum() ## Normalize the weights 150 | dmx = (x.T-(w*x).sum(1)).T ## Subtract the WEIGHTED mean 151 | wfact = 1/(1-(w**2).sum()) ## Compute the weighting factor 152 | return wfact*np.dot(w*dmx, dmx.T.conj()) ## Take the weighted inner product 153 | 154 | def test_weighted_cov(): 155 | """Runs a test on the weighted_cov function, creating a dataset for which the covariance is known 156 | for two different populations, and weights are used to reproduce the individual covariances. 157 | """ 158 | T = 1000 ## number of time points 159 | N = 100 ## A signals 160 | M = 100 ## B signals 161 | snr = 5 ## signal to noise ratio 162 | 163 | ## Create the two datasets 164 | siga = np.random.rand(T) 165 | noisea = np.random.rand(T, N) 166 | respa = (noisea.T+snr*siga).T 167 | 168 | sigb = np.random.rand(T) 169 | noiseb = np.random.rand(T, M) 170 | respb = (noiseb.T+snr*sigb).T 171 | 172 | ## Compute self-covariance matrixes 173 | cova = np.cov(respa) 174 | covb = np.cov(respb) 175 | 176 | ## Compute the full covariance matrix 177 | allresp = np.hstack([respa, respb]) 178 | fullcov = np.cov(allresp) 179 | 180 | ## Make weights that will recover individual covariances 181 | wta = np.ones([N+M,]) 182 | wta[N:] = 0 183 | 184 | wtb = np.ones([N+M,]) 185 | wtb[:N] = 0 186 | 187 | recova = weighted_cov(allresp, wta) 188 | recovb = weighted_cov(allresp, wtb) 189 | 190 | return locals() 191 | 192 | def fixPCs(orig, new): 193 | """Finds and fixes sign-flips in PCs by finding the coefficient with the greatest 194 | magnitude in the [orig] PCs, then negating the [new] PCs if that coefficient has 195 | a different sign. 196 | """ 197 | flipped = [] 198 | for o,n in zip(orig, new): 199 | maxind = np.abs(o).argmax() 200 | if o[maxind]*n[maxind]>0: 201 | ## Same sign, no need to flip 202 | flipped.append(n) 203 | else: 204 | ## Different sign, flip 205 | flipped.append(-n) 206 | 207 | return np.vstack(flipped) 208 | 209 | 210 | def plot_model_comparison(corrs1, corrs2, name1, name2, thresh=0.35): 211 | fig = figure(figsize=(8,8)) 212 | ax = fig.add_subplot(1,1,1) 213 | 214 | good1 = corrs1>thresh 215 | good2 = corrs2>thresh 216 | better1 = corrs1>corrs2 217 | #both = np.logical_and(good1, good2) 218 | neither = np.logical_not(np.logical_or(good1, good2)) 219 | only1 = np.logical_and(good1, better1) 220 | only2 = np.logical_and(good2, np.logical_not(better1)) 221 | 222 | ptalpha = 0.3 223 | ax.plot(corrs1[neither], corrs2[neither], 'ko', alpha=ptalpha) 224 | #ax.plot(corrs1[both], corrs2[both], 'go', alpha=ptalpha) 225 | ax.plot(corrs1[only1], corrs2[only1], 'ro', alpha=ptalpha) 226 | ax.plot(corrs1[only2], corrs2[only2], 'bo', alpha=ptalpha) 227 | 228 | lims = [-0.5, 1.0] 229 | 230 | ax.plot([thresh, thresh], [lims[0], thresh], 'r-') 231 | ax.plot([lims[0], thresh], [thresh,thresh], 'b-') 232 | 233 | ax.text(lims[0]+0.05, thresh, "$n=%d$"%np.sum(good2), horizontalalignment="left", verticalalignment="bottom") 234 | ax.text(thresh, lims[0]+0.05, "$n=%d$"%np.sum(good1), horizontalalignment="left", verticalalignment="bottom") 235 | 236 | ax.plot(lims, lims, '-', color="gray") 237 | ax.set_xlim(lims) 238 | ax.set_ylim(lims) 239 | ax.set_xlabel(name1) 240 | ax.set_ylabel(name2) 241 | 242 | show() 243 | return fig 244 | 245 | import matplotlib.colors 246 | bwr = matplotlib.colors.LinearSegmentedColormap.from_list("bwr", ((0.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 0.0, 0.0))) 247 | bkr = matplotlib.colors.LinearSegmentedColormap.from_list("bkr", ((0.0, 0.0, 1.0), (0.0, 0.0, 0.0), (1.0, 0.0, 0.0))) 248 | bgr = matplotlib.colors.LinearSegmentedColormap.from_list("bgr", ((0.0, 0.0, 1.0), (0.5, 0.5, 0.5), (1.0, 0.0, 0.0))) 249 | 250 | def plot_model_comparison2(corrFile1, corrFile2, name1, name2, thresh=0.35): 251 | fig = figure(figsize=(9,10)) 252 | #ax = fig.add_subplot(3,1,[1,2], aspect="equal") 253 | ax = fig.add_axes([0.25, 0.4, 0.6, 0.5], aspect="equal") 254 | 255 | corrs1 = tables.openFile(corrFile1).root.semcorr.read() 256 | corrs2 = tables.openFile(corrFile2).root.semcorr.read() 257 | maxcorr = np.clip(np.vstack([corrs1, corrs2]).max(0), 0, thresh)/thresh 258 | corrdiff = (corrs1-corrs2) + 0.5 259 | colors = (bgr(corrdiff).T*maxcorr).T 260 | colors[:,3] = 1.0 ## Don't scale alpha 261 | 262 | ptalpha = 0.8 263 | ax.scatter(corrs1, corrs2, s=10, c=colors, alpha=ptalpha, edgecolors="none") 264 | lims = [-0.5, 1.0] 265 | 266 | ax.plot([thresh, thresh], [lims[0], thresh], color="gray") 267 | ax.plot([lims[0], thresh], [thresh,thresh], color="gray") 268 | 269 | good1 = corrs1>thresh 270 | good2 = corrs2>thresh 271 | ax.text(lims[0]+0.05, thresh, "$n=%d$"%np.sum(good2), horizontalalignment="left", verticalalignment="bottom") 272 | ax.text(thresh, lims[0]+0.05, "$n=%d$"%np.sum(good1), horizontalalignment="left", verticalalignment="bottom") 273 | 274 | ax.plot(lims, lims, '-', color="gray") 275 | ax.set_xlim(lims) 276 | ax.set_ylim(lims) 277 | ax.set_xlabel(name1+" model") 278 | ax.set_ylabel(name2+" model") 279 | 280 | fig.canvas.draw() 281 | show() 282 | ## Add over-under comparison 283 | #ax_left = ax.get_window_extent()._bbox.x0 284 | #ax_right = ax.get_window_extent()._bbox.x1 285 | #ax_width = ax_right-ax_left 286 | #print ax_left, ax_right 287 | #ax2 = fig.add_axes([ax_left, 0.1, ax_width, 0.2]) 288 | ax2 = fig.add_axes([0.25, 0.1, 0.6, 0.25])#, sharex=ax) 289 | #ax2 = fig.add_subplot(3, 1, 3) 290 | #plot_model_overunder_comparison(corrs1, corrs2, name1, name2, thresh=thresh, ax=ax2) 291 | plot_model_histogram_comparison(corrs1, corrs2, name1, name2, thresh=thresh, ax=ax2) 292 | 293 | fig.suptitle("Model comparison: %s vs. %s"%(name1, name2)) 294 | show() 295 | return fig 296 | 297 | 298 | def plot_model_overunder_comparison(corrs1, corrs2, name1, name2, thresh=0.35, ax=None): 299 | """Plots over-under difference between two models. 300 | """ 301 | if ax is None: 302 | fig = figure(figsize=(8,8)) 303 | ax = fig.add_subplot(1,1,1) 304 | 305 | maxcorr = max(corrs1.max(), corrs2.max()) 306 | vals = np.linspace(0, maxcorr, 500) 307 | overunder = lambda c: np.array([np.sum(c>v)-np.sum(c<-v) for v in vals]) 308 | 309 | ou1 = overunder(corrs1) 310 | ou2 = overunder(corrs2) 311 | 312 | oud = ou2-ou1 313 | 314 | ax.fill_between(vals, 0, np.clip(oud, 0, 1e9), facecolor="blue") 315 | ax.fill_between(vals, 0, np.clip(oud, -1e9, 0), facecolor="red") 316 | 317 | yl = np.max(np.abs(np.array(ax.get_ylim()))) 318 | ax.plot([thresh, thresh], [-yl, yl], '-', color="gray") 319 | ax.set_ylim(-yl, yl) 320 | ax.set_xlim(0, maxcorr) 321 | ax.set_xlabel("Voxel correlation") 322 | ax.set_ylabel("%s better %s better"%(name1, name2)) 323 | 324 | show() 325 | return ax 326 | 327 | def plot_model_histogram_comparison(corrs1, corrs2, name1, name2, thresh=0.35, ax=None): 328 | """Plots over-under difference between two models. 329 | """ 330 | if ax is None: 331 | fig = figure(figsize=(8,8)) 332 | ax = fig.add_subplot(1,1,1) 333 | 334 | maxcorr = max(corrs1.max(), corrs2.max()) 335 | nbins = 100 336 | hist1 = np.histogram(corrs1, nbins, range=(-1,1)) 337 | hist2 = np.histogram(corrs2, nbins, range=(-1,1)) 338 | 339 | ouhist1 = hist1[0][nbins/2:]-hist1[0][:nbins/2][::-1] 340 | ouhist2 = hist2[0][nbins/2:]-hist2[0][:nbins/2][::-1] 341 | 342 | oud = ouhist2-ouhist1 343 | bwidth = 2.0/nbins 344 | barlefts = hist1[1][nbins/2:-1] 345 | 346 | #ax.fill_between(vals, 0, np.clip(oud, 0, 1e9), facecolor="blue") 347 | #ax.fill_between(vals, 0, np.clip(oud, -1e9, 0), facecolor="red") 348 | 349 | ax.bar(barlefts, np.clip(oud, 0, 1e9), bwidth, facecolor="blue") 350 | ax.bar(barlefts, np.clip(oud, -1e9, 0), bwidth, facecolor="red") 351 | 352 | yl = np.max(np.abs(np.array(ax.get_ylim()))) 353 | ax.plot([thresh, thresh], [-yl, yl], '-', color="gray") 354 | ax.set_ylim(-yl, yl) 355 | ax.set_xlim(0, maxcorr) 356 | ax.set_xlabel("Voxel correlation") 357 | ax.set_ylabel("%s better %s better"%(name1, name2)) 358 | 359 | show() 360 | return ax 361 | 362 | 363 | def plot_model_comparison_rois(corrs1, corrs2, name1, name2, roivoxels, roinames, thresh=0.35): 364 | """Plots model correlation comparisons per ROI. 365 | """ 366 | fig = figure() 367 | ptalpha = 0.3 368 | 369 | for ri in range(len(roinames)): 370 | ax = fig.add_subplot(4, 4, ri+1) 371 | ax.plot(corrs1[roivoxels[ri]], corrs2[roivoxels[ri]], 'bo', alpha=ptalpha) 372 | lims = [-0.3, 1.0] 373 | ax.plot(lims, lims, '-', color="gray") 374 | ax.set_xlim(lims) 375 | ax.set_ylim(lims) 376 | ax.set_title(roinames[ri]) 377 | 378 | show() 379 | return fig 380 | 381 | def save_table_file(filename, filedict): 382 | """Saves the variables in [filedict] in a hdf5 table file at [filename]. 383 | """ 384 | hf = tables.openFile(filename, mode="w", title="save_file") 385 | for vname, var in filedict.items(): 386 | hf.createArray("/", vname, var) 387 | hf.close() -------------------------------------------------------------------------------- /decoding/utils_ridge/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | #import scipy.stats 3 | import random 4 | import sys 5 | 6 | def zscore(mat, return_unzvals=False): 7 | """Z-scores the rows of [mat] by subtracting off the mean and dividing 8 | by the standard deviation. 9 | If [return_unzvals] is True, a matrix will be returned that can be used 10 | to return the z-scored values to their original state. 11 | """ 12 | zmat = np.empty(mat.shape, mat.dtype) 13 | unzvals = np.zeros((zmat.shape[0], 2), mat.dtype) 14 | for ri in range(mat.shape[0]): 15 | unzvals[ri,0] = np.std(mat[ri,:]) 16 | unzvals[ri,1] = np.mean(mat[ri,:]) 17 | zmat[ri,:] = (mat[ri,:]-unzvals[ri,1]) / (1e-10+unzvals[ri,0]) 18 | 19 | if return_unzvals: 20 | return zmat, unzvals 21 | 22 | return zmat 23 | 24 | def center(mat, return_uncvals=False): 25 | """Centers the rows of [mat] by subtracting off the mean, but doesn't 26 | divide by the SD. 27 | Can be undone like zscore. 28 | """ 29 | cmat = np.empty(mat.shape) 30 | uncvals = np.ones((mat.shape[0], 2)) 31 | for ri in range(mat.shape[0]): 32 | uncvals[ri,1] = np.mean(mat[ri,:]) 33 | cmat[ri,:] = mat[ri,:]-uncvals[ri,1] 34 | 35 | if return_uncvals: 36 | return cmat, uncvals 37 | 38 | return cmat 39 | 40 | def unzscore(mat, unzvals): 41 | """Un-Z-scores the rows of [mat] by multiplying by unzvals[:,0] (the standard deviations) 42 | and then adding unzvals[:,1] (the row means). 43 | """ 44 | unzmat = np.empty(mat.shape) 45 | for ri in range(mat.shape[0]): 46 | unzmat[ri,:] = mat[ri,:]*(1e-10+unzvals[ri,0])+unzvals[ri,1] 47 | return unzmat 48 | 49 | def gaussianize(vec): 50 | """Uses a look-up table to force the values in [vec] to be gaussian.""" 51 | ranks = np.argsort(np.argsort(vec)) 52 | cranks = (ranks+1).astype(float)/(ranks.max()+2) 53 | vals = scipy.stats.norm.isf(1-cranks) 54 | zvals = vals/vals.std() 55 | return zvals 56 | 57 | def gaussianize_mat(mat): 58 | """Gaussianizes each column of [mat].""" 59 | gmat = np.empty(mat.shape) 60 | for ri in range(mat.shape[1]): 61 | gmat[:,ri] = gaussianize(mat[:,ri]) 62 | return gmat 63 | 64 | def make_delayed(stim, delays, circpad=False): 65 | """Creates non-interpolated concatenated delayed versions of [stim] with the given [delays] 66 | (in samples). 67 | 68 | If [circpad], instead of being padded with zeros, [stim] will be circularly shifted. 69 | """ 70 | nt,ndim = stim.shape 71 | dstims = [] 72 | for di,d in enumerate(delays): 73 | dstim = np.zeros((nt, ndim)) 74 | if d<0: ## negative delay 75 | dstim[:d,:] = stim[-d:,:] 76 | if circpad: 77 | dstim[d:,:] = stim[:-d,:] 78 | elif d>0: 79 | dstim[d:,:] = stim[:-d,:] 80 | if circpad: 81 | dstim[:d,:] = stim[-d:,:] 82 | else: ## d==0 83 | dstim = stim.copy() 84 | dstims.append(dstim) 85 | return np.hstack(dstims) 86 | 87 | def mult_diag(d, mtx, left=True): 88 | """Multiply a full matrix by a diagonal matrix. 89 | This function should always be faster than dot. 90 | Input: 91 | d -- 1D (N,) array (contains the diagonal elements) 92 | mtx -- 2D (N,N) array 93 | Output: 94 | mult_diag(d, mts, left=True) == dot(diag(d), mtx) 95 | mult_diag(d, mts, left=False) == dot(mtx, diag(d)) 96 | 97 | By Pietro Berkes 98 | From http://mail.scipy.org/pipermail/numpy-discussion/2007-March/026807.html 99 | """ 100 | if left: 101 | return (d*mtx.T).T 102 | else: 103 | return d*mtx 104 | 105 | import time 106 | import logging 107 | def counter(iterable, countevery=100, total=None, logger=logging.getLogger("counter")): 108 | """Logs a status and timing update to [logger] every [countevery] draws from [iterable]. 109 | If [total] is given, log messages will include the estimated time remaining. 110 | """ 111 | start_time = time.time() 112 | 113 | ## Check if the iterable has a __len__ function, use it if no total length is supplied 114 | if total is None: 115 | if hasattr(iterable, "__len__"): 116 | total = len(iterable) 117 | 118 | for count, thing in enumerate(iterable): 119 | yield thing 120 | 121 | if not count%countevery: 122 | current_time = time.time() 123 | rate = float(count+1)/(current_time-start_time) 124 | 125 | if rate>1: ## more than 1 item/second 126 | ratestr = "%0.2f items/second"%rate 127 | else: ## less than 1 item/second 128 | ratestr = "%0.2f seconds/item"%(rate**-1) 129 | 130 | if total is not None: 131 | remitems = total-(count+1) 132 | remtime = remitems/rate 133 | timestr = ", %s remaining" % time.strftime('%H:%M:%S', time.gmtime(remtime)) 134 | itemstr = "%d/%d"%(count+1, total) 135 | else: 136 | timestr = "" 137 | itemstr = "%d"%(count+1) 138 | 139 | formatted_str = "%s items complete (%s%s)"%(itemstr,ratestr,timestr) 140 | if logger is None: 141 | print(formatted_str) 142 | else: 143 | logger.info(formatted_str) -------------------------------------------------------------------------------- /decoding/utils_stim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | 5 | import config 6 | from utils_ridge.stimulus_utils import TRFile, load_textgrids, load_simulated_trfiles 7 | from utils_ridge.dsutils import make_word_ds 8 | from utils_ridge.interpdata import lanczosinterp2D 9 | from utils_ridge.util import make_delayed 10 | 11 | def get_story_wordseqs(stories): 12 | """loads words and word times of stimulus stories 13 | """ 14 | grids = load_textgrids(stories, config.DATA_TRAIN_DIR) 15 | with open(os.path.join(config.DATA_TRAIN_DIR, "respdict.json"), "r") as f: 16 | respdict = json.load(f) 17 | trfiles = load_simulated_trfiles(respdict) 18 | wordseqs = make_word_ds(grids, trfiles) 19 | return wordseqs 20 | 21 | def get_stim(stories, features, tr_stats = None): 22 | """extract quantitative features of stimulus stories 23 | """ 24 | word_seqs = get_story_wordseqs(stories) 25 | word_vecs = {story : features.make_stim(word_seqs[story].data) for story in stories} 26 | word_mat = np.vstack([word_vecs[story] for story in stories]) 27 | word_mean, word_std = word_mat.mean(0), word_mat.std(0) 28 | 29 | ds_vecs = {story : lanczosinterp2D(word_vecs[story], word_seqs[story].data_times, word_seqs[story].tr_times) 30 | for story in stories} 31 | ds_mat = np.vstack([ds_vecs[story][5+config.TRIM:-config.TRIM] for story in stories]) 32 | if tr_stats is None: 33 | r_mean, r_std = ds_mat.mean(0), ds_mat.std(0) 34 | r_std[r_std == 0] = 1 35 | else: 36 | r_mean, r_std = tr_stats 37 | ds_mat = np.nan_to_num(np.dot((ds_mat - r_mean), np.linalg.inv(np.diag(r_std)))) 38 | del_mat = make_delayed(ds_mat, config.STIM_DELAYS) 39 | if tr_stats is None: return del_mat, (r_mean, r_std), (word_mean, word_std) 40 | else: return del_mat 41 | 42 | def predict_word_rate(resp, wt, vox, mean_rate): 43 | """predict word rate at each acquisition time 44 | """ 45 | delresp = make_delayed(resp[:, vox], config.RESP_DELAYS) 46 | rate = ((delresp.dot(wt) + mean_rate)).reshape(-1).clip(min = 0) 47 | return np.round(rate).astype(int) 48 | 49 | def predict_word_times(word_rate, resp, starttime = 0, tr = 2): 50 | """predict evenly spaced word times from word rate 51 | """ 52 | half = tr / 2 53 | trf = TRFile(None, tr) 54 | trf.soundstarttime = starttime 55 | trf.simulate(resp.shape[0]) 56 | tr_times = trf.get_reltriggertimes() + half 57 | 58 | word_times = [] 59 | for mid, num in zip(tr_times, word_rate): 60 | if num < 1: continue 61 | word_times.extend(np.linspace(mid - half, mid + half, num, endpoint = False) + half / num) 62 | return np.array(word_times), tr_times --------------------------------------------------------------------------------