├── README.md ├── Scaling_Law_Tutorial.ipynb ├── corrs.png ├── extract_speech_features.py ├── ridge_utils ├── DataSequence.py ├── SemanticModel.py ├── dsutils.py ├── interpdata.py ├── npp.py ├── ridge.py ├── stimulus_utils.py ├── textgrid.py ├── tokenization_helpers.py ├── util.py └── utils.py └── speech_model_configs.json /README.md: -------------------------------------------------------------------------------- 1 | # Encoding Model Scaling Laws 2 | Repository for the 2023 NeurIPS paper "[Scaling laws for language encoding models in fMRI](https://arxiv.org/abs/2305.11863)". 3 | 4 | ![Encoding model performance for OPT-30B](https://github.com/HuthLab/encoding-model-scaling-laws/blob/main/corrs.png) 5 | 6 | This repository provides feature extraction code, as well as encoding model features and weights from the analyses in the paper “Scaling Laws for Language Encoding Models in fMRI”. 7 | 8 | The repository uses a [Box folder](https://utexas.box.com/v/EncodingModelScalingLaws) to host larger data files, including weights, response data, and features. 9 | 10 | Please see the tutorial notebook or the boxnote for instructions on how to use the provided data. If you use this repository or any derivatives, please cite our paper: 11 | 12 | ``` 13 | @article{antonello2023scaling, 14 | title={Scaling laws for language encoding models in fMRI}, 15 | author={Richard J. Antonello and Aditya R. Vaidya and Alexander G. Huth}, 16 | journal={Advances in Neural Information Processing Systems}, 17 | volume={36}, 18 | year={2023} 19 | } 20 | ``` 21 | 22 | ## Speech models 23 | 24 | Feature extraction from audio-based models is not as straightforward as for LMs because audio models are usually bidirectional, and because of this we created a separate feature extraction pipeline. 25 | To maintain the causality of the features, we extract features from these models with a sliding window over the stimulus. 26 | In this paper, the stride is 0.1 s and the size is 16.1 s. 27 | At every iteration of the sliding window $[t-16.1, t]$, we select the output vector for the final "token" of the model’s output, and consider it _the_ feature vector for time $t$. 28 | This ensures that features at time $t$ are only computed given the first $t$ seconds of audio. 29 | 30 | Because feature extraction is more complex, it is broken out into a separate script: https://github.com/HuthLab/encoding-model-scaling-laws/blob/main/extract_speech_features.py . 31 | The function `extract_speech_features` implements feature extraction (with striding, etc.) for a single audio file. 32 | The rest of the script mainly handles stimulus selection and saving data. 33 | (If you want to extract features without saving them, you can import this function into another script or notebook.) 34 | 35 | Download the folder `story_data/story_audio` from the Box. 36 | This command will then extract features from whisper-tiny for all audio files (using the above sliding window parameters), and save the features to the folder `features_cnk0.1_ctx16.0/whisper-tiny`: 37 | 38 | ```bash 39 | python3 ./extract_speech_features.py --stimulus_dir story_audio/ --model whisper-tiny --chunksz 100 --contextsz 16000 --use_featext --batchsz 64 40 | ``` 41 | 42 | `chunksz` denotes the stride (in milliseconds), and the window size is `chunksz+contextsz`. 43 | `--use_featext` passes the audio to the model's `FeatureExtractor` before a forward pass. 44 | You can extract features for specific stories with the `--stories ` option. 45 | 46 | Features from each layer will be saved in a different directory. The script also saves the associated timestamps (i.e. `time[t]` is the time as which hidden state `features[t]` occurred). 47 | 48 | Here is the directory structure after running the script with `--stories wheretheressmoke` (this takes ~10 min. on a Titan Xp): 49 | ``` 50 | $ tree features_cnk0.1_ctx16.0/ 51 | 52 | features_cnk0.1_ctx16.0/ 53 | └── whisper-tiny 54 | ├── encoder.0 55 | │   └── wheretheressmoke.npz 56 | ├── encoder.1 57 | │   └── wheretheressmoke.npz 58 | ├── encoder.2 59 | │   └── wheretheressmoke.npz 60 | ├── encoder.3 61 | │   └── wheretheressmoke.npz 62 | ├── encoder.4 63 | │   └── wheretheressmoke.npz 64 | ├── wheretheressmoke.npz 65 | └── wheretheressmoke_times.npz 66 | ``` 67 | 68 | ### Using extracted speech features 69 | 70 | As with word-level features, features from speech models must be downsampled to the rate of fMRI acquisition before being using in encoding models. 71 | This code will downsample the features: 72 | 73 | ```python 74 | from pathlib import Path 75 | 76 | chunk_sz, context_sz = 0.1, 16.0 77 | model = 'whisper-tiny' 78 | 79 | base_features_path = Path(f"features_cnk{chunk_sz:0.1f}_ctx{context_sz:0.1f}/{model}") 80 | 81 | story = 'wheretheressmoke' 82 | 83 | times = np.load(base_features_path / f"{story}_times.npz")['times'][:,1] # shape: (time,) 84 | features = np.load(base_features_path / f"{story}.npz")['features'] # shape: (time, model dim.) 85 | 86 | # you will need `wordseqs` from the notebook 87 | downsampled_features = lanczosinterp2D(features, times, wordseqs[story].tr_times) 88 | ``` 89 | 90 | `downsampled_features` can then be used like features from OPT or LLaMa. 91 | (Note that features in the Box are already downsampled, so this step is not necessary.) 92 | -------------------------------------------------------------------------------- /corrs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuthLab/encoding-model-scaling-laws/8dd1126a90693787ff6d68f4c4013cf5a144c55f/corrs.png -------------------------------------------------------------------------------- /extract_speech_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Feature extraction for ASR models supported by Hugging Face. 5 | """ 6 | 7 | import argparse 8 | import collections 9 | import copy # for freezing specific parts of networks during randomization 10 | import itertools 11 | import json 12 | import os 13 | import operator 14 | from pathlib import Path 15 | import re 16 | from typing import Dict, Iterable, List, Optional 17 | 18 | import cottoncandy as cc 19 | import ipdb 20 | import numpy as np 21 | from tqdm import tqdm 22 | import torch 23 | import torchaudio 24 | from transformers import AutoModel, AutoModelForPreTraining, PreTrainedModel,\ 25 | AutoFeatureExtractor, WhisperModel 26 | 27 | try: 28 | import database_utils 29 | IS_STIMULIDB_AVAILABLE = True 30 | except: 31 | IS_STIMULIDB_AVAILABLE = False 32 | 33 | # Resample to this sample rate. 16kHz is used by most models. 34 | # TODO: programatically verify the input sample rate of each model? 35 | TARGET_SAMPLE_RATE = 16000 36 | 37 | def extract_speech_features(model: PreTrainedModel, model_config: dict, wav: torch.Tensor, 38 | chunksz_sec: float, contextsz_sec: float, 39 | num_sel_frames = 1, frame_skip = 5, sel_layers: Optional[List[int]]=None, 40 | batchsz: int = 1, 41 | return_numpy: bool = True, move_to_cpu: bool = True, 42 | disable_tqdm: bool = False, feature_extractor=None, 43 | sampling_rate: int = TARGET_SAMPLE_RATE, require_full_context: bool = False, 44 | stereo: bool = False): 45 | assert (num_sel_frames == 1), f"'num_sel_frames` must be 1 to ensure causal feature extraction, but got {num_sel_frames}. "\ 46 | "This option will be deprecated in the future." 47 | if stereo: 48 | raise NotImplementedError("stereo not implemented") 49 | else: 50 | assert wav.ndim == 1, f"input `wav` must be 1-D but got {wav.ndim}" 51 | if return_numpy: assert move_to_cpu, "'move_to_cpu' must be true if returning numpy arrays" 52 | 53 | # Whisper needs special handling 54 | is_whisper_model = isinstance(model, WhisperModel) 55 | 56 | # Compute chunks & context sizes in terms of samples & context 57 | chunksz_samples = int(chunksz_sec * sampling_rate) 58 | contextsz_samples = int(contextsz_sec * sampling_rate) 59 | 60 | # `snippet_ends` has the last (exclusive) sample for each snippet 61 | snippet_ends = [] 62 | if not require_full_context: 63 | # Add all snippets that are _less_ than the total input size 64 | # (context+chunk) 65 | snippet_ends.append(torch.arange(chunksz_samples, contextsz_samples+chunksz_samples, chunksz_samples)) 66 | 67 | # Add all snippets that are exactly the length of the requested input 68 | # (`Tensor.unfold` is basically a sliding window). 69 | if wav.shape[0] >= chunksz_samples+contextsz_samples: 70 | # `unfold` fails if `wav.shape[0]` is less than the window size. 71 | snippet_ends.append( 72 | torch.arange(wav.shape[0]).unfold(0, chunksz_samples+contextsz_samples, chunksz_samples)[:,-1]+1 73 | ) 74 | 75 | snippet_ends = torch.cat(snippet_ends, dim=0) # shape: (num_snippets,) 76 | 77 | if snippet_ends.shape[0] == 0: 78 | raise ValueError(f"No snippets possible! Stimulus is probably too short ({wav.shape[0]} samples). Consider reducing context size or setting `require_full_context=True`") 79 | 80 | # 2-D array where `[i,0]` and `[i,1]` are the start and end, respectively, 81 | # of snippet `i` in samples. Shape: (num_snippets, 2) 82 | snippet_times = torch.stack([torch.maximum(torch.zeros_like(snippet_ends), 83 | snippet_ends-(contextsz_samples+chunksz_samples)), 84 | snippet_ends], dim=1) 85 | 86 | # Remove snippets that are not long enough. (Seems easier to filter 87 | # after generating the snippet bounds than handling it above in each case) 88 | # TODO: is there any way to programatically check this in HuggingFace? 89 | # doesn't seem so (unlike s3prl). 90 | if 'min_input_length' in model_config: 91 | # this is stored originally in **samples**!!! 92 | min_length_samples = model_config['min_input_length'] 93 | elif 'win_ms' in model.config: 94 | min_length_samples = model.config['win_ms'] / 1000. * TARGET_SAMPLE_RATE 95 | 96 | snippet_times = snippet_times[(snippet_times[:,1] - snippet_times[:,0]) >= min_length_samples] 97 | snippet_times_sec = snippet_times / sampling_rate # snippet_times, but in sec. 98 | 99 | module_features = collections.defaultdict(list) 100 | out_features = [] # the final output of the model 101 | times = [] # times are shared across all layers 102 | 103 | #assert (frames_per_chunk % frame_skip) == 0, "These must be divisible" 104 | frame_len_sec = model_config['stride'] / TARGET_SAMPLE_RATE # length of an output frame (sec.) 105 | 106 | snippet_length_samples = snippet_times[:,1] - snippet_times[:,0] # shape: (num_snippets,) 107 | if require_full_context: 108 | assert all(snippet_length_samples == snippet_length_samples[0]), "uneven snippet lengths!" 109 | snippet_length_samples = snippet_length_samples[0] 110 | assert snippet_length_samples.ndim == 0 111 | 112 | # Set up the iterator over batches of snippets 113 | if require_full_context: 114 | # This case is simpler, so handle it explicitly 115 | snippet_batches = snippet_times.T.split(batchsz, dim=1) 116 | else: 117 | # First, batch the snippets that are of different lengths. 118 | snippet_batches = snippet_times.tensor_split(torch.where(snippet_length_samples.diff() != 0)[0]+1, dim=0) 119 | # Then, split any batches that are too big to fit into the given 120 | # batch size. 121 | snippet_iter = [] 122 | for batch in snippet_batches: 123 | # split, *then* transpose 124 | if batch.shape[0] > batchsz: 125 | snippet_iter += batch.T.split(batchsz,dim=1) 126 | else: 127 | snippet_iter += [batch.T] 128 | snippet_batches = snippet_iter 129 | 130 | snippet_iter = snippet_batches 131 | if not disable_tqdm: 132 | snippet_iter = tqdm(snippet_iter, desc='snippet batches', leave=False) 133 | snippet_iter = enumerate(snippet_iter) 134 | 135 | 136 | # Iterate with a sliding window. stride = chunk_sz 137 | for batch_idx, (snippet_starts, snippet_ends) in snippet_iter: 138 | if ((snippet_ends - snippet_starts) < (contextsz_samples + chunksz_samples)).any() and require_full_context: 139 | raise ValueError("This shouldn't happen with require_full_context") 140 | 141 | # If we don't have enough samples, skip this chunk. 142 | if (snippet_ends - snippet_starts < min_length_samples).any(): 143 | print('If this is true for any, then you might be losing more snippets than just the offending (too short) snippet') 144 | assert False 145 | 146 | # Construct the input waveforms for the batch 147 | batched_wav_in_list = [] 148 | for batch_snippet_idx, (snippet_start, snippet_end) in enumerate(zip(snippet_starts, snippet_ends)): 149 | # Stacking might be inefficient, so populate a pre-allocated array. 150 | #batched_wav_in[batch_snippet_idx, :] = wav[snippet_start:snippet_end] 151 | # But stacking makes variable batch size easier! 152 | batched_wav_in_list.append(wav[snippet_start:snippet_end]) 153 | batched_wav_in = torch.stack(batched_wav_in_list, dim=0) 154 | 155 | # The final batch may be incomplete if batchsz doesn't evenly divide 156 | # the number of snippets. 157 | if (snippet_starts.shape[0] != batched_wav_in.shape[0]) and (snippet_starts.shape[0] != batchsz): 158 | batched_wav_in = batched_wav_in[:snippet_starts.shape[0]] 159 | 160 | # Take the last 1 or 2 activations, and time-wise put it at the 161 | # end of chunk. 162 | output_inds = np.array([-1 - frame_skip*i for i in reversed(range(num_sel_frames))]) 163 | 164 | # Use a pre-processor if given (e.g. to normalize the waveform), and 165 | # then feed into the model. 166 | if feature_extractor is not None: 167 | # This step seems to be NOT differentiable, since the feature 168 | # extractor first converts the Tensor to a numpy array, then back 169 | # into a Tensor. 170 | # If you want to backprop through the stimulus, you might have to 171 | # re-implement the feature extraction in PyTorch (in particular, the 172 | # normalization) 173 | 174 | if stereo: raise NotImplementedError("Support handling multi-channel audio with feature extractor") 175 | # It looks like most feature extractors (e.g. 176 | # Wav2Vec2FeatureExtractor) accept mono audio (i.e. 1-dimensional), 177 | # but it's unclear if they support stereo as well. 178 | 179 | feature_extractor_kwargs = {} 180 | if is_whisper_model: 181 | # Because Whisper auto-pads all inputs to 30 sec., we'll use 182 | # the attention mask to figure out when the "last" relevant 183 | # input was. 184 | features_key = 'input_features' 185 | feature_extractor_kwargs['return_attention_mask'] = True 186 | else: 187 | features_key = 'input_values' 188 | 189 | preprocessed_snippets = feature_extractor(list(batched_wav_in.cpu().numpy()), 190 | return_tensors='pt', 191 | sampling_rate=sampling_rate, 192 | **feature_extractor_kwargs) 193 | if is_whisper_model: 194 | chunk_features = model.encoder(preprocessed_snippets[features_key].to(model.device)) 195 | 196 | # Now we need to figure out which output index to use, since 2 197 | # conv layers downsample the inputs before passing them into 198 | # the encoder's Transformer layers. We can redo the encoder's 199 | # 1-D conv's on the attention mask to find the final output that 200 | # was influenced by the snippet. 201 | contributing_outs = preprocessed_snippets.attention_mask # 1 if part of waveform, 0 otherwise. shape: (batchsz, 3000) 202 | # Taking [0] works because all snippets have the same length. 203 | # Add the dimension back for `conv1d` to work 204 | # TODO: assert that all clips are the same length? 205 | contributing_outs = contributing_outs[0].unsqueeze(0) 206 | 207 | contributing_outs = torch.nn.functional.conv1d(contributing_outs, 208 | torch.ones((1,1)+model.encoder.conv1.kernel_size).to(contributing_outs), 209 | stride=model.encoder.conv1.stride, 210 | padding=model.encoder.conv1.padding, 211 | dilation=model.encoder.conv1.dilation, 212 | groups=model.encoder.conv1.groups) 213 | # shape: (batchsz, 1500) 214 | contributing_outs = torch.nn.functional.conv1d(contributing_outs, 215 | torch.ones((1,1)+model.encoder.conv2.kernel_size).to(contributing_outs), 216 | stride=model.encoder.conv2.stride, 217 | padding=model.encoder.conv2.padding, 218 | dilation=model.encoder.conv2.dilation, 219 | groups=model.encoder.conv1.groups) 220 | 221 | final_output = contributing_outs[0].nonzero().squeeze(-1).max() 222 | else: 223 | # sampling rates must match if not using a pre-processor 224 | assert sampling_rate == TARGET_SAMPLE_RATE, f"sampling rate mismatch! {sampling_rate} != {TARGET_SAMPLE_RATE}" 225 | 226 | chunk_features = model(preprocessed_snippets[features_key].to(model.device)) 227 | else: 228 | chunk_features = model(batched_wav_in) 229 | 230 | # Make sure we have enough outputs 231 | if(chunk_features['last_hidden_state'].shape[1] < (num_sel_frames-1) * frame_skip - 1): 232 | print("Skipping:", snippet_idx, "only had", chunk_features['last_hidden_state'].shape[1], 233 | "outputs, whereas", (num_sel_frames-1) * frame_skip - 1, "were needed.") 234 | continue 235 | 236 | assert len(output_inds) == 1, "Only one output per evaluation is "\ 237 | "supported for Hugging Face (because they don't provide the downsampling rate)" 238 | 239 | if is_whisper_model: 240 | output_inds = [final_output] 241 | 242 | for out_idx, output_offset in enumerate(output_inds): 243 | times.append(torch.stack([snippet_starts, snippet_ends], dim=1)) 244 | 245 | output_representation = chunk_features['last_hidden_state'][:, output_offset, :] # shape: (batchsz, hidden_size) 246 | if move_to_cpu: output_representation = output_representation.cpu() 247 | if return_numpy: output_representation = output_representation.numpy() 248 | out_features.append(output_representation) 249 | 250 | # Collect features from individual layers 251 | # NOTE: outs['hidden_states'] might have an extra element at 252 | # the beginning for the feature extractor. 253 | # e.g. 25 "layers" --> CNN output + 24 transformer layers' output 254 | for layer_idx, layer_activations in enumerate(chunk_features['hidden_states']): 255 | # Only save layers that the user wants (if specified) 256 | if sel_layers: 257 | if layer_idx not in sel_layers: continue 258 | 259 | layer_representation = layer_activations[:, output_offset, :] # shape: (batchsz, hidden_size) 260 | if move_to_cpu: layer_representation = layer_representation.cpu() 261 | if return_numpy: layer_representation = layer_representation.numpy() # TODO: convert to numpy at the end 262 | 263 | if is_whisper_model: 264 | # Leave the option open for using decoder layers in the 265 | # future 266 | module_name = f"encoder.{layer_idx}" 267 | else: 268 | module_name = f"layer.{layer_idx}" 269 | 270 | module_features[module_name].append(layer_representation) 271 | 272 | out_features = np.concatenate(out_features, axis=0) if return_numpy else torch.cat(out_features, dim=0) # shape: (timesteps, features) 273 | module_features = {name: (np.concatenate(features, axis=0) if return_numpy else torch.cat(features, dim=0))\ 274 | for name, features in module_features.items()} 275 | 276 | assert all(features.shape[0] == out_features.shape[0] for features in module_features.values()),\ 277 | "Missing timesteps in the module activations!! (possible PyTorch bug)" 278 | times = torch.cat(times, dim=0) / TARGET_SAMPLE_RATE # convert samples --> seconds. shape: (timesteps,) 279 | if return_numpy: times = times.numpy() 280 | 281 | del chunk_features # possible memory leak. remove if unneeded 282 | return {'final_outputs': out_features, 'times': times, 283 | 'module_features': module_features} 284 | 285 | 286 | if __name__ == "__main__": 287 | parser = argparse.ArgumentParser() 288 | parser.add_argument('--stimulus_dir', type=Path, 289 | default='./processed_stimuli/', 290 | help="Directory with preprocessed stimuli wav's.") 291 | parser.add_argument('--bucket', type=str, 292 | help="Bucket to save extracted features to. If blank, save to local filesystem.") 293 | parser.add_argument('--model', type=str, required=True) 294 | parser.add_argument('--use_featext', action='store_true') 295 | parser.add_argument('--batchsz', type=int, default=1, 296 | help='Number of audio clips to evaluate at once. (Only uses one GPU.)') 297 | parser.add_argument('--chunksz', type=float, default=100, 298 | help="Divide the stimulus waveform into chunks of this many *milliseconds*.") 299 | parser.add_argument('--contextsz', type=float, default=8000, 300 | help="Use these many milliseconds as context for each chunk.") 301 | parser.add_argument('--layers', nargs='+', type=int, help="Only save the " 302 | "features from these layers. Usually doesn't speed up execution " 303 | "time, but may speed up upload time and reduce total disk usage. " 304 | "NOTE: only works with numbered layers (currently).") 305 | parser.add_argument('--full_context', action='store_true', 306 | help="Only extract the representation for a stimulus if it is as long as the feature extractor's specified context (context_sz)") 307 | parser.add_argument('--resample', action='store_true', 308 | help='Resample the stimuli to the necessary sample rate ' 309 | 'and convert stereo to mono if needed. If this flag is ' 310 | 'not supplied, an assertion will fail if either ' 311 | 'condition is not met.') 312 | parser.add_argument('--stride', type=float, 313 | help='Extract features every seconds. If using --custom_stimuli, consider changing this argument. Don\'t use this for extracting story features to train encoding models (use --chunksz instead). 0.5 is a good value.') 314 | parser.add_argument('--pad_silence', action='store_true', 315 | help='Pad short clips (less than context_sz+chunk_sz) with silence at the beginning') 316 | 317 | # Arguments for choosing stories 318 | stimulus_sel_args = parser.add_argument_group('stimulus_sel', 'Stimulus selection') 319 | stimulus_sel_args.add_argument('--sessions', nargs='+', type=str, 320 | help="Only process stories presented in these sessions." 321 | "Can be used in conjuction with --subject to take an intersection.") 322 | stimulus_sel_args.add_argument('--stories', '--stimuli', nargs='+', type=str, 323 | help="Only process the given stories." 324 | "Overrides --sessions and --subjects.") 325 | stimulus_sel_args.add_argument('--recursive', action='store_true', 326 | help='Recursively find .wav and .flac in the stimulus_dir.') 327 | stimulus_sel_args.add_argument('--custom_stimuli', type=str, 328 | help='Use custom (non-story) stimuli, stored in ' 329 | '"{stimulus_dir}/{custom_stimuli}". If this flag ' 330 | 'is not set, use story stimuli.') 331 | stimulus_sel_args.add_argument('--overwrite', action='store_true', 332 | help='Overwrite existing features (default behavior is to skip)') 333 | 334 | 335 | args = parser.parse_args() 336 | 337 | if (args.bucket is not None) and (args.bucket != ''): 338 | cci_features = cc.get_interface(args.bucket, verbose=False) 339 | print("Saving features to bucket", cci_features.bucket_name) 340 | else: 341 | cci_features = None 342 | print('Saving features to local filesystem.') 343 | 344 | model_name = args.model 345 | with open('speech_model_configs.json', 'r') as f: 346 | model_config = json.load(f)[model_name] 347 | model_hf_path = model_config['huggingface_hub'] 348 | print('Loading model', model_name, 'from the Hugging Face Hub...') 349 | model = AutoModel.from_pretrained(model_hf_path, output_hidden_states=True).cuda() 350 | feature_extractor = None 351 | if args.use_featext: 352 | feature_extractor = AutoFeatureExtractor.from_pretrained(model_hf_path) 353 | 354 | ## Stimulus selection 355 | # Using CLI arguments, find stimuli and their locations. 356 | stories = set() 357 | 358 | if args.stories is not None: 359 | stories.update(args.stories) 360 | 361 | if args.sessions is not None: 362 | assert IS_STIMULIDB_AVAILABLE, "database_utils is unavailable but is needed to access Stimuli DB" 363 | cci_stim = cc.get_interface('stimulidb', verbose=False) 364 | sess_to_story = cci_stim.download_json('sess_to_story') # IMO this should be added to database_utils 365 | 366 | for session in args.sessions: 367 | train_stories, test_story = sess_to_story[session] 368 | stories.add(test_story) 369 | for story in train_stories: 370 | stories.add(story) 371 | 372 | stimulus_dir = args.stimulus_dir 373 | assert stimulus_dir.exists(), f"Stimulus dir {str(stimulus_dir)} does not exist" 374 | assert stimulus_dir.is_dir(), f"Stimulus dir {str(stimulus_dir)} is not a directory" 375 | 376 | stimulus_paths: Dict[str, Path] = {} # map of stimulus name --> file path. We also use this as the list of stimuli 377 | 378 | if args.custom_stimuli: # optionally use non-story stimuli 379 | custom_stimuli_dir = stimulus_dir / args.custom_stimuli 380 | assert custom_stimuli_dir.exists(), f"dir {str(custom_stimuli_dir)} does not exist" 381 | stimulus_dir = custom_stimuli_dir 382 | 383 | # We haven't selected any stories yet, so just select all stories in the 384 | # stimulus directory. 385 | if len(stories) == 0: 386 | # Look for all files ending in '.flac' and '.wav'. If there are two 387 | # files with the same basename (i.e. without the suffix), then prefer 388 | # the FLAC file. 389 | if args.recursive: 390 | stimulus_glob_wav_iter = stimulus_dir.rglob('*.wav') 391 | stimulus_glob_flac_iter = stimulus_dir.rglob('*.flac') 392 | else: 393 | stimulus_glob_wav_iter = stimulus_dir.glob('*.wav') 394 | stimulus_glob_flac_iter = stimulus_dir.glob('*.flac') 395 | 396 | for stimulus_path in itertools.chain(stimulus_glob_wav_iter, stimulus_glob_flac_iter): 397 | # Use 'relative_to' to preserve directory structure when using 398 | # --recursive 399 | stimulus_name = str(stimulus_path.relative_to(stimulus_dir).with_suffix('')) 400 | # If stimulus already exists, overwrite the path with the 401 | # most recent extension 402 | stimulus_paths[stimulus_name] = stimulus_path 403 | else: 404 | for story in stories: 405 | # Find the associated sound file for each stimulus. 406 | # First extension found is preferred. 407 | for ext in ['flac', 'wav']: 408 | stimulus_path = stimulus_dir / f"{story}.{ext}" 409 | if stimulus_path.exists() and stimulus_path.is_file(): 410 | stimulus_paths[story] = stimulus_path 411 | break 412 | 413 | missing_stories = set(stories).difference(set(stimulus_paths.keys())) 414 | if len(missing_stories) > 0: 415 | raise RuntimeError(f"missing stimuli for stories: " + ' '.join(missing_stories)) 416 | 417 | assert len(stimulus_paths) > 0, "no stimuli to process!" 418 | 419 | # Make sure that all preprocessed stimuli exist and are readable. 420 | for stimulus_name, stimulus_local_path in stimulus_paths.items(): 421 | wav, sample_rate = torchaudio.load(stimulus_local_path) 422 | if not args.resample: 423 | assert wav.shape[0] == 1, f"stimulus '{stimulus_local_path}' is not mono-channel" 424 | 425 | # chunk size in seconds and samples, respectively 426 | chunksz_sec = args.chunksz / 1000. 427 | 428 | # context size in terms of chunks 429 | assert (args.contextsz % args.chunksz) == 0, "These must be divisible" 430 | contextsz_sec = args.contextsz / 1000. 431 | 432 | model_save_path = f"features_cnk{chunksz_sec:0.1f}_ctx{contextsz_sec:0.1f}/{model_name}" 433 | if args.stride: 434 | # If using a custom stride length (e.g. for snippets), store in a 435 | # separate directory. 436 | model_save_path = os.path.join(model_save_path, f"stride_{args.stride}") 437 | if args.custom_stimuli: 438 | # Save custom (non-story) stimuli in their own subdirectory 439 | model_save_path = os.path.join(model_save_path, 'custom_stimuli', args.custom_stimuli) 440 | print('Saving features to:', model_save_path) 441 | 442 | ## Feature extraction loop 443 | # Go through each stimulus and save resulting features 444 | torch.set_grad_enabled(False) # VERY important! (for memory) 445 | model.eval() 446 | # Sort stimuli alphabetically. Allows us, in theory, to resume partial/failed jobs 447 | stimulus_paths = collections.OrderedDict(sorted(stimulus_paths.items(), key=lambda x: x[0])) 448 | for stimulus_name, stimulus_local_path in tqdm(stimulus_paths.items(), desc='Processing stories'): 449 | wav, sample_rate = torchaudio.load(stimulus_local_path) 450 | if not args.resample: 451 | # Perform checks on the original waveform 452 | assert wav.shape[0] == 1, f"stimulus '{stimulus_local_path}' is not mono-channel" 453 | assert sample_rate == TARGET_SAMPLE_RATE 454 | else: 455 | # Resample & convert to mono as needed 456 | if wav.shape[0] != 1: wav = wav.mean(0, keepdims=True) # convert to mono 457 | if sample_rate != TARGET_SAMPLE_RATE: # resample to 16 kHz 458 | wav = torchaudio.functional.resample(wav, sample_rate, TARGET_SAMPLE_RATE) 459 | sample_rate = TARGET_SAMPLE_RATE 460 | 461 | wav.squeeze_(0) # shape: (num_samples,) 462 | 463 | assert sample_rate == TARGET_SAMPLE_RATE, f"Expected sample rate {TARGET_SAMPLE_RATE} but got {sample_rate}" 464 | 465 | features_save_path = os.path.join(model_save_path, stimulus_name) 466 | times_save_path = f"{features_save_path}_times" 467 | if not args.overwrite: 468 | if cci_features is None: 469 | if os.path.exists(times_save_path + '.npz'): 470 | print(f"Skipping {stimulus_name}, timestamps found at {times_save_path}") 471 | continue 472 | else: 473 | if cci_features.exists_object(times_save_path): 474 | print(f"Skipping {stimulus_name}, timestamps found at {times_save_path}") 475 | continue 476 | 477 | # Call a separate function to do the actual feature extraction 478 | extract_features_kwargs = { 479 | 'model': model, 'model_config': model_config, 480 | 'wav': wav.to(model.device), 481 | 'chunksz_sec': chunksz_sec, 'contextsz_sec': contextsz_sec, 482 | 'sel_layers': args.layers, 'feature_extractor': feature_extractor, 483 | 'require_full_context': args.full_context or args.pad_silence, 484 | 'batchsz': args.batchsz, 'return_numpy': False 485 | } 486 | 487 | if args.stride: 488 | # Set the context_sz so that the total span length (context+chunk) 489 | # is the same as in non-stride mode, and so that the chunk_sz is 490 | # the "new" stride length. 491 | extract_features_kwargs['contextsz_sec'] = chunksz_sec + contextsz_sec - args.stride 492 | extract_features_kwargs['chunksz_sec'] = args.stride 493 | 494 | if args.pad_silence: 495 | # Pad with `context_sz` sec. of silence, so that the first 496 | # (non-silence) output is at time `chunk_sz` 497 | wav = torch.cat([torch.zeros(int(extract_features_kwargs['contextsz_sec']*TARGET_SAMPLE_RATE)), wav], axis=0) 498 | extract_features_kwargs['wav'] = wav.to(model.device) 499 | 500 | extracted_features = extract_speech_features(**extract_features_kwargs) 501 | out_features, times, module_features = [extracted_features[k] for k in \ 502 | ['final_outputs', 'times', 'module_features']] 503 | del extracted_features # free up some memory after we've selected the outputs we want; maybe unnecessary 504 | 505 | # Remove the 'silence' we added at the beginning 506 | if args.pad_silence: 507 | times = torch.clip(times - extract_features_kwargs['contextsz_sec'], 0, torch.inf) 508 | assert torch.all(times >= 0), "padding is smaller than the correction (subtraction)!" 509 | assert torch.all(times[:,1] > 0), f"insufficient padding for require_full_context ! (times[times[:,1]<=0,1])" 510 | 511 | if cci_features is None: 512 | # If cottoncandy unavailable, save locally. 513 | os.makedirs(os.path.dirname(features_save_path), exist_ok=True) 514 | np.savez_compressed(features_save_path + '.npz', features=out_features.numpy()) 515 | np.savez_compressed(times_save_path + '.npz', times=times.numpy()) 516 | else: 517 | cci_features.upload_raw_array(features_save_path, out_features.numpy()) 518 | cci_features.upload_raw_array(times_save_path, times.numpy()) 519 | 520 | module_save_paths = {module: os.path.join(model_save_path, module, stimulus_name) for module in module_features.keys()} 521 | 522 | # This is the "save name" of the module (not its original name) 523 | for module_name, features in module_features.items(): 524 | features_save_path = module_save_paths[module_name] 525 | times_save_path = f"{features_save_path}_times" 526 | if cci_features is None: 527 | os.makedirs(os.path.dirname(features_save_path), exist_ok=True) 528 | np.savez_compressed(features_save_path + '.npz', features=features.numpy()) 529 | # "times" should be the same for all modules 530 | else: 531 | cci_features.upload_raw_array(features_save_path, features.numpy()) 532 | -------------------------------------------------------------------------------- /ridge_utils/DataSequence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools as itools 3 | from ridge_utils.interpdata import sincinterp2D, gabor_xfm2D, lanczosinterp2D 4 | 5 | class DataSequence(object): 6 | """DataSequence class provides a nice interface for handling data that is both continuous 7 | and discretely chunked. For example, semantic projections of speech stimuli must be 8 | considered both at the level of single words (which are continuous throughout the stimulus) 9 | and at the level of TRs (which contain discrete chunks of words). 10 | """ 11 | def __init__(self, data, split_inds, data_times=None, tr_times=None): 12 | """Initializes the DataSequence with the given [data] object (which can be any iterable) 13 | and a collection of [split_inds], which should be the indices where the data is split into 14 | separate TR chunks. 15 | """ 16 | self.data = data 17 | self.split_inds = split_inds 18 | self.data_times = data_times 19 | self.tr_times = tr_times 20 | 21 | def mapdata(self, fun): 22 | """Creates a new DataSequence where each element of [data] is produced by mapping the 23 | function [fun] onto this DataSequence's [data]. 24 | 25 | The [split_inds] are preserved exactly. 26 | """ 27 | return DataSequence(self, list(map(fun, self.data)), self.split_inds) 28 | 29 | def chunks(self): 30 | """Splits the stored [data] into the discrete chunks and returns them. 31 | """ 32 | return np.split(self.data, self.split_inds) 33 | 34 | def data_to_chunk_ind(self, dataind): 35 | """Returns the index of the chunk containing the data with the given index. 36 | """ 37 | zc = np.zeros((len(self.data),)) 38 | zc[dataind] = 1.0 39 | ch = np.array([ch.sum() for ch in np.split(zc, self.split_inds)]) 40 | return np.nonzero(ch)[0][0] 41 | 42 | def chunk_to_data_ind(self, chunkind): 43 | """Returns the indexes of the data contained in the chunk with the given index. 44 | """ 45 | return list(np.split(np.arange(len(self.data)), self.split_inds)[chunkind]) 46 | 47 | def chunkmeans(self): 48 | """Splits the stored [data] into the discrete chunks, then takes the mean of each chunk 49 | (this is assuming that [data] is a numpy array) and returns the resulting matrix with 50 | one row per chunk. 51 | """ 52 | dsize = self.data.shape[1] 53 | outmat = np.zeros((len(self.split_inds)+1, dsize)) 54 | for ci, c in enumerate(self.chunks()): 55 | if len(c): 56 | outmat[ci] = np.vstack(c).mean(0) 57 | 58 | return outmat 59 | 60 | def chunksums(self, interp="rect", **kwargs): 61 | """Splits the stored [data] into the discrete chunks, then takes the sum of each chunk 62 | (this is assuming that [data] is a numpy array) and returns the resulting matrix with 63 | one row per chunk. 64 | 65 | If [interp] is "sinc", the signal will be downsampled using a truncated sinc filter 66 | instead of a rectangular filter. 67 | 68 | if [interp] is "lanczos", the signal will be downsampled using a Lanczos filter. 69 | 70 | [kwargs] are passed to the interpolation function. 71 | """ 72 | if interp=="sinc": 73 | ## downsample using sinc filter 74 | return sincinterp2D(self.data, self.data_times, self.tr_times, **kwargs) 75 | elif interp=="lanczos": 76 | ## downsample using Lanczos filter 77 | return lanczosinterp2D(self.data, self.data_times, self.tr_times, **kwargs) 78 | elif interp=="gabor": 79 | ## downsample using Gabor filter 80 | return np.abs(gabor_xfm2D(self.data.T, self.data_times, self.tr_times, **kwargs)).T 81 | else: 82 | dsize = self.data.shape[1] 83 | outmat = np.zeros((len(self.split_inds)+1, dsize)) 84 | for ci, c in enumerate(self.chunks()): 85 | if len(c): 86 | outmat[ci] = np.vstack(c).sum(0) 87 | 88 | return outmat 89 | 90 | def copy(self): 91 | """Returns a copy of this DataSequence. 92 | """ 93 | return DataSequence(list(self.data), self.split_inds.copy(), self.data_times, self.tr_times) 94 | 95 | @classmethod 96 | def from_grid(cls, grid_transcript, trfile): 97 | """Creates a new DataSequence from a [grid_transript] and a [trfile]. 98 | grid_transcript should be the product of the 'make_simple_transcript' method of TextGrid. 99 | """ 100 | data_entries = list(zip(*grid_transcript))[2] 101 | if isinstance(data_entries[0], str): 102 | data = list(map(str.lower, list(zip(*grid_transcript))[2])) 103 | else: 104 | data = data_entries 105 | word_starts = np.array(list(map(float, list(zip(*grid_transcript))[0]))) 106 | word_ends = np.array(list(map(float, list(zip(*grid_transcript))[1]))) 107 | word_avgtimes = (word_starts + word_ends)/2.0 108 | 109 | tr = trfile.avgtr 110 | trtimes = trfile.get_reltriggertimes() 111 | 112 | split_inds = [(word_starts<(t+tr)).sum() for t in trtimes][:-1] 113 | return cls(data, split_inds, word_avgtimes, trtimes+tr/2.0) 114 | 115 | @classmethod 116 | def from_chunks(cls, chunks): 117 | """The inverse operation of DataSequence.chunks(), this function concatenates 118 | the [chunks] and infers split_inds. 119 | """ 120 | lens = list(map(len, chunks)) 121 | split_inds = np.cumsum(lens)[:-1] 122 | #data = reduce(list.__add__, map(list, chunks)) ## 2.26s for 10k 6-w chunks 123 | data = list(itools.chain(*map(list, chunks))) ## 19.6ms for 10k 6-w chunks 124 | return cls(data, split_inds) 125 | -------------------------------------------------------------------------------- /ridge_utils/SemanticModel.py: -------------------------------------------------------------------------------- 1 | import tables 2 | import pickle 3 | import numpy as np 4 | 5 | import logging 6 | logger = logging.getLogger("SemanticModel") 7 | 8 | class SemanticModel(object): 9 | """This class defines a semantic vector-space model based on HAL or LSA with some 10 | prescribed preprocessing pipeline. 11 | 12 | It contains two important variables: vocab and data. 13 | vocab is a 1D list (or array) of words. 14 | data is a 2D array (features by words) of word-feature values. 15 | """ 16 | def __init__(self, data, vocab): 17 | """Initializes a SemanticModel with the given [data] and [vocab]. 18 | """ 19 | self.data = data 20 | self.vocab = vocab 21 | 22 | def get_ndim(self): 23 | """Returns the number of dimensions in this model. 24 | """ 25 | return self.data.shape[0] 26 | ndim = property(get_ndim) 27 | 28 | def get_vindex(self): 29 | """Return {vocab: index} dictionary. 30 | """ 31 | if "_vindex" not in dir(self): 32 | self._vindex = dict([(v,i) for (i,v) in enumerate(self.vocab)]) 33 | return self._vindex 34 | vindex = property(get_vindex) 35 | 36 | def __getitem__(self, word): 37 | """Returns the vector corresponding to the given [word]. 38 | """ 39 | return self.data[:,self.vindex[word]] 40 | 41 | def load_root(self, rootfile, vocab): 42 | """Load the SVD-generated semantic vector space from [rootfile], assumed to be 43 | an HDF5 file. 44 | """ 45 | roothf = tables.open_file(rootfile) 46 | self.data = roothf.get_node("/R").read() 47 | self.vocab = vocab 48 | roothf.close() 49 | 50 | def load_ascii_root(self, rootfile, vocab): 51 | """Loads the SVD-generated semantic vector space from [rootfile], assumed to be 52 | an ASCII dense matrix output from SDVLIBC. 53 | """ 54 | vtfile = open(rootfile) 55 | nrows, ncols = map(int, vtfile.readline().split()) 56 | Vt = np.zeros((nrows,ncols)) 57 | nrows_done = 0 58 | for row in vtfile: 59 | Vt[nrows_done,:] = map(float, row.split()) 60 | nrows_done += 1 61 | 62 | self.data = Vt 63 | self.vocab = vocab 64 | 65 | def restrict_by_occurrence(self, min_rank=60, max_rank=60000): 66 | """Restricts the data to words that have an occurrence rank lower than 67 | [min_rank] and higher than [max_rank]. 68 | """ 69 | logger.debug("Restricting words by occurrence..") 70 | nwords = self.data.shape[1] 71 | wordranks = np.argsort(np.argsort(self.data[0,:])) 72 | goodwords = np.nonzero(np.logical_and((nwords-wordranks)>min_rank, 73 | (nwords-wordranks)window/(2*B)] = 0 144 | if causal: 145 | val[t<0] = 0 146 | if not np.sum(val)==0.0 and renorm: 147 | val = val/np.sum(val) 148 | elif np.abs(t)>window/(2*B): 149 | val = 0 150 | if causal and t<0: 151 | val = 0 152 | return val 153 | 154 | def lanczosfun(cutoff, t, window=3): 155 | """Compute the lanczos function with some cutoff frequency [B] at some time [t]. 156 | [t] can be a scalar or any shaped numpy array. 157 | If given a [window], only the lowest-order [window] lobes of the sinc function 158 | will be non-zero. 159 | """ 160 | t = t * cutoff 161 | val = window * np.sin(np.pi*t) * np.sin(np.pi*t/window) / (np.pi**2 * t**2) 162 | val[t==0] = 1.0 163 | val[np.abs(t)>window] = 0.0 164 | return val# / (val.sum() + 1e-10) 165 | 166 | def expinterp2D(data, oldtime, newtime, theta): 167 | intmat = np.zeros((len(newtime), len(oldtime))) 168 | for ndi in range(len(newtime)): 169 | intmat[ndi,:] = expfun(theta, newtime[ndi]-oldtime) 170 | 171 | ## Construct new signal by multiplying the sinc matrix by the data ## 172 | newdata = np.dot(intmat, data) 173 | return newdata 174 | 175 | def expfun(theta, t): 176 | """Computes an exponential weighting function for interpolation. 177 | """ 178 | val = np.exp(-t*theta) 179 | val[t<0] = 0.0 180 | if not np.sum(val)==0.0: 181 | val = val/np.sum(val) 182 | return val 183 | 184 | def gabor_xfm(data, oldtimes, newtimes, freqs, sigma): 185 | sinvals = np.vstack([np.sin(oldtimes*f*2*np.pi) for f in freqs]) 186 | cosvals = np.vstack([np.cos(oldtimes*f*2*np.pi) for f in freqs]) 187 | outvals = np.zeros((len(newtimes), len(freqs)), dtype=np.complex128) 188 | for ti,t in enumerate(newtimes): 189 | ## Build gaussian function 190 | gaussvals = np.exp(-0.5*(oldtimes-t)**2/(2*sigma**2))*data 191 | ## Take product with sin/cos vals 192 | sprod = np.dot(sinvals, gaussvals) 193 | cprod = np.dot(cosvals, gaussvals) 194 | ## Store the output 195 | outvals[ti,:] = cprod + 1j*sprod 196 | 197 | return outvals 198 | 199 | def gabor_xfm2D(ddata, oldtimes, newtimes, freqs, sigma): 200 | return np.vstack([gabor_xfm(d, oldtimes, newtimes, freqs, sigma).T for d in ddata]) 201 | 202 | def test_interp(**kwargs): 203 | """Tests sincinterp2D passing it the given [kwargs] and interpolating known signals 204 | between the two time domains. 205 | """ 206 | oldtime = np.linspace(0, 10, 100) 207 | newtime = np.linspace(0, 10, 49) 208 | data = np.zeros((4, 100)) 209 | ## The first row has a single nonzero value 210 | data[0,50] = 1.0 211 | ## The second row has a few nonzero values in a row 212 | data[1,45:55] = 1.0 213 | ## The third row has a few nonzero values separated by zeros 214 | data[2,40:45] = 1.0 215 | data[2,55:60] = 1.0 216 | ## The fourth row has different values 217 | data[3,40:45] = 1.0 218 | data[3,55:60] = 2.0 219 | 220 | ## Interpolate the data 221 | interpdata = sincinterp2D(data.T, oldtime, newtime, **kwargs).T 222 | 223 | ## Plot the results 224 | from matplotlib.pyplot import figure, show 225 | fig = figure() 226 | for d in range(4): 227 | ax = fig.add_subplot(4,1,d+1) 228 | ax.plot(newtime, interpdata[d,:], 'go-') 229 | ax.plot(oldtime, data[d,:], 'bo-') 230 | 231 | #ax.tight() 232 | show() 233 | return newtime, interpdata 234 | -------------------------------------------------------------------------------- /ridge_utils/npp.py: -------------------------------------------------------------------------------- 1 | """This module contains one line functions that should, by all rights, by in numpy. 2 | """ 3 | import numpy as np 4 | 5 | ## Demean -- remove the mean from each column 6 | demean = lambda v: v-v.mean(0) 7 | demean.__doc__ = """Removes the mean from each column of [v].""" 8 | dm = demean 9 | 10 | ## Z-score -- z-score each column 11 | def zscore(v): 12 | s = v.std(0) 13 | m = v - v.mean(0) 14 | for i in range(len(s)): 15 | if s[i] != 0.: 16 | m[:, i] /= s[i] 17 | return m 18 | 19 | # zscore = lambda v: (v-v.mean(0))/v.std(0) 20 | zscore.__doc__ = """Z-scores (standardizes) each column of [v].""" 21 | zs = zscore 22 | 23 | ## Rescale -- make each column have unit variance 24 | rescale = lambda v: v/v.std(0) 25 | rescale.__doc__ = """Rescales each column of [v] to have unit variance.""" 26 | rs = rescale 27 | 28 | ## Matrix corr -- find correlation between each column of c1 and the corresponding column of c2 29 | mcorr = lambda c1,c2: (zs(c1)*zs(c2)).mean(0) 30 | mcorr.__doc__ = """Matrix correlation. Find the correlation between each column of [c1] and the corresponding column of [c2].""" 31 | 32 | ## Cross corr -- find corr. between each row of c1 and EACH row of c2 33 | xcorr = lambda c1,c2: np.dot(zs(c1.T).T,zs(c2.T)) / (c1.shape[1]) 34 | xcorr.__doc__ = """Cross-column correlation. Finds the correlation between each row of [c1] and each row of [c2].""" 35 | -------------------------------------------------------------------------------- /ridge_utils/ridge.py: -------------------------------------------------------------------------------- 1 | #import scipy 2 | import numpy as np 3 | import logging 4 | from ridge_utils.utils import mult_diag, counter 5 | import random 6 | import itertools as itools 7 | import joblib 8 | 9 | zs = lambda v: (v-v.mean(0))/v.std(0) ## z-score function 10 | 11 | ridge_logger = logging.getLogger("ridge_corr") 12 | 13 | def ridge(stim, resp, alpha, singcutoff=1e-10, normalpha=False, logger=ridge_logger): 14 | """Uses ridge regression to find a linear transformation of [stim] that approximates 15 | [resp]. The regularization parameter is [alpha]. 16 | 17 | Parameters 18 | ---------- 19 | stim : array_like, shape (T, N) 20 | Stimuli with T time points and N features. 21 | resp : array_like, shape (T, M) 22 | Responses with T time points and M separate responses. 23 | alpha : float or array_like, shape (M,) 24 | Regularization parameter. Can be given as a single value (which is applied to 25 | all M responses) or separate values for each response. 26 | normalpha : boolean 27 | Whether ridge parameters should be normalized by the largest singular value of stim. Good for 28 | comparing models with different numbers of parameters. 29 | 30 | Returns 31 | ------- 32 | wt : array_like, shape (N, M) 33 | Linear regression weights. 34 | """ 35 | try: 36 | U,S,Vh = np.linalg.svd(stim, full_matrices=False) 37 | except np.linalg.LinAlgError: 38 | logger.info("NORMAL SVD FAILED, trying more robust dgesvd..") 39 | from text.regression.svd_dgesvd import svd_dgesvd 40 | U,S,Vh = svd_dgesvd(stim, full_matrices=False) 41 | 42 | UR = np.dot(U.T, np.nan_to_num(resp)) 43 | 44 | # Expand alpha to a collection if it's just a single value 45 | if isinstance(alpha, (float,int)): 46 | alpha = np.ones(resp.shape[1]) * alpha 47 | 48 | # Normalize alpha by the LSV norm 49 | norm = S[0] 50 | if normalpha: 51 | nalphas = alpha * norm 52 | else: 53 | nalphas = alpha 54 | 55 | # Compute weights for each alpha 56 | ualphas = np.unique(nalphas) 57 | wt = np.zeros((stim.shape[1], resp.shape[1])) 58 | for ua in ualphas: 59 | selvox = np.nonzero(nalphas==ua)[0] 60 | #awt = reduce(np.dot, [Vh.T, np.diag(S/(S**2+ua**2)), UR[:,selvox]]) 61 | awt = Vh.T.dot(np.diag(S/(S**2+ua**2))).dot(UR[:,selvox]) 62 | wt[:,selvox] = awt 63 | 64 | return wt 65 | 66 | 67 | def ridge_corr_pred(Rstim, Pstim, Rresp, Presp, valphas, normalpha=False, 68 | singcutoff=1e-10, use_corr=True, logger=ridge_logger): 69 | """Uses ridge regression to find a linear transformation of [Rstim] that approximates [Rresp], 70 | then tests by comparing the transformation of [Pstim] to [Presp]. Returns the correlation 71 | between predicted and actual [Presp], without ever computing the regression weights. 72 | This function assumes that each voxel is assigned a separate alpha in [valphas]. 73 | 74 | Parameters 75 | ---------- 76 | Rstim : array_like, shape (TR, N) 77 | Training stimuli with TR time points and N features. Each feature should be Z-scored across time. 78 | Pstim : array_like, shape (TP, N) 79 | Test stimuli with TP time points and N features. Each feature should be Z-scored across time. 80 | Rresp : array_like, shape (TR, M) 81 | Training responses with TR time points and M responses (voxels, neurons, what-have-you). 82 | Each response should be Z-scored across time. 83 | Presp : array_like, shape (TP, M) 84 | Test responses with TP time points and M responses. 85 | valphas : list or array_like, shape (M,) 86 | Ridge parameter for each voxel. 87 | normalpha : boolean 88 | Whether ridge parameters should be normalized by the largest singular value (LSV) norm of 89 | Rstim. Good for comparing models with different numbers of parameters. 90 | corrmin : float in [0..1] 91 | Purely for display purposes. After each alpha is tested, the number of responses with correlation 92 | greater than corrmin minus the number of responses with correlation less than negative corrmin 93 | will be printed. For long-running regressions this vague metric of non-centered skewness can 94 | give you a rough sense of how well the model is working before it's done. 95 | singcutoff : float 96 | The first step in ridge regression is computing the singular value decomposition (SVD) of the 97 | stimulus Rstim. If Rstim is not full rank, some singular values will be approximately equal 98 | to zero and the corresponding singular vectors will be noise. These singular values/vectors 99 | should be removed both for speed (the fewer multiplications the better!) and accuracy. Any 100 | singular values less than singcutoff will be removed. 101 | use_corr : boolean 102 | If True, this function will use correlation as its metric of model fit. If False, this function 103 | will instead use variance explained (R-squared) as its metric of model fit. For ridge regression 104 | this can make a big difference -- highly regularized solutions will have very small norms and 105 | will thus explain very little variance while still leading to high correlations, as correlation 106 | is scale-free while R**2 is not. 107 | 108 | Returns 109 | ------- 110 | corr : array_like, shape (M,) 111 | The correlation between each predicted response and each column of Presp. 112 | 113 | """ 114 | ## Calculate SVD of stimulus matrix 115 | logger.info("Doing SVD...") 116 | try: 117 | U,S,Vh = np.linalg.svd(Rstim, full_matrices=False) 118 | except np.linalg.LinAlgError: 119 | logger.info("NORMAL SVD FAILED, trying more robust dgesvd..") 120 | from text.regression.svd_dgesvd import svd_dgesvd 121 | U,S,Vh = svd_dgesvd(Rstim, full_matrices=False) 122 | 123 | ## Truncate tiny singular values for speed 124 | origsize = S.shape[0] 125 | joblib.dump(S, "singvals.jbl") 126 | ngoodS = np.sum(S > singcutoff) 127 | nbad = origsize-ngoodS 128 | U = U[:,:ngoodS] 129 | S = S[:ngoodS] 130 | Vh = Vh[:ngoodS] 131 | logger.info("Dropped %d tiny singular values.. (U is now %s)"%(nbad, str(U.shape))) 132 | 133 | ## Normalize alpha by the LSV norm 134 | norm = S[0] 135 | logger.info("Training stimulus has LSV norm: %0.03f"%norm) 136 | if normalpha: 137 | nalphas = valphas * norm 138 | else: 139 | nalphas = valphas 140 | 141 | ## Precompute some products for speed 142 | UR = np.dot(U.T, Rresp) ## Precompute this matrix product for speed 143 | PVh = np.dot(Pstim, Vh.T) ## Precompute this matrix product for speed 144 | 145 | #Prespnorms = np.apply_along_axis(np.linalg.norm, 0, Presp) ## Precompute test response norms 146 | zPresp = zs(Presp) 147 | #Prespvar = Presp.var(0) 148 | Prespvar_actual = Presp.var(0) 149 | Prespvar = (np.ones_like(Prespvar_actual) + Prespvar_actual) / 2.0 150 | logger.info("Average difference between actual & assumed Prespvar: %0.3f" % (Prespvar_actual - Prespvar).mean()) 151 | 152 | ualphas = np.unique(nalphas) 153 | corr = np.zeros((Rresp.shape[1],)) 154 | for ua in ualphas: 155 | selvox = np.nonzero(nalphas==ua)[0] 156 | alpha_pred = PVh.dot(np.diag(S/(S**2+ua**2))).dot(UR[:,selvox]) 157 | 158 | if use_corr: 159 | corr[selvox] = (zPresp[:,selvox] * zs(alpha_pred)).mean(0) 160 | else: 161 | resvar = (Presp[:,selvox] - alpha_pred).var(0) 162 | Rsq = 1 - (resvar / Prespvar) 163 | corr[selvox] = np.sqrt(np.abs(Rsq)) * np.sign(Rsq) 164 | 165 | return corr 166 | 167 | 168 | def ridge_corr(Rstim, Pstim, Rresp, Presp, alphas, normalpha=False, corrmin=0.2, 169 | singcutoff=1e-10, use_corr=True, logger=ridge_logger): 170 | """Uses ridge regression to find a linear transformation of [Rstim] that approximates [Rresp], 171 | then tests by comparing the transformation of [Pstim] to [Presp]. This procedure is repeated 172 | for each regularization parameter alpha in [alphas]. The correlation between each prediction and 173 | each response for each alpha is returned. The regression weights are NOT returned, because 174 | computing the correlations without computing regression weights is much, MUCH faster. 175 | 176 | Parameters 177 | ---------- 178 | Rstim : array_like, shape (TR, N) 179 | Training stimuli with TR time points and N features. Each feature should be Z-scored across time. 180 | Pstim : array_like, shape (TP, N) 181 | Test stimuli with TP time points and N features. Each feature should be Z-scored across time. 182 | Rresp : array_like, shape (TR, M) 183 | Training responses with TR time points and M responses (voxels, neurons, what-have-you). 184 | Each response should be Z-scored across time. 185 | Presp : array_like, shape (TP, M) 186 | Test responses with TP time points and M responses. 187 | alphas : list or array_like, shape (A,) 188 | Ridge parameters to be tested. Should probably be log-spaced. np.logspace(0, 3, 20) works well. 189 | normalpha : boolean 190 | Whether ridge parameters should be normalized by the largest singular value (LSV) norm of 191 | Rstim. Good for comparing models with different numbers of parameters. 192 | corrmin : float in [0..1] 193 | Purely for display purposes. After each alpha is tested, the number of responses with correlation 194 | greater than corrmin minus the number of responses with correlation less than negative corrmin 195 | will be printed. For long-running regressions this vague metric of non-centered skewness can 196 | give you a rough sense of how well the model is working before it's done. 197 | singcutoff : float 198 | The first step in ridge regression is computing the singular value decomposition (SVD) of the 199 | stimulus Rstim. If Rstim is not full rank, some singular values will be approximately equal 200 | to zero and the corresponding singular vectors will be noise. These singular values/vectors 201 | should be removed both for speed (the fewer multiplications the better!) and accuracy. Any 202 | singular values less than singcutoff will be removed. 203 | use_corr : boolean 204 | If True, this function will use correlation as its metric of model fit. If False, this function 205 | will instead use variance explained (R-squared) as its metric of model fit. For ridge regression 206 | this can make a big difference -- highly regularized solutions will have very small norms and 207 | will thus explain very little variance while still leading to high correlations, as correlation 208 | is scale-free while R**2 is not. 209 | 210 | Returns 211 | ------- 212 | Rcorrs : array_like, shape (A, M) 213 | The correlation between each predicted response and each column of Presp for each alpha. 214 | 215 | """ 216 | ## Calculate SVD of stimulus matrix 217 | logger.info("Doing SVD...") 218 | try: 219 | U,S,Vh = np.linalg.svd(Rstim, full_matrices=False) 220 | except np.linalg.LinAlgError: 221 | logger.info("NORMAL SVD FAILED, trying more robust dgesvd..") 222 | from text.regression.svd_dgesvd import svd_dgesvd 223 | U,S,Vh = svd_dgesvd(Rstim, full_matrices=False) 224 | 225 | ## Truncate tiny singular values for speed 226 | origsize = S.shape[0] 227 | #joblib.dump((Rstim, U,S,Vh), "/hdd/singvals.jbl") 228 | ngoodS = np.sum(S > singcutoff) 229 | nbad = origsize-ngoodS 230 | U = U[:,:ngoodS] 231 | S = S[:ngoodS] 232 | Vh = Vh[:ngoodS] 233 | logger.info("Dropped %d tiny singular values.. (U is now %s)"%(nbad, str(U.shape))) 234 | 235 | ## Normalize alpha by the LSV norm 236 | norm = S[0] 237 | logger.info("Training stimulus has LSV norm: %0.03f"%norm) 238 | if normalpha: 239 | nalphas = alphas * norm 240 | else: 241 | nalphas = alphas 242 | 243 | ## Precompute some products for speed 244 | UR = np.dot(U.T, Rresp) ## Precompute this matrix product for speed 245 | PVh = np.dot(Pstim, Vh.T) ## Precompute this matrix product for speed 246 | 247 | #Prespnorms = np.apply_along_axis(np.linalg.norm, 0, Presp) ## Precompute test response norms 248 | zPresp = zs(Presp) 249 | #Prespvar = Presp.var(0) 250 | Prespvar_actual = Presp.var(0) 251 | Prespvar = (np.ones_like(Prespvar_actual) + Prespvar_actual) / 2.0 252 | logger.info("Average difference between actual & assumed Prespvar: %0.3f" % (Prespvar_actual - Prespvar).mean()) 253 | Rcorrs = [] ## Holds training correlations for each alpha 254 | for na, a in zip(nalphas, alphas): 255 | #D = np.diag(S/(S**2+a**2)) ## Reweight singular vectors by the ridge parameter 256 | D = S / (S ** 2 + na ** 2) ## Reweight singular vectors by the (normalized?) ridge parameter 257 | 258 | pred = np.dot(mult_diag(D, PVh, left=False), UR) ## Best (1.75 seconds to prediction in test) 259 | # pred = np.dot(mult_diag(D, np.dot(Pstim, Vh.T), left=False), UR) ## Better (2.0 seconds to prediction in test) 260 | 261 | # pvhd = reduce(np.dot, [Pstim, Vh.T, D]) ## Pretty good (2.4 seconds to prediction in test) 262 | # pred = np.dot(pvhd, UR) 263 | 264 | # wt = reduce(np.dot, [Vh.T, D, UR]).astype(dtype) ## Bad (14.2 seconds to prediction in test) 265 | # wt = reduce(np.dot, [Vh.T, D, U.T, Rresp]).astype(dtype) ## Worst 266 | # pred = np.dot(Pstim, wt) ## Predict test responses 267 | 268 | if use_corr: 269 | #prednorms = np.apply_along_axis(np.linalg.norm, 0, pred) ## Compute predicted test response norms 270 | #Rcorr = np.array([np.corrcoef(Presp[:,ii], pred[:,ii].ravel())[0,1] for ii in range(Presp.shape[1])]) ## Slowly compute correlations 271 | #Rcorr = np.array(np.sum(np.multiply(Presp, pred), 0)).squeeze()/(prednorms*Prespnorms) ## Efficiently compute correlations 272 | Rcorr = (zPresp * zs(pred)).mean(0) 273 | else: 274 | ## Compute variance explained 275 | resvar = (Presp - pred).var(0) 276 | Rsq = 1 - (resvar / Prespvar) 277 | Rcorr = np.sqrt(np.abs(Rsq)) * np.sign(Rsq) 278 | 279 | Rcorr[np.isnan(Rcorr)] = 0 280 | Rcorrs.append(Rcorr) 281 | 282 | log_template = "Training: alpha=%0.3f, mean corr=%0.5f, max corr=%0.5f, over-under(%0.2f)=%d" 283 | log_msg = log_template % (a, 284 | np.mean(Rcorr), 285 | np.max(Rcorr), 286 | corrmin, 287 | (Rcorr>corrmin).sum()-(-Rcorr>corrmin).sum()) 288 | logger.info(log_msg) 289 | 290 | return Rcorrs 291 | 292 | 293 | def bootstrap_ridge(Rstim, Rresp, Pstim, Presp, alphas, nboots, chunklen, nchunks, 294 | corrmin=0.2, joined=None, singcutoff=1e-10, normalpha=False, single_alpha=False, 295 | use_corr=True, return_wt=True, logger=ridge_logger): 296 | """Uses ridge regression with a bootstrapped held-out set to get optimal alpha values for each response. 297 | [nchunks] random chunks of length [chunklen] will be taken from [Rstim] and [Rresp] for each regression 298 | run. [nboots] total regression runs will be performed. The best alpha value for each response will be 299 | averaged across the bootstraps to estimate the best alpha for that response. 300 | 301 | If [joined] is given, it should be a list of lists where the STRFs for all the voxels in each sublist 302 | will be given the same regularization parameter (the one that is the best on average). 303 | 304 | Parameters 305 | ---------- 306 | Rstim : array_like, shape (TR, N) 307 | Training stimuli with TR time points and N features. Each feature should be Z-scored across time. 308 | Rresp : array_like, shape (TR, M) 309 | Training responses with TR time points and M different responses (voxels, neurons, what-have-you). 310 | Each response should be Z-scored across time. 311 | Pstim : array_like, shape (TP, N) 312 | Test stimuli with TP time points and N features. Each feature should be Z-scored across time. 313 | Presp : array_like, shape (TP, M) 314 | Test responses with TP time points and M different responses. Each response should be Z-scored across 315 | time. 316 | alphas : list or array_like, shape (A,) 317 | Ridge parameters that will be tested. Should probably be log-spaced. np.logspace(0, 3, 20) works well. 318 | nboots : int 319 | The number of bootstrap samples to run. 15 to 30 works well. 320 | chunklen : int 321 | On each sample, the training data is broken into chunks of this length. This should be a few times 322 | longer than your delay/STRF. e.g. for a STRF with 3 delays, I use chunks of length 10. 323 | nchunks : int 324 | The number of training chunks held out to test ridge parameters for each bootstrap sample. The product 325 | of nchunks and chunklen is the total number of training samples held out for each sample, and this 326 | product should be about 20 percent of the total length of the training data. 327 | corrmin : float in [0..1], default 0.2 328 | Purely for display purposes. After each alpha is tested for each bootstrap sample, the number of 329 | responses with correlation greater than this value will be printed. For long-running regressions this 330 | can give a rough sense of how well the model works before it's done. 331 | joined : None or list of array_like indices, default None 332 | If you want the STRFs for two (or more) responses to be directly comparable, you need to ensure that 333 | the regularization parameter that they use is the same. To do that, supply a list of the response sets 334 | that should use the same ridge parameter here. For example, if you have four responses, joined could 335 | be [np.array([0,1]), np.array([2,3])], in which case responses 0 and 1 will use the same ridge parameter 336 | (which will be parameter that is best on average for those two), and likewise for responses 2 and 3. 337 | singcutoff : float, default 1e-10 338 | The first step in ridge regression is computing the singular value decomposition (SVD) of the 339 | stimulus Rstim. If Rstim is not full rank, some singular values will be approximately equal 340 | to zero and the corresponding singular vectors will be noise. These singular values/vectors 341 | should be removed both for speed (the fewer multiplications the better!) and accuracy. Any 342 | singular values less than singcutoff will be removed. 343 | normalpha : boolean, default False 344 | Whether ridge parameters (alphas) should be normalized by the largest singular value (LSV) 345 | norm of Rstim. Good for rigorously comparing models with different numbers of parameters. 346 | single_alpha : boolean, default False 347 | Whether to use a single alpha for all responses. Good for identification/decoding. 348 | use_corr : boolean, default True 349 | If True, this function will use correlation as its metric of model fit. If False, this function 350 | will instead use variance explained (R-squared) as its metric of model fit. For ridge regression 351 | this can make a big difference -- highly regularized solutions will have very small norms and 352 | will thus explain very little variance while still leading to high correlations, as correlation 353 | is scale-free while R**2 is not. 354 | return_wt : boolean, default True 355 | If True, this function will compute and return the regression weights after finding the best 356 | alpha parameter for each voxel. However, for very large models this can lead to memory issues. 357 | If false, this function will _not_ compute weights, but will still compute prediction performance 358 | on the prediction dataset (Pstim, Presp). 359 | 360 | Returns 361 | ------- 362 | wt : array_like, shape (N, M) 363 | If [return_wt] is True, regression weights for N features and M responses. If [return_wt] is False, []. 364 | corrs : array_like, shape (M,) 365 | Validation set correlations. Predicted responses for the validation set are obtained using the regression 366 | weights: pred = np.dot(Pstim, wt), and then the correlation between each predicted response and each 367 | column in Presp is found. 368 | alphas : array_like, shape (M,) 369 | The regularization coefficient (alpha) selected for each voxel using bootstrap cross-validation. 370 | bootstrap_corrs : array_like, shape (A, M, B) 371 | Correlation between predicted and actual responses on randomly held out portions of the training set, 372 | for each of A alphas, M voxels, and B bootstrap samples. 373 | valinds : array_like, shape (TH, B) 374 | The indices of the training data that were used as "validation" for each bootstrap sample. 375 | """ 376 | nresp, nvox = Rresp.shape 377 | valinds = [] # Will hold the indices into the validation data for each bootstrap 378 | 379 | Rcmats = [] 380 | for bi in counter(range(nboots), countevery=1, total=nboots): 381 | logger.info("Selecting held-out test set..") 382 | allinds = range(nresp) 383 | indchunks = list(zip(*[iter(allinds)]*chunklen)) 384 | random.shuffle(indchunks) 385 | heldinds = list(itools.chain(*indchunks[:nchunks])) 386 | notheldinds = list(set(allinds)-set(heldinds)) 387 | valinds.append(heldinds) 388 | 389 | RRstim = Rstim[notheldinds,:] 390 | PRstim = Rstim[heldinds,:] 391 | RRresp = Rresp[notheldinds,:] 392 | PRresp = Rresp[heldinds,:] 393 | 394 | # Run ridge regression using this test set 395 | Rcmat = ridge_corr(RRstim, PRstim, RRresp, PRresp, alphas, 396 | corrmin=corrmin, singcutoff=singcutoff, 397 | normalpha=normalpha, use_corr=use_corr, 398 | logger=logger) 399 | 400 | Rcmats.append(Rcmat) 401 | 402 | # Find best alphas 403 | if nboots>0: 404 | allRcorrs = np.dstack(Rcmats) 405 | else: 406 | allRcorrs = None 407 | 408 | if not single_alpha: 409 | if nboots==0: 410 | raise ValueError("You must run at least one cross-validation step to assign " 411 | "different alphas to each response.") 412 | 413 | logger.info("Finding best alpha for each voxel..") 414 | if joined is None: 415 | # Find best alpha for each voxel 416 | meanbootcorrs = allRcorrs.mean(2) 417 | bestalphainds = np.argmax(meanbootcorrs, 0) 418 | valphas = alphas[bestalphainds] 419 | else: 420 | # Find best alpha for each group of voxels 421 | valphas = np.zeros((nvox,)) 422 | for jl in joined: 423 | # Mean across voxels in the set, then mean across bootstraps 424 | jcorrs = allRcorrs[:,jl,:].mean(1).mean(1) 425 | bestalpha = np.argmax(jcorrs) 426 | valphas[jl] = alphas[bestalpha] 427 | else: 428 | logger.info("Finding single best alpha..") 429 | if nboots==0: 430 | if len(alphas)==1: 431 | bestalphaind = 0 432 | bestalpha = alphas[0] 433 | else: 434 | raise ValueError("You must run at least one cross-validation step " 435 | "to choose best overall alpha, or only supply one" 436 | "possible alpha value.") 437 | else: 438 | meanbootcorr = allRcorrs.mean(2).mean(1) 439 | bestalphaind = np.argmax(meanbootcorr) 440 | bestalpha = alphas[bestalphaind] 441 | 442 | valphas = np.array([bestalpha]*nvox) 443 | logger.info("Best alpha = %0.3f"%bestalpha) 444 | 445 | if return_wt: 446 | # Find weights 447 | logger.info("Computing weights for each response using entire training set..") 448 | wt = ridge(Rstim, Rresp, valphas, singcutoff=singcutoff, normalpha=normalpha) 449 | 450 | # Predict responses on prediction set 451 | logger.info("Predicting responses for predictions set..") 452 | pred = np.dot(Pstim, wt) 453 | 454 | # Find prediction correlations 455 | nnpred = np.nan_to_num(pred) 456 | if use_corr: 457 | corrs = np.nan_to_num(np.array([np.corrcoef(Presp[:,ii], nnpred[:,ii].ravel())[0,1] 458 | for ii in range(Presp.shape[1])])) 459 | else: 460 | resvar = (Presp-pred).var(0) 461 | Rsqs = 1 - (resvar / Presp.var(0)) 462 | corrs = np.sqrt(np.abs(Rsqs)) * np.sign(Rsqs) 463 | 464 | return wt, corrs, valphas, allRcorrs, valinds 465 | else: 466 | # get correlations for prediction dataset directly 467 | corrs = ridge_corr_pred(Rstim, Pstim, Rresp, Presp, valphas, 468 | normalpha=normalpha, use_corr=use_corr, 469 | logger=logger, singcutoff=singcutoff) 470 | 471 | return [], corrs, valphas, allRcorrs, valinds 472 | -------------------------------------------------------------------------------- /ridge_utils/stimulus_utils.py: -------------------------------------------------------------------------------- 1 | from ridge_utils.textgrid import TextGrid 2 | import os 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | def load_grid(story, grid_dir): 7 | """Loads the TextGrid for the given [story] from the directory [grid_dir]. 8 | The first file that starts with [story] will be loaded, so if there are 9 | multiple versions of a grid for a story, beward. 10 | """ 11 | gridfile = [os.path.join(grid_dir, gf) for gf in os.listdir(grid_dir) if gf.startswith(story)][0] 12 | return TextGrid(open(gridfile).read()) 13 | 14 | def load_grids_for_stories(stories, grid_dir): 15 | """Loads grids for the given [stories], puts them in a dictionary. 16 | """ 17 | return dict([(st, load_grid(st, grid_dir)) for st in stories]) 18 | 19 | def load_5tier_grids_for_stories(stories, rootdir): 20 | grids = dict() 21 | for story in stories: 22 | storydir = os.path.join(rootdir, [sd for sd in os.listdir(rootdir) if sd.startswith(story)][0]) 23 | storyfile = os.path.join(storydir, [sf for sf in os.listdir(storydir) if sf.endswith("TextGrid")][0]) 24 | grids[story] = TextGrid(open(storyfile).read()) 25 | return grids 26 | 27 | 28 | class TRFile(object): 29 | def __init__(self, trfilename, expectedtr=2.0045): 30 | """Loads data from [trfilename], should be output from stimulus presentation code. 31 | """ 32 | self.trtimes = [] 33 | self.soundstarttime = -1 34 | self.soundstoptime = -1 35 | self.otherlabels = [] 36 | self.expectedtr = expectedtr 37 | 38 | if trfilename is not None: 39 | self.load_from_file(trfilename) 40 | 41 | 42 | def load_from_file(self, trfilename): 43 | """Loads TR data from report with given [trfilename]. 44 | """ 45 | ## Read the report file and populate the datastructure 46 | for ll in open(trfilename): 47 | timestr = ll.split()[0] 48 | label = " ".join(ll.split()[1:]) 49 | time = float(timestr) 50 | 51 | if label in ("init-trigger", "trigger") or label.startswith("first"): 52 | self.trtimes.append(time) 53 | 54 | elif label=="sound-start" or label.startswith("start") or label.startswith("START"): 55 | self.soundstarttime = time 56 | 57 | elif label=="sound-stop" or label.startswith("END"): 58 | self.soundstoptime = time 59 | 60 | else: 61 | self.otherlabels.append((time, label)) 62 | 63 | ## Fix weird TR times 64 | itrtimes = np.diff(self.trtimes) 65 | badtrtimes = np.nonzero(itrtimes>(itrtimes.mean()*1.5))[0] 66 | newtrs = [] 67 | for btr in badtrtimes: 68 | ## Insert new TR where it was missing.. 69 | newtrtime = self.trtimes[btr]+self.expectedtr 70 | newtrs.append((newtrtime,btr)) 71 | 72 | for ntr,btr in newtrs: 73 | self.trtimes.insert(btr+1, ntr) 74 | 75 | def simulate(self, ntrs): 76 | """Simulates [ntrs] TRs that occur at the expected TR. 77 | """ 78 | self.trtimes = list(np.arange(ntrs)*self.expectedtr) 79 | 80 | def get_reltriggertimes(self): 81 | """Returns the times of all trigger events relative to the sound. 82 | """ 83 | return np.array(self.trtimes)-self.soundstarttime 84 | 85 | @property 86 | def avgtr(self): 87 | """Returns the average TR for this run. 88 | """ 89 | return np.diff(self.trtimes).mean() 90 | 91 | def load_generic_trfiles(stories, root): 92 | """Loads a dictionary of generic TRFiles (i.e. not specifically from the session 93 | in which the data was collected.. this should be fine) for the given stories. 94 | """ 95 | trdict = dict() 96 | 97 | for story in stories: 98 | try: 99 | trf = TRFile(os.path.join(root, "%s.report"%story)) 100 | trdict[story] = [trf] 101 | except Exception as e: 102 | print (e) 103 | 104 | return trdict 105 | 106 | -------------------------------------------------------------------------------- /ridge_utils/textgrid.py: -------------------------------------------------------------------------------- 1 | # Natural Language Toolkit: TextGrid analysis 2 | # 3 | # Copyright (C) 2001-2011 NLTK Project 4 | # Author: Margaret Mitchell 5 | # Steven Bird (revisions) 6 | # URL: 7 | # For license information, see LICENSE.TXT 8 | # 9 | 10 | """ 11 | Tools for reading TextGrid files, the format used by Praat. 12 | 13 | Module contents 14 | =============== 15 | 16 | The textgrid corpus reader provides 4 data items and 1 function 17 | for each textgrid file. For each tier in the file, the reader 18 | provides 10 data items and 2 functions. 19 | 20 | For the full textgrid file: 21 | 22 | - size 23 | The number of tiers in the file. 24 | 25 | - xmin 26 | First marked time of the file. 27 | 28 | - xmax 29 | Last marked time of the file. 30 | 31 | - t_time 32 | xmax - xmin. 33 | 34 | - text_type 35 | The style of TextGrid format: 36 | - ooTextFile: Organized by tier. 37 | - ChronTextFile: Organized by time. 38 | - OldooTextFile: Similar to ooTextFile. 39 | 40 | - to_chron() 41 | Convert given file to a ChronTextFile format. 42 | 43 | - to_oo() 44 | Convert given file to an ooTextFile format. 45 | 46 | For each tier: 47 | 48 | - text_type 49 | The style of TextGrid format, as above. 50 | 51 | - classid 52 | The style of transcription on this tier: 53 | - IntervalTier: Transcription is marked as intervals. 54 | - TextTier: Transcription is marked as single points. 55 | 56 | - nameid 57 | The name of the tier. 58 | 59 | - xmin 60 | First marked time of the tier. 61 | 62 | - xmax 63 | Last marked time of the tier. 64 | 65 | - size 66 | Number of entries in the tier. 67 | 68 | - transcript 69 | The raw transcript for the tier. 70 | 71 | - simple_transcript 72 | The transcript formatted as a list of tuples: (time1, time2, utterance). 73 | 74 | - tier_info 75 | List of (classid, nameid, xmin, xmax, size, transcript). 76 | 77 | - min_max() 78 | A tuple of (xmin, xmax). 79 | 80 | - time(non_speech_marker) 81 | Returns the utterance time of a given tier. 82 | Excludes entries that begin with a non-speech marker. 83 | 84 | """ 85 | 86 | # needs more cleanup, subclassing, epydoc docstrings 87 | 88 | import sys 89 | import re 90 | 91 | TEXTTIER = "TextTier" 92 | INTERVALTIER = "IntervalTier" 93 | 94 | OOTEXTFILE = re.compile(r"""(?x) 95 | xmin\ =\ (.*)[\r\n]+ 96 | xmax\ =\ (.*)[\r\n]+ 97 | [\s\S]+?size\ =\ (.*)[\r\n]+ 98 | """) 99 | 100 | CHRONTEXTFILE = re.compile(r"""(?x) 101 | [\r\n]+(\S+)\ 102 | (\S+)\ +!\ Time\ domain.\ *[\r\n]+ 103 | (\S+)\ +!\ Number\ of\ tiers.\ *[\r\n]+" 104 | """) 105 | 106 | OLDOOTEXTFILE = re.compile(r"""(?x) 107 | [\r\n]+(\S+) 108 | [\r\n]+(\S+) 109 | [\r\n]+.+[\r\n]+(\S+) 110 | """) 111 | 112 | 113 | 114 | ################################################################# 115 | # TextGrid Class 116 | ################################################################# 117 | 118 | class TextGrid(object): 119 | """ 120 | Class to manipulate the TextGrid format used by Praat. 121 | Separates each tier within this file into its own Tier 122 | object. Each TextGrid object has 123 | a number of tiers (size), xmin, xmax, a text type to help 124 | with the different styles of TextGrid format, and tiers with their 125 | own attributes. 126 | """ 127 | 128 | def __init__(self, read_file): 129 | """ 130 | Takes open read file as input, initializes attributes 131 | of the TextGrid file. 132 | @type read_file: An open TextGrid file, mode "r". 133 | @param size: Number of tiers. 134 | @param xmin: xmin. 135 | @param xmax: xmax. 136 | @param t_time: Total time of TextGrid file. 137 | @param text_type: TextGrid format. 138 | @type tiers: A list of tier objects. 139 | """ 140 | 141 | self.read_file = read_file 142 | self.size = 0 143 | self.xmin = 0 144 | self.xmax = 0 145 | self.t_time = 0 146 | self.text_type = self._check_type() 147 | self.tiers = self._find_tiers() 148 | 149 | def __iter__(self): 150 | for tier in self.tiers: 151 | yield tier 152 | 153 | def next(self): 154 | if self.idx == (self.size - 1): 155 | raise StopIteration 156 | self.idx += 1 157 | return self.tiers[self.idx] 158 | 159 | @staticmethod 160 | def load(file): 161 | """ 162 | @param file: a file in TextGrid format 163 | """ 164 | 165 | return TextGrid(open(file).read()) 166 | 167 | def _load_tiers(self, header): 168 | """ 169 | Iterates over each tier and grabs tier information. 170 | """ 171 | 172 | tiers = [] 173 | if self.text_type == "ChronTextFile": 174 | m = re.compile(header) 175 | tier_headers = m.findall(self.read_file) 176 | tier_re = " \d+.?\d* \d+.?\d*[\r\n]+\"[^\"]*\"" 177 | for i in range(0, self.size): 178 | tier_info = [tier_headers[i]] + \ 179 | re.findall(str(i + 1) + tier_re, self.read_file) 180 | tier_info = "\n".join(tier_info) 181 | tiers.append(Tier(tier_info, self.text_type, self.t_time)) 182 | return tiers 183 | 184 | tier_re = header + "[\s\S]+?(?=" + header + "|$$)" 185 | m = re.compile(tier_re) 186 | tier_iter = m.finditer(self.read_file) 187 | for iterator in tier_iter: 188 | (begin, end) = iterator.span() 189 | tier_info = self.read_file[begin:end] 190 | tiers.append(Tier(tier_info, self.text_type, self.t_time)) 191 | return tiers 192 | 193 | def _check_type(self): 194 | """ 195 | Figures out the TextGrid format. 196 | """ 197 | 198 | m = re.match("(.*)[\r\n](.*)[\r\n](.*)[\r\n](.*)", self.read_file) 199 | try: 200 | type_id = m.group(1).strip() 201 | except AttributeError: 202 | raise TypeError("Cannot read file -- try TextGrid.load()") 203 | xmin = m.group(4) 204 | if type_id == "File type = \"ooTextFile\"": 205 | if "xmin" not in xmin: 206 | text_type = "OldooTextFile" 207 | else: 208 | text_type = "ooTextFile" 209 | elif type_id == "\"Praat chronological TextGrid text file\"": 210 | text_type = "ChronTextFile" 211 | else: 212 | raise TypeError("Unknown format '(%s)'", (type_id)) 213 | return text_type 214 | 215 | def _find_tiers(self): 216 | """ 217 | Splits the textgrid file into substrings corresponding to tiers. 218 | """ 219 | 220 | if self.text_type == "ooTextFile": 221 | m = OOTEXTFILE 222 | header = " +item \[" 223 | elif self.text_type == "ChronTextFile": 224 | m = CHRONTEXTFILE 225 | header = "\"\S+\" \".*\" \d+\.?\d* \d+\.?\d*" 226 | elif self.text_type == "OldooTextFile": 227 | m = OLDOOTEXTFILE 228 | header = "\".*\"[\r\n]+\".*\"" 229 | 230 | file_info = m.findall(self.read_file)[0] 231 | self.xmin = float(file_info[0]) 232 | self.xmax = float(file_info[1]) 233 | self.t_time = self.xmax - self.xmin 234 | self.size = int(file_info[2]) 235 | tiers = self._load_tiers(header) 236 | return tiers 237 | 238 | def to_chron(self): 239 | """ 240 | @return: String in Chronological TextGrid file format. 241 | """ 242 | 243 | chron_file = "" 244 | chron_file += "\"Praat chronological TextGrid text file\"\n" 245 | chron_file += str(self.xmin) + " " + str(self.xmax) 246 | chron_file += " ! Time domain.\n" 247 | chron_file += str(self.size) + " ! Number of tiers.\n" 248 | for tier in self.tiers: 249 | idx = (self.tiers.index(tier)) + 1 250 | tier_header = "\"" + tier.classid + "\" \"" \ 251 | + tier.nameid + "\" " + str(tier.xmin) \ 252 | + " " + str(tier.xmax) 253 | chron_file += tier_header + "\n" 254 | transcript = tier.simple_transcript 255 | for (xmin, xmax, utt) in transcript: 256 | chron_file += str(idx) + " " + str(xmin) 257 | chron_file += " " + str(xmax) +"\n" 258 | chron_file += "\"" + utt + "\"\n" 259 | return chron_file 260 | 261 | def to_oo(self): 262 | """ 263 | @return: A string in OoTextGrid file format. 264 | """ 265 | 266 | oo_file = "" 267 | oo_file += "File type = \"ooTextFile\"\n" 268 | oo_file += "Object class = \"TextGrid\"\n\n" 269 | oo_file += "xmin = ", self.xmin, "\n" 270 | oo_file += "xmax = ", self.xmax, "\n" 271 | oo_file += "tiers? \n" 272 | oo_file += "size = ", self.size, "\n" 273 | oo_file += "item []:\n" 274 | for i in range(len(self.tiers)): 275 | oo_file += "%4s%s [%s]" % ("", "item", i + 1) 276 | _curr_tier = self.tiers[i] 277 | for (x, y) in _curr_tier.header: 278 | oo_file += "%8s%s = \"%s\"" % ("", x, y) 279 | if _curr_tier.classid != TEXTTIER: 280 | for (xmin, xmax, text) in _curr_tier.simple_transcript: 281 | oo_file += "%12s%s = %s" % ("", "xmin", xmin) 282 | oo_file += "%12s%s = %s" % ("", "xmax", xmax) 283 | oo_file += "%12s%s = \"%s\"" % ("", "text", text) 284 | else: 285 | for (time, mark) in _curr_tier.simple_transcript: 286 | oo_file += "%12s%s = %s" % ("", "time", time) 287 | oo_file += "%12s%s = %s" % ("", "mark", mark) 288 | return oo_file 289 | 290 | 291 | ################################################################# 292 | # Tier Class 293 | ################################################################# 294 | 295 | class Tier(object): 296 | """ 297 | A container for each tier. 298 | """ 299 | 300 | def __init__(self, tier, text_type, t_time): 301 | """ 302 | Initializes attributes of the tier: class, name, xmin, xmax 303 | size, transcript, total time. 304 | Utilizes text_type to guide how to parse the file. 305 | @type tier: a tier object; single item in the TextGrid list. 306 | @param text_type: TextGrid format 307 | @param t_time: Total time of TextGrid file. 308 | @param classid: Type of tier (point or interval). 309 | @param nameid: Name of tier. 310 | @param xmin: xmin of the tier. 311 | @param xmax: xmax of the tier. 312 | @param size: Number of entries in the tier 313 | @param transcript: The raw transcript for the tier. 314 | """ 315 | 316 | self.tier = tier 317 | self.text_type = text_type 318 | self.t_time = t_time 319 | self.classid = "" 320 | self.nameid = "" 321 | self.xmin = 0 322 | self.xmax = 0 323 | self.size = 0 324 | self.transcript = "" 325 | self.tier_info = "" 326 | self._make_info() 327 | self.simple_transcript = self.make_simple_transcript() 328 | if self.classid != TEXTTIER: 329 | self.mark_type = "intervals" 330 | else: 331 | self.mark_type = "points" 332 | self.header = [("class", self.classid), ("name", self.nameid), \ 333 | ("xmin", self.xmin), ("xmax", self.xmax), ("size", self.size)] 334 | 335 | def __iter__(self): 336 | return self 337 | 338 | def _make_info(self): 339 | """ 340 | Figures out most attributes of the tier object: 341 | class, name, xmin, xmax, transcript. 342 | """ 343 | 344 | trans = "([\S\s]*)" 345 | if self.text_type == "ChronTextFile": 346 | classid = "\"(.*)\" +" 347 | nameid = "\"(.*)\" +" 348 | xmin = "(\d+\.?\d*) +" 349 | xmax = "(\d+\.?\d*) *[\r\n]+" 350 | # No size values are given in the Chronological Text File format. 351 | self.size = None 352 | size = "" 353 | elif self.text_type == "ooTextFile": 354 | classid = " +class = \"(.*)\" *[\r\n]+" 355 | nameid = " +name = \"(.*)\" *[\r\n]+" 356 | xmin = " +xmin = (\d+\.?\d*) *[\r\n]+" 357 | xmax = " +xmax = (\d+\.?\d*) *[\r\n]+" 358 | size = " +\S+: size = (\d+) *[\r\n]+" 359 | elif self.text_type == "OldooTextFile": 360 | classid = "\"(.*)\" *[\r\n]+" 361 | nameid = "\"(.*)\" *[\r\n]+" 362 | xmin = "(\d+\.?\d*) *[\r\n]+" 363 | xmax = "(\d+\.?\d*) *[\r\n]+" 364 | size = "(\d+) *[\r\n]+" 365 | m = re.compile(classid + nameid + xmin + xmax + size + trans) 366 | self.tier_info = m.findall(self.tier)[0] 367 | self.classid = self.tier_info[0] 368 | self.nameid = self.tier_info[1] 369 | self.xmin = float(self.tier_info[2]) 370 | self.xmax = float(self.tier_info[3]) 371 | if self.size != None: 372 | self.size = int(self.tier_info[4]) 373 | self.transcript = self.tier_info[-1] 374 | 375 | def make_simple_transcript(self): 376 | """ 377 | @return: Transcript of the tier, in form [(start_time end_time label)] 378 | """ 379 | 380 | if self.text_type == "ChronTextFile": 381 | trans_head = "" 382 | trans_xmin = " (\S+)" 383 | trans_xmax = " (\S+)[\r\n]+" 384 | trans_text = "\"([\S\s]*?)\"" 385 | elif self.text_type == "ooTextFile": 386 | trans_head = " +\S+ \[\d+\]: *[\r\n]+" 387 | trans_xmin = " +\S+ = (\S+) *[\r\n]+" 388 | trans_xmax = " +\S+ = (\S+) *[\r\n]+" 389 | trans_text = " +\S+ = \"([^\"]*?)\"" 390 | elif self.text_type == "OldooTextFile": 391 | trans_head = "" 392 | trans_xmin = "(.*)[\r\n]+" 393 | trans_xmax = "(.*)[\r\n]+" 394 | trans_text = "\"([\S\s]*?)\"" 395 | if self.classid == TEXTTIER: 396 | trans_xmin = "" 397 | trans_m = re.compile(trans_head + trans_xmin + trans_xmax + trans_text) 398 | self.simple_transcript = trans_m.findall(self.transcript) 399 | return self.simple_transcript 400 | 401 | def transcript(self): 402 | """ 403 | @return: Transcript of the tier, as it appears in the file. 404 | """ 405 | 406 | return self.transcript 407 | 408 | def time(self, non_speech_char="."): 409 | """ 410 | @return: Utterance time of a given tier. 411 | Screens out entries that begin with a non-speech marker. 412 | """ 413 | 414 | total = 0.0 415 | if self.classid != TEXTTIER: 416 | for (time1, time2, utt) in self.simple_transcript: 417 | utt = utt.strip() 418 | if utt and not utt[0] == ".": 419 | total += (float(time2) - float(time1)) 420 | return total 421 | 422 | def tier_name(self): 423 | """ 424 | @return: Tier name of a given tier. 425 | """ 426 | 427 | return self.nameid 428 | 429 | def classid(self): 430 | """ 431 | @return: Type of transcription on tier. 432 | """ 433 | 434 | return self.classid 435 | 436 | def min_max(self): 437 | """ 438 | @return: (xmin, xmax) tuple for a given tier. 439 | """ 440 | 441 | return (self.xmin, self.xmax) 442 | 443 | def __repr__(self): 444 | return "<%s \"%s\" (%.2f, %.2f) %.2f%%>" % (self.classid, self.nameid, self.xmin, self.xmax, 100*self.time()/self.t_time) 445 | 446 | def __str__(self): 447 | return self.__repr__() + "\n " + "\n ".join(" ".join(row) for row in self.simple_transcript) 448 | 449 | def demo_TextGrid(demo_data): 450 | print("** Demo of the TextGrid class. **") 451 | 452 | fid = TextGrid(demo_data) 453 | print("Tiers:", fid.size) 454 | 455 | for i, tier in enumerate(fid): 456 | print("\n***") 457 | print("Tier:", i + 1) 458 | print(tier) 459 | 460 | def demo(): 461 | # Each demo demonstrates different TextGrid formats. 462 | print("Format 1") 463 | demo_TextGrid(demo_data1) 464 | print("\nFormat 2") 465 | demo_TextGrid(demo_data2) 466 | print("\nFormat 3") 467 | demo_TextGrid(demo_data3) 468 | 469 | 470 | demo_data1 = """File type = "ooTextFile" 471 | Object class = "TextGrid" 472 | 473 | xmin = 0 474 | xmax = 2045.144149659864 475 | tiers? 476 | size = 3 477 | item []: 478 | item [1]: 479 | class = "IntervalTier" 480 | name = "utterances" 481 | xmin = 0 482 | xmax = 2045.144149659864 483 | intervals: size = 5 484 | intervals [1]: 485 | xmin = 0 486 | xmax = 2041.4217474125382 487 | text = "" 488 | intervals [2]: 489 | xmin = 2041.4217474125382 490 | xmax = 2041.968276643991 491 | text = "this" 492 | intervals [3]: 493 | xmin = 2041.968276643991 494 | xmax = 2042.5281632653062 495 | text = "is" 496 | intervals [4]: 497 | xmin = 2042.5281632653062 498 | xmax = 2044.0487352585324 499 | text = "a" 500 | intervals [5]: 501 | xmin = 2044.0487352585324 502 | xmax = 2045.144149659864 503 | text = "demo" 504 | item [2]: 505 | class = "TextTier" 506 | name = "notes" 507 | xmin = 0 508 | xmax = 2045.144149659864 509 | points: size = 3 510 | points [1]: 511 | time = 2041.4217474125382 512 | mark = ".begin_demo" 513 | points [2]: 514 | time = 2043.8338291031832 515 | mark = "voice gets quiet here" 516 | points [3]: 517 | time = 2045.144149659864 518 | mark = ".end_demo" 519 | item [3]: 520 | class = "IntervalTier" 521 | name = "phones" 522 | xmin = 0 523 | xmax = 2045.144149659864 524 | intervals: size = 12 525 | intervals [1]: 526 | xmin = 0 527 | xmax = 2041.4217474125382 528 | text = "" 529 | intervals [2]: 530 | xmin = 2041.4217474125382 531 | xmax = 2041.5438290324326 532 | text = "D" 533 | intervals [3]: 534 | xmin = 2041.5438290324326 535 | xmax = 2041.7321032910372 536 | text = "I" 537 | intervals [4]: 538 | xmin = 2041.7321032910372 539 | xmax = 2041.968276643991 540 | text = "s" 541 | intervals [5]: 542 | xmin = 2041.968276643991 543 | xmax = 2042.232189031843 544 | text = "I" 545 | intervals [6]: 546 | xmin = 2042.232189031843 547 | xmax = 2042.5281632653062 548 | text = "z" 549 | intervals [7]: 550 | xmin = 2042.5281632653062 551 | xmax = 2044.0487352585324 552 | text = "eI" 553 | intervals [8]: 554 | xmin = 2044.0487352585324 555 | xmax = 2044.2487352585324 556 | text = "dc" 557 | intervals [9]: 558 | xmin = 2044.2487352585324 559 | xmax = 2044.3102321849011 560 | text = "d" 561 | intervals [10]: 562 | xmin = 2044.3102321849011 563 | xmax = 2044.5748932104329 564 | text = "E" 565 | intervals [11]: 566 | xmin = 2044.5748932104329 567 | xmax = 2044.8329108578437 568 | text = "m" 569 | intervals [12]: 570 | xmin = 2044.8329108578437 571 | xmax = 2045.144149659864 572 | text = "oU" 573 | """ 574 | 575 | demo_data2 = """File type = "ooTextFile" 576 | Object class = "TextGrid" 577 | 578 | 0 579 | 2.8 580 | 581 | 2 582 | "IntervalTier" 583 | "utterances" 584 | 0 585 | 2.8 586 | 3 587 | 0 588 | 1.6229213249309031 589 | "" 590 | 1.6229213249309031 591 | 2.341428074708195 592 | "demo" 593 | 2.341428074708195 594 | 2.8 595 | "" 596 | "IntervalTier" 597 | "phones" 598 | 0 599 | 2.8 600 | 6 601 | 0 602 | 1.6229213249309031 603 | "" 604 | 1.6229213249309031 605 | 1.6428291382019483 606 | "dc" 607 | 1.6428291382019483 608 | 1.65372183721983721 609 | "d" 610 | 1.65372183721983721 611 | 1.94372874328943728 612 | "E" 613 | 1.94372874328943728 614 | 2.13821938291038210 615 | "m" 616 | 2.13821938291038210 617 | 2.341428074708195 618 | "oU" 619 | 2.341428074708195 620 | 2.8 621 | "" 622 | """ 623 | 624 | demo_data3 = """"Praat chronological TextGrid text file" 625 | 0 2.8 ! Time domain. 626 | 2 ! Number of tiers. 627 | "IntervalTier" "utterances" 0 2.8 628 | "IntervalTier" "utterances" 0 2.8 629 | 1 0 1.6229213249309031 630 | "" 631 | 2 0 1.6229213249309031 632 | "" 633 | 2 1.6229213249309031 1.6428291382019483 634 | "dc" 635 | 2 1.6428291382019483 1.65372183721983721 636 | "d" 637 | 2 1.65372183721983721 1.94372874328943728 638 | "E" 639 | 2 1.94372874328943728 2.13821938291038210 640 | "m" 641 | 2 2.13821938291038210 2.341428074708195 642 | "oU" 643 | 1 1.6229213249309031 2.341428074708195 644 | "demo" 645 | 1 2.341428074708195 2.8 646 | "" 647 | 2 2.341428074708195 2.8 648 | "" 649 | """ 650 | 651 | if __name__ == "__main__": 652 | demo() 653 | 654 | -------------------------------------------------------------------------------- /ridge_utils/tokenization_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import sys 4 | import joblib 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from ridge_utils.DataSequence import DataSequence 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | 10 | 11 | ### Warning, you are entering tokenization hell. 12 | 13 | def compute_correct_tokens_opt(acc, acc_lookback, acc_offset, total_len): 14 | #print(acc) 15 | new_tokens = [] 16 | new_tokens.append(2) # Special OPT start token 17 | acc_count_all = 0 18 | first_word = max(0,acc_offset-acc_lookback) 19 | last_word = min(acc_offset+1, total_len) 20 | acc_start = 0 21 | while acc_start != first_word + 1: 22 | if acc[acc_count_all] == 27: 23 | acc_start += 1 24 | acc_count_all += 1 25 | else: 26 | acc_count_all += 1 27 | 28 | acc2 = acc[acc_count_all:] 29 | acc_count8 = 0 30 | acc_count_all = 0 31 | while acc_count8 != (last_word - first_word): 32 | if acc2[acc_count_all] == 27: 33 | acc_count8 += 1 34 | acc_count_all += 1 35 | else: 36 | new_tokens.append(acc2[acc_count_all]) 37 | acc_count_all += 1 38 | return new_tokens 39 | 40 | 41 | 42 | def generate_efficient_feat_dicts_opt(wordseqs, tokenizer, lookback1, lookback2): 43 | text_dict = {} 44 | text_dict2 = {} 45 | text_dict3 = {} 46 | for story in wordseqs.keys(): 47 | ds = wordseqs[story] 48 | newdata = [] 49 | total_len = len(ds.data) 50 | acc = [] 51 | acc8 = 0 52 | text = [" ".join(ds.data)] 53 | text_len = len(text[0]) 54 | inputs = tokenizer(text, return_tensors="pt") 55 | tokens = np.array(inputs['input_ids'][0]) 56 | assert (27 not in tokens) 57 | # Annotate word boundaries 58 | for ei,i in enumerate(tokens): 59 | # A lot of tokenization edge cases 60 | if (tokenizer.decode(torch.tensor([i]))[0] == ' ' and tokenizer.decode(torch.tensor([i])).strip() != '') or (tokenizer.decode(torch.tensor([i])) != '' and ei == 1): 61 | acc.append(27) 62 | acc.append(i) 63 | acc8 += 1 64 | elif (ei==1860 and i == 2836) or (ei==349 and i == 1437) or (ei==365 and i == 1437) or (ei==1914 and i == 1437) or (ei==1305 and i == 1437) or (ei==300 and i==1437 and story=='beneaththemushroomcloud') or (ei==202 and i == 3432) or (ei==1316 and i==4514) or (ei==656 and i==2550) or (ei==1358 and i==6355) or (ei==2160 and i==8629) or (i==24929 and ei != 2): 65 | acc.append(27) 66 | acc.append(i) 67 | acc8 += 1 68 | else: 69 | acc.append(i) 70 | acc.append(27) 71 | #print(acc) 72 | lookback1 = 256 73 | lookback2 = 512 74 | acc_lookback = 0 75 | misc_offset = 0 76 | new_tokens = [2] 77 | #print(tokenizer.decode(new_tokens)) 78 | for i, w in enumerate(ds.data): 79 | if w.strip() != '' and w != "'s": 80 | if acc_lookback < lookback1: 81 | new_tokens = compute_correct_tokens_opt(acc, acc_lookback, i + misc_offset, total_len) 82 | #print(tokenizer.decode(torch.tensor(new_tokens))) 83 | text_dict[(story, i)] = new_tokens 84 | text_dict2[(story, i)] = False 85 | text_dict3[tuple(new_tokens)] = False 86 | elif lookback2 > acc_lookback and acc_lookback >= lookback1: 87 | new_tokens = compute_correct_tokens_opt(acc, acc_lookback, i + misc_offset, total_len) 88 | #print(tokenizer.decode(torch.tensor(new_tokens))) 89 | text_dict[(story, i)] = new_tokens 90 | text_dict2[(story, i)] = False 91 | text_dict3[tuple(new_tokens)] = False 92 | elif acc_lookback == lookback2: 93 | new_tokens = compute_correct_tokens_opt(acc, acc_lookback, i + misc_offset, total_len) 94 | #print(tokenizer.decode(torch.tensor(new_tokens))) 95 | acc_lookback = lookback1 96 | text_dict[(story, i)] = new_tokens 97 | text_dict2[(story, i)] = True 98 | text_dict3[tuple(new_tokens)] = False 99 | else: 100 | print("WARNING, LOOKBACK EDGE CASE 1", acc_lookback, "\n") 101 | assert False 102 | #print(max(0, i-acc_lookback), min(i+1, total_len)) 103 | #text = [" ".join(ds.data[max(0,i-acc_lookback):min(i+1,total_len)])][0] 104 | #print(text) 105 | text_dict[(story, i)] = new_tokens 106 | text_dict2[(story, i)] = False 107 | text_dict3[tuple(new_tokens)] = False 108 | else: 109 | #hidden_states = np.zeros((1024,)) 110 | text_dict[(story, i)] = new_tokens 111 | text_dict2[(story, i)] = True 112 | text_dict3[tuple(new_tokens)] = False 113 | acc_lookback += 1 114 | misc_offset -= 1 115 | continue 116 | acc_lookback += 1 117 | if i == total_len - 1: 118 | text_dict2[(story, i)] = True 119 | return text_dict, text_dict2, text_dict3 120 | 121 | 122 | def convert_to_feature_mats_opt(wordseqs, tokenizer, lookback1, lookback2, text_dict3): 123 | text_dict = {} 124 | text_dict2 = {} 125 | featureseqs = {} 126 | for story in wordseqs.keys(): 127 | ds = wordseqs[story] 128 | newdata = [] 129 | total_len = len(ds.data) 130 | acc = [] 131 | acc8 = 0 132 | text = [" ".join(ds.data)] 133 | text_len = len(text[0]) 134 | inputs = tokenizer(text, return_tensors="pt") 135 | tokens = np.array(inputs['input_ids'][0]) 136 | assert (27 not in tokens) 137 | # Annotate word boundaries 138 | for ei,i in enumerate(tokens): 139 | # A lot of tokenization edge cases 140 | if (tokenizer.decode(torch.tensor([i]))[0] == ' ' and tokenizer.decode(torch.tensor([i])).strip() != '') or (tokenizer.decode(torch.tensor([i])) != '' and ei == 1): 141 | acc.append(27) 142 | acc.append(i) 143 | acc8 += 1 144 | elif (ei==1860 and i == 2836) or (ei==349 and i == 1437) or (ei==365 and i == 1437) or (ei==1914 and i == 1437) or (ei==1305 and i == 1437) or (ei==300 and i==1437 and story=='beneaththemushroomcloud') or (ei==202 and i == 3432) or (ei==1316 and i==4514) or (ei==656 and i==2550) or (ei==1358 and i==6355) or (ei==2160 and i==8629) or (i==24929 and ei != 2): 145 | acc.append(27) 146 | acc.append(i) 147 | acc8 += 1 148 | else: 149 | acc.append(i) 150 | acc.append(27) 151 | lookback1 = 256 152 | lookback2 = 512 153 | acc_lookback = 0 154 | misc_offset = 0 155 | new_tokens = [2] 156 | for i, w in enumerate(ds.data): 157 | if w.strip() != '' and w != "'s": 158 | if acc_lookback < lookback1: 159 | new_tokens = compute_correct_tokens_opt(acc, acc_lookback, i + misc_offset, total_len) 160 | text_dict[(story, i)] = new_tokens 161 | text_dict2[(story, i)] = False 162 | newdata.append(text_dict3[tuple(new_tokens)]) 163 | elif lookback2 > acc_lookback and acc_lookback >= lookback1: 164 | new_tokens = compute_correct_tokens_opt(acc, acc_lookback, i + misc_offset, total_len) 165 | text_dict[(story, i)] = new_tokens 166 | text_dict2[(story, i)] = False 167 | newdata.append(text_dict3[tuple(new_tokens)]) 168 | elif acc_lookback == lookback2: 169 | new_tokens = compute_correct_tokens_opt(acc, acc_lookback, i + misc_offset, total_len) 170 | acc_lookback = lookback1 171 | text_dict[(story, i)] = new_tokens 172 | text_dict2[(story, i)] = True 173 | newdata.append(text_dict3[tuple(new_tokens)]) 174 | else: 175 | print("WARNING, LOOKBACK EDGE CASE 1", acc_lookback, "\n") 176 | assert False 177 | text_dict[(story, i)] = new_tokens 178 | text_dict2[(story, i)] = False 179 | newdata.append(text_dict3[tuple(new_tokens)]) 180 | else: 181 | text_dict[(story, i)] = new_tokens 182 | text_dict2[(story, i)] = True 183 | newdata.append(text_dict3[tuple(new_tokens)]) 184 | acc_lookback += 1 185 | misc_offset -= 1 186 | continue 187 | acc_lookback += 1 188 | if i == total_len - 1: 189 | text_dict2[(story, i)] = True 190 | featureseqs[story] = DataSequence(np.array(newdata), ds.split_inds, ds.data_times, ds.tr_times) 191 | downsampled_featureseqs = {} 192 | for story in featureseqs: 193 | downsampled_featureseqs[story] = featureseqs[story].chunksums('lanczos', window=3) 194 | return downsampled_featureseqs 195 | 196 | 197 | def compute_correct_tokens_llama(acc, acc_lookback, acc_offset, total_len): 198 | new_tokens = [1] 199 | acc_count_all = 0 200 | first_word = max(0,acc_offset-acc_lookback) 201 | last_word = min(acc_offset+1, total_len) 202 | acc_start = 0 203 | while acc_start != first_word + 1: 204 | if acc[acc_count_all] == 29947: 205 | acc_start += 1 206 | acc_count_all += 1 207 | else: 208 | acc_count_all += 1 209 | acc2 = acc[acc_count_all:] 210 | acc_count8 = 0 211 | acc_count_all = 0 212 | while acc_count8 != (last_word - first_word): 213 | if acc2[acc_count_all] == 29947: 214 | acc_count8 += 1 215 | acc_count_all += 1 216 | else: 217 | new_tokens.append(acc2[acc_count_all]) 218 | acc_count_all += 1 219 | return new_tokens 220 | 221 | def generate_efficient_feat_dicts_llama(wordseqs, tokenizer, lookback1, lookback2): 222 | text_dict = {} 223 | text_dict2 = {} 224 | text_dict3 = {} 225 | for es, story in enumerate(wordseqs.keys()): 226 | #print(story) 227 | ds = wordseqs[story] 228 | total_len = len(ds.data) 229 | text = [" ".join(ds.data)] 230 | inputs = tokenizer(text, return_tensors="pt") 231 | tokens = np.array(inputs['input_ids'][0]) 232 | assert (29947 not in tokens) # Use a dummy token '8' for marking word cutoffs 233 | acc = [1] # Contexts should start with special START token 234 | acc8 = 0 235 | acc_words = 0 236 | for ei,i in enumerate(tokens): 237 | if tokenizer.convert_ids_to_tokens(torch.tensor([i]))[0][0] == '▁' and (tokenizer.decode(torch.tensor([i])).strip() != ''): 238 | acc.append(29947) 239 | acc.append(i) 240 | acc8 += 1 241 | elif ei != (len(tokens) - 1): 242 | if (i == 29871) and (tokenizer.convert_ids_to_tokens(torch.tensor([tokens[ei+1]]))[0][0] != '▁'): 243 | acc.append(29947) 244 | acc.append(i) 245 | acc8 += 1 246 | else: 247 | acc.append(i) 248 | else: 249 | acc.append(i) 250 | decoded = tokenizer.decode(torch.tensor(acc)) 251 | acc_words = 0 252 | for i in ds.data: 253 | if i.strip() != '': 254 | acc_words += 1 255 | #print(acc8, acc_words, story, es) 256 | assert acc8 == acc_words # Number of annotations should equal number of words 257 | acc.append(29947) 258 | acc_lookback = 0 259 | misc_offset = 0 260 | new_tokens = [1] 261 | for i, w in enumerate(ds.data): 262 | if w.strip() != '' and w != "'s": 263 | if acc_lookback < lookback1 or (lookback2 > acc_lookback and acc_lookback >= lookback1): 264 | new_tokens = compute_correct_tokens_llama(acc, acc_lookback, i + misc_offset, total_len) 265 | text_dict[(story, i)] = new_tokens 266 | text_dict2[(story, i)] = False 267 | text_dict3[tuple(new_tokens)] = False 268 | elif acc_lookback == lookback2: 269 | new_tokens = compute_correct_tokens_llama(acc, acc_lookback, i + misc_offset, total_len) 270 | acc_lookback = lookback1 271 | text_dict[(story, i)] = new_tokens 272 | text_dict2[(story, i)] = True 273 | text_dict3[tuple(new_tokens)] = False 274 | else: 275 | assert False 276 | else: 277 | text_dict[(story, i)] = new_tokens 278 | text_dict2[(story, i)] = True 279 | text_dict3[tuple(new_tokens)] = False 280 | acc_lookback += 1 281 | misc_offset -= 1 282 | continue 283 | acc_lookback += 1 284 | if i == total_len - 1: 285 | text_dict2[(story, i)] = True 286 | return text_dict, text_dict2, text_dict3 287 | 288 | def convert_to_feature_mats_llama(wordseqs, tokenizer, lookback1, lookback2, text_dict3): 289 | text_dict = {} 290 | text_dict2 = {} 291 | featureseqs = {} 292 | for es, story in enumerate(wordseqs.keys()): 293 | #print(story) 294 | ds = wordseqs[story] 295 | newdata = [] 296 | total_len = len(ds.data) 297 | text = [" ".join(ds.data)] 298 | inputs = tokenizer(text, return_tensors="pt") 299 | tokens = np.array(inputs['input_ids'][0]) 300 | assert (29947 not in tokens) # Use a dummy token '8' for marking word cutoffs 301 | acc = [1] # Contexts should start with special START token 302 | acc8 = 0 303 | acc_words = 0 304 | for ei,i in enumerate(tokens): 305 | if tokenizer.convert_ids_to_tokens(torch.tensor([i]))[0][0] == '▁' and (tokenizer.decode(torch.tensor([i])).strip() != ''): 306 | acc.append(29947) 307 | acc.append(i) 308 | acc8 += 1 309 | elif ei != (len(tokens) - 1): 310 | if (i == 29871) and (tokenizer.convert_ids_to_tokens(torch.tensor([tokens[ei+1]]))[0][0] != '▁'): 311 | acc.append(29947) 312 | acc.append(i) 313 | acc8 += 1 314 | else: 315 | acc.append(i) 316 | else: 317 | acc.append(i) 318 | decoded = tokenizer.decode(torch.tensor(acc)) 319 | acc_words = 0 320 | for i in ds.data: 321 | if i.strip() != '': 322 | acc_words += 1 323 | #print(acc8, acc_words, story, es) 324 | assert acc8 == acc_words # Number of annotations should equal number of words 325 | acc.append(29947) 326 | acc_lookback = 0 327 | misc_offset = 0 328 | new_tokens = [1] 329 | for i, w in enumerate(ds.data): 330 | if w.strip() != '' and w != "'s": 331 | if acc_lookback < lookback1 or (lookback2 > acc_lookback and acc_lookback >= lookback1): 332 | new_tokens = compute_correct_tokens_llama(acc, acc_lookback, i + misc_offset, total_len) 333 | text_dict[(story, i)] = new_tokens 334 | text_dict2[(story, i)] = False 335 | newdata.append(text_dict3[tuple(new_tokens)]) 336 | elif acc_lookback == lookback2: 337 | new_tokens = compute_correct_tokens_llama(acc, acc_lookback, i + misc_offset, total_len) 338 | acc_lookback = lookback1 339 | text_dict[(story, i)] = new_tokens 340 | text_dict2[(story, i)] = True 341 | newdata.append(text_dict3[tuple(new_tokens)]) 342 | else: 343 | assert False 344 | else: 345 | text_dict[(story, i)] = new_tokens 346 | text_dict2[(story, i)] = True 347 | newdata.append(text_dict3[tuple(new_tokens)]) 348 | acc_lookback += 1 349 | misc_offset -= 1 350 | continue 351 | acc_lookback += 1 352 | if i == total_len - 1: 353 | text_dict2[(story, i)] = True 354 | featureseqs[story] = DataSequence(np.array(newdata), ds.split_inds, ds.data_times, ds.tr_times) 355 | downsampled_featureseqs = {} 356 | for story in featureseqs: 357 | downsampled_featureseqs[story] = featureseqs[story].chunksums('lanczos', window=3) 358 | return downsampled_featureseqs 359 | -------------------------------------------------------------------------------- /ridge_utils/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tables 3 | #from matplotlib.pyplot import figure, show 4 | import scipy.linalg 5 | 6 | def make_delayed(stim, delays, circpad=False): 7 | """Creates non-interpolated concatenated delayed versions of [stim] with the given [delays] 8 | (in samples). 9 | 10 | If [circpad], instead of being padded with zeros, [stim] will be circularly shifted. 11 | """ 12 | nt,ndim = stim.shape 13 | dstims = [] 14 | for di,d in enumerate(delays): 15 | dstim = np.zeros((nt, ndim)) 16 | if d<0: ## negative delay 17 | dstim[:d,:] = stim[-d:,:] 18 | if circpad: 19 | dstim[d:,:] = stim[:-d,:] 20 | elif d>0: 21 | dstim[d:,:] = stim[:-d,:] 22 | if circpad: 23 | dstim[:d,:] = stim[-d:,:] 24 | else: ## d==0 25 | dstim = stim.copy() 26 | dstims.append(dstim) 27 | return np.hstack(dstims) 28 | 29 | def best_corr_vec(wvec, vocab, SU, n=10): 30 | """Returns the [n] words from [vocab] most similar to the given [wvec], where each word is represented 31 | as a row in [SU]. Similarity is computed using correlation.""" 32 | wvec = wvec - np.mean(wvec) 33 | nwords = len(vocab) 34 | corrs = np.nan_to_num([np.corrcoef(wvec, SU[wi,:]-np.mean(SU[wi,:]))[1,0] for wi in range(nwords-1)]) 35 | scorrs = np.argsort(corrs) 36 | words = list(reversed([(corrs[i],vocab[i]) for i in scorrs[-n:]])) 37 | return words 38 | 39 | def get_word_prob(): 40 | """Returns the probabilities of all the words in the mechanical turk video labels. 41 | """ 42 | import constants as c 43 | import cPickle 44 | data = cPickle.load(open(c.datafile)) # Read in the words from the labels 45 | wordcount = dict() 46 | totalcount = 0 47 | for label in data: 48 | for word in label: 49 | totalcount += 1 50 | if word in wordcount: 51 | wordcount[word] += 1 52 | else: 53 | wordcount[word] = 1 54 | 55 | wordprob = dict([(word, float(wc)/totalcount) for word, wc in wordcount.items()]) 56 | return wordprob 57 | 58 | def best_prob_vec(wvec, vocab, space, wordprobs): 59 | """Orders the words by correlation with the given [wvec], but also weights the correlations by the prior 60 | probability of the word appearing in the mechanical turk video labels. 61 | """ 62 | words = best_corr_vec(wvec, vocab, space, n=len(vocab)) ## get correlations for all words 63 | ## weight correlations by the prior probability of the word in the labels 64 | weightwords = [] 65 | for wcorr,word in words: 66 | if word in wordprobs: 67 | weightwords.append((wordprobs[word]*wcorr, word)) 68 | 69 | return sorted(weightwords, key=lambda ww: ww[0]) 70 | 71 | def find_best_words(vectors, vocab, wordspace, actual, display=True, num=15): 72 | cwords = [] 73 | for si in range(len(vectors)): 74 | cw = best_corr_vec(vectors[si], vocab, wordspace, n=num) 75 | cwords.append(cw) 76 | if display: 77 | print("Closest words to scene %d:" % si) 78 | print([b[1] for b in cw]) 79 | print("Actual words:") 80 | print(actual[si]) 81 | print("") 82 | return cwords 83 | 84 | def find_best_stims_for_word(wordvector, decstims, n): 85 | """Returns a list of the indexes of the [n] stimuli in [decstims] (should be decoded stimuli) 86 | that lie closest to the vector [wordvector], which should be taken from the same space as the 87 | stimuli. 88 | """ 89 | scorrs = np.array([np.corrcoef(wordvector, ds)[0,1] for ds in decstims]) 90 | scorrs[np.isnan(scorrs)] = -1 91 | return np.argsort(scorrs)[-n:][::-1] 92 | 93 | def princomp(x, use_dgesvd=False): 94 | """Does principal components analysis on [x]. 95 | Returns coefficients, scores and latent variable values. 96 | Translated from MATLAB princomp function. Unlike the matlab princomp function, however, the 97 | rows of the returned value 'coeff' are the principal components, not the columns. 98 | """ 99 | 100 | n,p = x.shape 101 | #cx = x-np.tile(x.mean(0), (n,1)) ## column-centered x 102 | cx = x-x.mean(0) 103 | r = np.min([n-1,p]) ## maximum possible rank of cx 104 | 105 | if use_dgesvd: 106 | from svd_dgesvd import svd_dgesvd 107 | U,sigma,coeff = svd_dgesvd(cx, full_matrices=False) 108 | else: 109 | U,sigma,coeff = np.linalg.svd(cx, full_matrices=False) 110 | 111 | sigma = np.diag(sigma) 112 | score = np.dot(cx, coeff.T) 113 | sigma = sigma/np.sqrt(n-1) 114 | 115 | latent = sigma**2 116 | 117 | return coeff, score, latent 118 | 119 | def eigprincomp(x, npcs=None, norm=False, weights=None, index=0): 120 | """Does principal components analysis on [x]. 121 | Returns coefficients (eigenvectors) and eigenvalues. 122 | If given, only the [npcs] greatest eigenvectors/values will be returned. 123 | If given, the covariance matrix will be computed using [weights] on the samples. 124 | """ 125 | n,p = x.shape 126 | #cx = x-np.tile(x.mean(0), (n,1)) ## column-centered x 127 | cx = x-x.mean(index) 128 | r = np.min([n-1,p]) ## maximum possible rank of cx 129 | 130 | xcov = np.cov(cx.T) 131 | if norm: 132 | xcov /= n 133 | 134 | if npcs is not None: 135 | latent,coeff = scipy.linalg.eigh(xcov, eigvals=(p-npcs,p-1)) 136 | else: 137 | latent,coeff = np.linalg.eigh(xcov) 138 | 139 | ## Transpose coeff, reverse its rows 140 | return coeff.T[::-1], latent[::-1] 141 | 142 | def weighted_cov(x, weights=None): 143 | """If given [weights], the covariance will be computed using those weights on the samples. 144 | Otherwise the simple covariance will be returned. 145 | """ 146 | if weights is None: 147 | return np.cov(x) 148 | else: 149 | w = weights/weights.sum() ## Normalize the weights 150 | dmx = (x.T-(w*x).sum(1)).T ## Subtract the WEIGHTED mean 151 | wfact = 1/(1-(w**2).sum()) ## Compute the weighting factor 152 | return wfact*np.dot(w*dmx, dmx.T.conj()) ## Take the weighted inner product 153 | 154 | def test_weighted_cov(): 155 | """Runs a test on the weighted_cov function, creating a dataset for which the covariance is known 156 | for two different populations, and weights are used to reproduce the individual covariances. 157 | """ 158 | T = 1000 ## number of time points 159 | N = 100 ## A signals 160 | M = 100 ## B signals 161 | snr = 5 ## signal to noise ratio 162 | 163 | ## Create the two datasets 164 | siga = np.random.rand(T) 165 | noisea = np.random.rand(T, N) 166 | respa = (noisea.T+snr*siga).T 167 | 168 | sigb = np.random.rand(T) 169 | noiseb = np.random.rand(T, M) 170 | respb = (noiseb.T+snr*sigb).T 171 | 172 | ## Compute self-covariance matrixes 173 | cova = np.cov(respa) 174 | covb = np.cov(respb) 175 | 176 | ## Compute the full covariance matrix 177 | allresp = np.hstack([respa, respb]) 178 | fullcov = np.cov(allresp) 179 | 180 | ## Make weights that will recover individual covariances 181 | wta = np.ones([N+M,]) 182 | wta[N:] = 0 183 | 184 | wtb = np.ones([N+M,]) 185 | wtb[:N] = 0 186 | 187 | recova = weighted_cov(allresp, wta) 188 | recovb = weighted_cov(allresp, wtb) 189 | 190 | return locals() 191 | 192 | def fixPCs(orig, new): 193 | """Finds and fixes sign-flips in PCs by finding the coefficient with the greatest 194 | magnitude in the [orig] PCs, then negating the [new] PCs if that coefficient has 195 | a different sign. 196 | """ 197 | flipped = [] 198 | for o,n in zip(orig, new): 199 | maxind = np.abs(o).argmax() 200 | if o[maxind]*n[maxind]>0: 201 | ## Same sign, no need to flip 202 | flipped.append(n) 203 | else: 204 | ## Different sign, flip 205 | flipped.append(-n) 206 | 207 | return np.vstack(flipped) 208 | 209 | 210 | def plot_model_comparison(corrs1, corrs2, name1, name2, thresh=0.35): 211 | fig = figure(figsize=(8,8)) 212 | ax = fig.add_subplot(1,1,1) 213 | 214 | good1 = corrs1>thresh 215 | good2 = corrs2>thresh 216 | better1 = corrs1>corrs2 217 | #both = np.logical_and(good1, good2) 218 | neither = np.logical_not(np.logical_or(good1, good2)) 219 | only1 = np.logical_and(good1, better1) 220 | only2 = np.logical_and(good2, np.logical_not(better1)) 221 | 222 | ptalpha = 0.3 223 | ax.plot(corrs1[neither], corrs2[neither], 'ko', alpha=ptalpha) 224 | #ax.plot(corrs1[both], corrs2[both], 'go', alpha=ptalpha) 225 | ax.plot(corrs1[only1], corrs2[only1], 'ro', alpha=ptalpha) 226 | ax.plot(corrs1[only2], corrs2[only2], 'bo', alpha=ptalpha) 227 | 228 | lims = [-0.5, 1.0] 229 | 230 | ax.plot([thresh, thresh], [lims[0], thresh], 'r-') 231 | ax.plot([lims[0], thresh], [thresh,thresh], 'b-') 232 | 233 | ax.text(lims[0]+0.05, thresh, "$n=%d$"%np.sum(good2), horizontalalignment="left", verticalalignment="bottom") 234 | ax.text(thresh, lims[0]+0.05, "$n=%d$"%np.sum(good1), horizontalalignment="left", verticalalignment="bottom") 235 | 236 | ax.plot(lims, lims, '-', color="gray") 237 | ax.set_xlim(lims) 238 | ax.set_ylim(lims) 239 | ax.set_xlabel(name1) 240 | ax.set_ylabel(name2) 241 | 242 | show() 243 | return fig 244 | 245 | import matplotlib.colors 246 | bwr = matplotlib.colors.LinearSegmentedColormap.from_list("bwr", ((0.0, 0.0, 1.0), (1.0, 1.0, 1.0), (1.0, 0.0, 0.0))) 247 | bkr = matplotlib.colors.LinearSegmentedColormap.from_list("bkr", ((0.0, 0.0, 1.0), (0.0, 0.0, 0.0), (1.0, 0.0, 0.0))) 248 | bgr = matplotlib.colors.LinearSegmentedColormap.from_list("bgr", ((0.0, 0.0, 1.0), (0.5, 0.5, 0.5), (1.0, 0.0, 0.0))) 249 | 250 | def plot_model_comparison2(corrFile1, corrFile2, name1, name2, thresh=0.35): 251 | fig = figure(figsize=(9,10)) 252 | #ax = fig.add_subplot(3,1,[1,2], aspect="equal") 253 | ax = fig.add_axes([0.25, 0.4, 0.6, 0.5], aspect="equal") 254 | 255 | corrs1 = tables.openFile(corrFile1).root.semcorr.read() 256 | corrs2 = tables.openFile(corrFile2).root.semcorr.read() 257 | maxcorr = np.clip(np.vstack([corrs1, corrs2]).max(0), 0, thresh)/thresh 258 | corrdiff = (corrs1-corrs2) + 0.5 259 | colors = (bgr(corrdiff).T*maxcorr).T 260 | colors[:,3] = 1.0 ## Don't scale alpha 261 | 262 | ptalpha = 0.8 263 | ax.scatter(corrs1, corrs2, s=10, c=colors, alpha=ptalpha, edgecolors="none") 264 | lims = [-0.5, 1.0] 265 | 266 | ax.plot([thresh, thresh], [lims[0], thresh], color="gray") 267 | ax.plot([lims[0], thresh], [thresh,thresh], color="gray") 268 | 269 | good1 = corrs1>thresh 270 | good2 = corrs2>thresh 271 | ax.text(lims[0]+0.05, thresh, "$n=%d$"%np.sum(good2), horizontalalignment="left", verticalalignment="bottom") 272 | ax.text(thresh, lims[0]+0.05, "$n=%d$"%np.sum(good1), horizontalalignment="left", verticalalignment="bottom") 273 | 274 | ax.plot(lims, lims, '-', color="gray") 275 | ax.set_xlim(lims) 276 | ax.set_ylim(lims) 277 | ax.set_xlabel(name1+" model") 278 | ax.set_ylabel(name2+" model") 279 | 280 | fig.canvas.draw() 281 | show() 282 | ## Add over-under comparison 283 | #ax_left = ax.get_window_extent()._bbox.x0 284 | #ax_right = ax.get_window_extent()._bbox.x1 285 | #ax_width = ax_right-ax_left 286 | #print(ax_left, ax_right 287 | #ax2 = fig.add_axes([ax_left, 0.1, ax_width, 0.2]) 288 | ax2 = fig.add_axes([0.25, 0.1, 0.6, 0.25])#, sharex=ax) 289 | #ax2 = fig.add_subplot(3, 1, 3) 290 | #plot_model_overunder_comparison(corrs1, corrs2, name1, name2, thresh=thresh, ax=ax2) 291 | plot_model_histogram_comparison(corrs1, corrs2, name1, name2, thresh=thresh, ax=ax2) 292 | 293 | fig.suptitle("Model comparison: %s vs. %s"%(name1, name2)) 294 | show() 295 | return fig 296 | 297 | 298 | def plot_model_overunder_comparison(corrs1, corrs2, name1, name2, thresh=0.35, ax=None): 299 | """Plots over-under difference between two models. 300 | """ 301 | if ax is None: 302 | fig = figure(figsize=(8,8)) 303 | ax = fig.add_subplot(1,1,1) 304 | 305 | maxcorr = max(corrs1.max(), corrs2.max()) 306 | vals = np.linspace(0, maxcorr, 500) 307 | overunder = lambda c: np.array([np.sum(c>v)-np.sum(c<-v) for v in vals]) 308 | 309 | ou1 = overunder(corrs1) 310 | ou2 = overunder(corrs2) 311 | 312 | oud = ou2-ou1 313 | 314 | ax.fill_between(vals, 0, np.clip(oud, 0, 1e9), facecolor="blue") 315 | ax.fill_between(vals, 0, np.clip(oud, -1e9, 0), facecolor="red") 316 | 317 | yl = np.max(np.abs(np.array(ax.get_ylim()))) 318 | ax.plot([thresh, thresh], [-yl, yl], '-', color="gray") 319 | ax.set_ylim(-yl, yl) 320 | ax.set_xlim(0, maxcorr) 321 | ax.set_xlabel("Voxel correlation") 322 | ax.set_ylabel("%s better %s better"%(name1, name2)) 323 | 324 | show() 325 | return ax 326 | 327 | def plot_model_histogram_comparison(corrs1, corrs2, name1, name2, thresh=0.35, ax=None): 328 | """Plots over-under difference between two models. 329 | """ 330 | if ax is None: 331 | fig = figure(figsize=(8,8)) 332 | ax = fig.add_subplot(1,1,1) 333 | 334 | maxcorr = max(corrs1.max(), corrs2.max()) 335 | nbins = 100 336 | hist1 = np.histogram(corrs1, nbins, range=(-1,1)) 337 | hist2 = np.histogram(corrs2, nbins, range=(-1,1)) 338 | 339 | ouhist1 = hist1[0][nbins/2:]-hist1[0][:nbins/2][::-1] 340 | ouhist2 = hist2[0][nbins/2:]-hist2[0][:nbins/2][::-1] 341 | 342 | oud = ouhist2-ouhist1 343 | bwidth = 2.0/nbins 344 | barlefts = hist1[1][nbins/2:-1] 345 | 346 | #ax.fill_between(vals, 0, np.clip(oud, 0, 1e9), facecolor="blue") 347 | #ax.fill_between(vals, 0, np.clip(oud, -1e9, 0), facecolor="red") 348 | 349 | ax.bar(barlefts, np.clip(oud, 0, 1e9), bwidth, facecolor="blue") 350 | ax.bar(barlefts, np.clip(oud, -1e9, 0), bwidth, facecolor="red") 351 | 352 | yl = np.max(np.abs(np.array(ax.get_ylim()))) 353 | ax.plot([thresh, thresh], [-yl, yl], '-', color="gray") 354 | ax.set_ylim(-yl, yl) 355 | ax.set_xlim(0, maxcorr) 356 | ax.set_xlabel("Voxel correlation") 357 | ax.set_ylabel("%s better %s better"%(name1, name2)) 358 | 359 | show() 360 | return ax 361 | 362 | 363 | def plot_model_comparison_rois(corrs1, corrs2, name1, name2, roivoxels, roinames, thresh=0.35): 364 | """Plots model correlation comparisons per ROI. 365 | """ 366 | fig = figure() 367 | ptalpha = 0.3 368 | 369 | for ri in range(len(roinames)): 370 | ax = fig.add_subplot(4, 4, ri+1) 371 | ax.plot(corrs1[roivoxels[ri]], corrs2[roivoxels[ri]], 'bo', alpha=ptalpha) 372 | lims = [-0.3, 1.0] 373 | ax.plot(lims, lims, '-', color="gray") 374 | ax.set_xlim(lims) 375 | ax.set_ylim(lims) 376 | ax.set_title(roinames[ri]) 377 | 378 | show() 379 | return fig 380 | 381 | def save_table_file(filename, filedict): 382 | """Saves the variables in [filedict] in a hdf5 table file at [filename]. 383 | """ 384 | hf = tables.openFile(filename, mode="w", title="save_file") 385 | for vname, var in filedict.items(): 386 | hf.createArray("/", vname, var) 387 | hf.close() 388 | 389 | -------------------------------------------------------------------------------- /ridge_utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | #import scipy.stats 4 | import random 5 | import sys 6 | 7 | def zscore(mat, return_unzvals=False): 8 | """Z-scores the rows of [mat] by subtracting off the mean and dividing 9 | by the standard deviation. 10 | If [return_unzvals] is True, a matrix will be returned that can be used 11 | to return the z-scored values to their original state. 12 | """ 13 | zmat = np.empty(mat.shape, mat.dtype) 14 | unzvals = np.zeros((zmat.shape[0], 2), mat.dtype) 15 | for ri in range(mat.shape[0]): 16 | unzvals[ri,0] = np.std(mat[ri,:]) 17 | unzvals[ri,1] = np.mean(mat[ri,:]) 18 | zmat[ri,:] = (mat[ri,:]-unzvals[ri,1]) / (1e-10+unzvals[ri,0]) 19 | 20 | if return_unzvals: 21 | return zmat, unzvals 22 | 23 | return zmat 24 | 25 | def center(mat, return_uncvals=False): 26 | """Centers the rows of [mat] by subtracting off the mean, but doesn't 27 | divide by the SD. 28 | Can be undone like zscore. 29 | """ 30 | cmat = np.empty(mat.shape) 31 | uncvals = np.ones((mat.shape[0], 2)) 32 | for ri in range(mat.shape[0]): 33 | uncvals[ri,1] = np.mean(mat[ri,:]) 34 | cmat[ri,:] = mat[ri,:]-uncvals[ri,1] 35 | 36 | if return_uncvals: 37 | return cmat, uncvals 38 | 39 | return cmat 40 | 41 | def unzscore(mat, unzvals): 42 | """Un-Z-scores the rows of [mat] by multiplying by unzvals[:,0] (the standard deviations) 43 | and then adding unzvals[:,1] (the row means). 44 | """ 45 | unzmat = np.empty(mat.shape) 46 | for ri in range(mat.shape[0]): 47 | unzmat[ri,:] = mat[ri,:]*(1e-10+unzvals[ri,0])+unzvals[ri,1] 48 | return unzmat 49 | 50 | def gaussianize(vec): 51 | """Uses a look-up table to force the values in [vec] to be gaussian.""" 52 | ranks = np.argsort(np.argsort(vec)) 53 | cranks = (ranks+1).astype(float)/(ranks.max()+2) 54 | vals = scipy.stats.norm.isf(1-cranks) 55 | zvals = vals/vals.std() 56 | return zvals 57 | 58 | def gaussianize_mat(mat): 59 | """Gaussianizes each column of [mat].""" 60 | gmat = np.empty(mat.shape) 61 | for ri in range(mat.shape[1]): 62 | gmat[:,ri] = gaussianize(mat[:,ri]) 63 | return gmat 64 | 65 | def make_delayed(stim, delays, circpad=False): 66 | """Creates non-interpolated concatenated delayed versions of [stim] with the given [delays] 67 | (in samples). 68 | 69 | If [circpad], instead of being padded with zeros, [stim] will be circularly shifted. 70 | """ 71 | nt,ndim = stim.shape 72 | dstims = [] 73 | for di,d in enumerate(delays): 74 | dstim = np.zeros((nt, ndim)) 75 | if d<0: ## negative delay 76 | dstim[:d,:] = stim[-d:,:] 77 | if circpad: 78 | dstim[d:,:] = stim[:-d,:] 79 | elif d>0: 80 | dstim[d:,:] = stim[:-d,:] 81 | if circpad: 82 | dstim[:d,:] = stim[-d:,:] 83 | else: ## d==0 84 | dstim = stim.copy() 85 | dstims.append(dstim) 86 | return np.hstack(dstims) 87 | 88 | def mult_diag(d, mtx, left=True): 89 | """Multiply a full matrix by a diagonal matrix. 90 | This function should always be faster than dot. 91 | 92 | Input: 93 | d -- 1D (N,) array (contains the diagonal elements) 94 | mtx -- 2D (N,N) array 95 | 96 | Output: 97 | mult_diag(d, mts, left=True) == dot(diag(d), mtx) 98 | mult_diag(d, mts, left=False) == dot(mtx, diag(d)) 99 | 100 | By Pietro Berkes 101 | From http://mail.scipy.org/pipermail/numpy-discussion/2007-March/026807.html 102 | """ 103 | if left: 104 | return (d*mtx.T).T 105 | else: 106 | return d*mtx 107 | 108 | import time 109 | import logging 110 | def counter(iterable, countevery=100, total=None, logger=logging.getLogger("counter")): 111 | """Logs a status and timing update to [logger] every [countevery] draws from [iterable]. 112 | If [total] is given, log messages will include the estimated time remaining. 113 | """ 114 | start_time = time.time() 115 | 116 | ## Check if the iterable has a __len__ function, use it if no total length is supplied 117 | if total is None: 118 | if hasattr(iterable, "__len__"): 119 | total = len(iterable) 120 | 121 | for count, thing in enumerate(iterable): 122 | yield thing 123 | 124 | if not count%countevery: 125 | current_time = time.time() 126 | rate = float(count+1)/(current_time-start_time) 127 | 128 | if rate>1: ## more than 1 item/second 129 | ratestr = "%0.2f items/second"%rate 130 | else: ## less than 1 item/second 131 | ratestr = "%0.2f seconds/item"%(rate**-1) 132 | 133 | if total is not None: 134 | remitems = total-(count+1) 135 | remtime = remitems/rate 136 | timestr = ", %s remaining" % time.strftime('%H:%M:%S', time.gmtime(remtime)) 137 | itemstr = "%d/%d"%(count+1, total) 138 | else: 139 | timestr = "" 140 | itemstr = "%d"%(count+1) 141 | 142 | formatted_str = "%s items complete (%s%s)"%(itemstr,ratestr,timestr) 143 | if logger is None: 144 | print(formatted_str) 145 | else: 146 | logger.info(formatted_str) 147 | -------------------------------------------------------------------------------- /speech_model_configs.json: -------------------------------------------------------------------------------- 1 | { 2 | "hubert-base": { 3 | "huggingface_hub": "facebook/hubert-base-ls960", 4 | "stride": 320, 5 | "min_input_length": 400 6 | }, 7 | "hubert-large": { 8 | "huggingface_hub": "facebook/hubert-large-ll60k", 9 | "stride": 320, 10 | "min_input_length": 400 11 | }, 12 | "hubert-xlarge": { 13 | "huggingface_hub": "facebook/hubert-xlarge-ll60k", 14 | "stride": 320, 15 | "min_input_length": 400 16 | }, 17 | "wavlm-base": { 18 | "huggingface_hub": "microsoft/wavlm-base", 19 | "stride": 320, 20 | "min_input_length": 400 21 | }, 22 | "wavlm-base-plus": { 23 | "huggingface_hub": "microsoft/wavlm-base-plus", 24 | "stride": 320, 25 | "min_input_length": 400 26 | }, 27 | "wavlm-large": { 28 | "huggingface_hub": "microsoft/wavlm-large", 29 | "stride": 320, 30 | "min_input_length": 400 31 | }, 32 | "whisper-tiny": { 33 | "huggingface_hub": "openai/whisper-tiny", 34 | "stride": 320, 35 | "min_input_length": 400 36 | }, 37 | "whisper-base": { 38 | "huggingface_hub": "openai/whisper-base", 39 | "stride": 320, 40 | "min_input_length": 400 41 | }, 42 | "whisper-small": { 43 | "huggingface_hub": "openai/whisper-small", 44 | "stride": 320, 45 | "min_input_length": 400 46 | }, 47 | "whisper-medium": { 48 | "huggingface_hub": "openai/whisper-medium", 49 | "stride": 320, 50 | "min_input_length": 400 51 | }, 52 | "whisper-large": { 53 | "huggingface_hub": "openai/whisper-large", 54 | "stride": 320, 55 | "min_input_length": 400 56 | } 57 | } 58 | --------------------------------------------------------------------------------