├── .gitignore ├── README.md ├── notebooks └── formatCompetitionData.ipynb ├── pyproject.toml ├── scripts ├── eval_competition.py ├── eval_competition.sh └── train_model.py ├── setup.cfg ├── setup.py └── src └── neural_decoder ├── augmentations.py ├── conf ├── config.yaml └── hydra │ └── launcher │ ├── big_gpu_slurm_med_time.yaml │ └── gpu_slurm_med_time.yaml ├── dataset.py ├── model.py └── neural_decoder_trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch implementation of [Neural Sequence Decoder](https://github.com/fwillett/speechBCI/tree/main/NeuralDecoder) 2 | 3 | ## Requirements 4 | - python >= 3.9 5 | 6 | ## Installation 7 | 8 | pip install -e . 9 | 10 | ## How to run 11 | 12 | 1. Convert the speech BCI dataset using [formatCompetitionData.ipynb](./notebooks/formatCompetitionData.ipynb) 13 | 2. Train model: `python ./scripts/train_model.py` 14 | 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=46.1.0"] -------------------------------------------------------------------------------- /scripts/eval_competition.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | import pickle 4 | import numpy as np 5 | 6 | from edit_distance import SequenceMatcher 7 | import torch 8 | from dataset import SpeechDataset 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | from nnDecoderModel import getDatasetLoaders 14 | from nnDecoderModel import loadModel 15 | import neuralDecoder.utils.lmDecoderUtils as lmDecoderUtils 16 | import pickle 17 | import argparse 18 | 19 | parser = argparse.ArgumentParser(description="") 20 | parser.add_argument("--modelPath", type=str, default=None, help="Path to model") 21 | input_args = parser.parse_args() 22 | 23 | 24 | with open(input_args.modelPath + "/args", "rb") as handle: 25 | args = pickle.load(handle) 26 | 27 | args["datasetPath"] = "/oak/stanford/groups/henderj/stfan/data/ptDecoder_ctc" 28 | trainLoaders, testLoaders, loadedData = getDatasetLoaders( 29 | args["datasetPath"], args["seqLen"], args["maxTimeSeriesLen"], args["batchSize"] 30 | ) 31 | 32 | model = loadModel(input_args.modelPath, device="cpu") 33 | 34 | device = "cpu" 35 | 36 | model.eval() 37 | 38 | rnn_outputs = { 39 | "logits": [], 40 | "logitLengths": [], 41 | "trueSeqs": [], 42 | "transcriptions": [], 43 | } 44 | 45 | partition = "competition" # "test" 46 | if partition == "competition": 47 | testDayIdxs = [4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 20] 48 | elif partition == "test": 49 | testDayIdxs = range(len(loadedData[partition])) 50 | 51 | for i, testDayIdx in testDayIdxs: 52 | test_ds = SpeechDataset([loadedData[partition][i]]) 53 | test_loader = torch.utils.data.DataLoader( 54 | test_ds, batch_size=1, shuffle=False, num_workers=0 55 | ) 56 | for j, (X, y, X_len, y_len, _) in enumerate(test_loader): 57 | X, y, X_len, y_len, dayIdx = ( 58 | X.to(device), 59 | y.to(device), 60 | X_len.to(device), 61 | y_len.to(device), 62 | torch.tensor([testDayIdx], dtype=torch.int64).to(device), 63 | ) 64 | pred = model.forward(X, dayIdx) 65 | adjustedLens = ((X_len - model.kernelLen) / model.strideLen).to(torch.int32) 66 | 67 | for iterIdx in range(pred.shape[0]): 68 | trueSeq = np.array(y[iterIdx][0 : y_len[iterIdx]].cpu().detach()) 69 | 70 | rnn_outputs["logits"].append(pred[iterIdx].cpu().detach().numpy()) 71 | rnn_outputs["logitLengths"].append( 72 | adjustedLens[iterIdx].cpu().detach().item() 73 | ) 74 | rnn_outputs["trueSeqs"].append(trueSeq) 75 | 76 | transcript = loadedData[partition][i]["transcriptions"][j].strip() 77 | transcript = re.sub(r"[^a-zA-Z\- \']", "", transcript) 78 | transcript = transcript.replace("--", "").lower() 79 | rnn_outputs["transcriptions"].append(transcript) 80 | 81 | 82 | MODEL_CACHE_DIR = "/scratch/users/stfan/huggingface" 83 | # Load OPT 6B model 84 | llm, llm_tokenizer = lmDecoderUtils.build_opt( 85 | cacheDir=MODEL_CACHE_DIR, device="auto", load_in_8bit=True 86 | ) 87 | 88 | lmDir = "/oak/stanford/groups/henderj/stfan/code/nptlrig2/LanguageModelDecoder/examples/speech/s0/lm_order_exp/5gram/data/lang_test" 89 | ngramDecoder = lmDecoderUtils.build_lm_decoder( 90 | lmDir, acoustic_scale=0.5, nbest=100, beam=18 91 | ) 92 | 93 | 94 | 95 | # LM decoding hyperparameters 96 | acoustic_scale = 0.5 97 | blank_penalty = np.log(7) 98 | llm_weight = 0.5 99 | 100 | llm_outputs = [] 101 | # Generate nbest outputs from 5gram LM 102 | start_t = time.time() 103 | nbest_outputs = [] 104 | for j in range(len(rnn_outputs["logits"])): 105 | logits = rnn_outputs["logits"][j] 106 | logits = np.concatenate( 107 | [logits[:, 1:], logits[:, 0:1]], axis=-1 108 | ) # Blank is last token 109 | logits = lmDecoderUtils.rearrange_speech_logits(logits[None, :, :], has_sil=True) 110 | nbest = lmDecoderUtils.lm_decode( 111 | ngramDecoder, 112 | logits[0], 113 | blankPenalty=blank_penalty, 114 | returnNBest=True, 115 | rescore=True, 116 | ) 117 | nbest_outputs.append(nbest) 118 | time_per_sample = (time.time() - start_t) / len(rnn_outputs["logits"]) 119 | print(f"5gram decoding took {time_per_sample} seconds per sample") 120 | 121 | for i in range(len(rnn_outputs["transcriptions"])): 122 | new_trans = [ord(c) for c in rnn_outputs["transcriptions"][i]] + [0] 123 | rnn_outputs["transcriptions"][i] = np.array(new_trans) 124 | 125 | # Rescore nbest outputs with LLM 126 | start_t = time.time() 127 | llm_out = lmDecoderUtils.cer_with_gpt2_decoder( 128 | llm, 129 | llm_tokenizer, 130 | nbest_outputs[:], 131 | acoustic_scale, 132 | rnn_outputs, 133 | outputType="speech_sil", 134 | returnCI=True, 135 | lengthPenalty=0, 136 | alpha=llm_weight, 137 | ) 138 | # time_per_sample = (time.time() - start_t) / len(logits) 139 | print(f"LLM decoding took {time_per_sample} seconds per sample") 140 | 141 | print(llm_out["cer"], llm_out["wer"]) 142 | with open(input_args.modelPath + "/llm_out", "wb") as handle: 143 | pickle.dump(llm_out, handle) 144 | 145 | decodedTranscriptions = llm_out["decoded_transcripts"] 146 | with open(input_args.modelPath + "/5gramLLMCompetitionSubmission.txt", "w") as f: 147 | for x in range(len(decodedTranscriptions)): 148 | f.write(decodedTranscriptions[x] + "\n") 149 | -------------------------------------------------------------------------------- /scripts/eval_competition.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Parameters 4 | #SBATCH --cpus-per-task=8 5 | #SBATCH --gpus-per-task=1 6 | #SBATCH --job-name=rescore 7 | #SBATCH --mail-type=ALL 8 | #SBATCH --mem=400GB 9 | #SBATCH --nodes=1 10 | #SBATCH --ntasks-per-node=1 11 | #SBATCH --open-mode=append 12 | #SBATCH --partition=henderj,owners 13 | #SBATCH --signal=USR1@120 14 | #SBATCH --time=2880 15 | #SBATCH --constraint=[GPU_MEM:32GB|GPU_MEM:40GB|GPU_MEM:80GB] 16 | 17 | ml gcc/10.1.0 18 | ml load cudnn/8.6.0.163 19 | ml load cuda/11.7.1 20 | 21 | python eval_competition.py --modelPath=$1 22 | -------------------------------------------------------------------------------- /scripts/train_model.py: -------------------------------------------------------------------------------- 1 | 2 | modelName = 'speechBaseline4' 3 | 4 | args = {} 5 | args['outputDir'] = '/oak/stanford/groups/henderj/stfan/logs/speech_logs/' + modelName 6 | args['datasetPath'] = '/oak/stanford/groups/henderj/fwillett/speech/ptDecoder_ctc' 7 | args['seqLen'] = 150 8 | args['maxTimeSeriesLen'] = 1200 9 | args['batchSize'] = 64 10 | args['lrStart'] = 0.02 11 | args['lrEnd'] = 0.02 12 | args['nUnits'] = 1024 13 | args['nBatch'] = 10000 #3000 14 | args['nLayers'] = 5 15 | args['seed'] = 0 16 | args['nClasses'] = 40 17 | args['nInputFeatures'] = 256 18 | args['dropout'] = 0.4 19 | args['whiteNoiseSD'] = 0.8 20 | args['constantOffsetSD'] = 0.2 21 | args['gaussianSmoothWidth'] = 2.0 22 | args['strideLen'] = 4 23 | args['kernelLen'] = 32 24 | args['bidirectional'] = True 25 | args['l2_decay'] = 1e-5 26 | 27 | from neural_decoder.neural_decoder_trainer import trainModel 28 | 29 | trainModel(args) -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # This file is used to configure your project. 2 | # Read more about the various options under: 3 | # https://setuptools.pypa.io/en/latest/userguide/declarative_config.html 4 | # https://setuptools.pypa.io/en/latest/references/keywords.html 5 | 6 | [metadata] 7 | name = neural_decoder 8 | description = PyTorch neural sequence decoder for speech BCI (https://github.com/fwillett/speechBCI/tree/main/NeuralDecoder) 9 | version = 0.0.1 10 | author = Chaofei Fan, Frank Willett 11 | author_email = stfan@stanford.edu 12 | license = MIT 13 | license_files = LICENSE.txt 14 | # Add here related links, for example: 15 | project_urls = 16 | 17 | # Change if running only on Windows, Mac or Linux (comma-separated) 18 | platforms = Linux 19 | 20 | # Add here all kinds of additional classifiers as defined under 21 | # https://pypi.org/classifiers/ 22 | classifiers = 23 | Development Status :: 4 - Beta 24 | Programming Language :: Python 25 | 26 | 27 | [options] 28 | zip_safe = False 29 | packages = find_namespace: 30 | include_package_data = True 31 | package_dir = 32 | =src 33 | 34 | # Require a min/specific Python version (comma-separated conditions) 35 | python_requires = >=3.9 36 | 37 | # Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. 38 | # Version specifiers like >=2.2,<3.0 avoid problems due to API changes in 39 | # new major versions. This works if the required packages follow Semantic Versioning. 40 | # For more information, check out https://semver.org/. 41 | install_requires = 42 | importlib-metadata; python_version<"3.8" 43 | torch==1.13.1 44 | hydra-core==1.3.2 45 | hydra-submitit-launcher==1.1.5 46 | hydra-optuna-sweeper==1.2.0 47 | numpy==1.25.0 48 | scipy==1.11.1 49 | numba==0.58.1 50 | scikit-learn==1.3.2 51 | g2p_en==2.1.0 52 | edit_distance==1.0.6 53 | 54 | 55 | [options.packages.find] 56 | where = src 57 | exclude = 58 | tests 59 | examples 60 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup file for OnlineRecalibrator. 3 | Use setup.cfg to configure your project. 4 | 5 | This file was generated with PyScaffold 4.5. 6 | PyScaffold helps you to put up the scaffold of your new Python project. 7 | Learn more under: https://pyscaffold.org/ 8 | """ 9 | from setuptools import setup 10 | 11 | if __name__ == "__main__": 12 | try: 13 | setup( 14 | ) 15 | except: # noqa 16 | print( 17 | "\n\nAn error occurred while building the project, " 18 | "please ensure you have the most updated version of setuptools, " 19 | "setuptools_scm and wheel with:\n" 20 | " pip install -U setuptools setuptools_scm wheel\n\n" 21 | ) 22 | raise -------------------------------------------------------------------------------- /src/neural_decoder/augmentations.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class WhiteNoise(nn.Module): 9 | def __init__(self, std=0.1): 10 | super().__init__() 11 | self.std = std 12 | 13 | def forward(self, x): 14 | noise = torch.randn_like(x) * self.std 15 | return x + noise 16 | 17 | class MeanDriftNoise(nn.Module): 18 | def __init__(self, std=0.1): 19 | super().__init__() 20 | self.std = std 21 | 22 | def forward(self, x): 23 | _, C = x.shape 24 | noise = torch.randn(1, C) * self.std 25 | return x + noise 26 | 27 | class GaussianSmoothing(nn.Module): 28 | """ 29 | Apply gaussian smoothing on a 30 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 31 | in the input using a depthwise convolution. 32 | Arguments: 33 | channels (int, sequence): Number of channels of the input tensors. Output will 34 | have this number of channels as well. 35 | kernel_size (int, sequence): Size of the gaussian kernel. 36 | sigma (float, sequence): Standard deviation of the gaussian kernel. 37 | dim (int, optional): The number of dimensions of the data. 38 | Default value is 2 (spatial). 39 | """ 40 | 41 | def __init__(self, channels, kernel_size, sigma, dim=2): 42 | super(GaussianSmoothing, self).__init__() 43 | if isinstance(kernel_size, numbers.Number): 44 | kernel_size = [kernel_size] * dim 45 | if isinstance(sigma, numbers.Number): 46 | sigma = [sigma] * dim 47 | 48 | # The gaussian kernel is the product of the 49 | # gaussian function of each dimension. 50 | kernel = 1 51 | meshgrids = torch.meshgrid( 52 | [torch.arange(size, dtype=torch.float32) for size in kernel_size] 53 | ) 54 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 55 | mean = (size - 1) / 2 56 | kernel *= ( 57 | 1 58 | / (std * math.sqrt(2 * math.pi)) 59 | * torch.exp(-(((mgrid - mean) / std) ** 2) / 2) 60 | ) 61 | 62 | # Make sure sum of values in gaussian kernel equals 1. 63 | kernel = kernel / torch.sum(kernel) 64 | 65 | # Reshape to depthwise convolutional weight 66 | kernel = kernel.view(1, 1, *kernel.size()) 67 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 68 | 69 | self.register_buffer("weight", kernel) 70 | self.groups = channels 71 | 72 | if dim == 1: 73 | self.conv = F.conv1d 74 | elif dim == 2: 75 | self.conv = F.conv2d 76 | elif dim == 3: 77 | self.conv = F.conv3d 78 | else: 79 | raise RuntimeError( 80 | "Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim) 81 | ) 82 | 83 | def forward(self, input): 84 | """ 85 | Apply gaussian filter to input. 86 | Arguments: 87 | input (torch.Tensor): Input to apply gaussian filter on. 88 | Returns: 89 | filtered (torch.Tensor): Filtered output. 90 | """ 91 | return self.conv(input, weight=self.weight, groups=self.groups, padding="same") 92 | -------------------------------------------------------------------------------- /src/neural_decoder/conf/config.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ${outputDir} 4 | sweep: 5 | dir: ${outputDir} 6 | subdir: ${hydra.job.override_dirname} 7 | job: 8 | config: 9 | override_dirname: 10 | exclude_keys: 11 | - outputDir 12 | - datasetPath 13 | 14 | outputDir: /oak/stanford/groups/henderj/stfan/logs/speech_logs/pt_neural_decoder 15 | datasetPath: /oak/stanford/groups/henderj/stfan/data/ptDecoder_ctc 16 | 17 | seed: 0 18 | batchSize: 64 19 | lrStart: 0.02 20 | lrEnd: 0.02 21 | l2_decay: 1e-5 22 | nBatch: 10000 23 | 24 | whiteNoiseSD: 0.8 25 | constantOffsetSD: 0.2 26 | gaussianSmoothWidth: 2.0 27 | 28 | nUnits: 1024 29 | nLayers: 5 30 | nInputFeatures: 256 31 | nClasses: 40 32 | dropout: 0.4 33 | strideLen: 4 34 | kernelLen: 32 35 | bidirectional: True -------------------------------------------------------------------------------- /src/neural_decoder/conf/hydra/launcher/big_gpu_slurm_med_time.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - submitit_slurm 3 | 4 | timeout_min: 600 5 | cpus_per_task: 3 6 | mem_gb: 96 7 | partition: owners,henderj 8 | gpus_per_node: 1 9 | setup: 10 | - ml load python/3.9.0 cuda/11.7.1 cudnn/8.6.0.163 11 | additional_parameters: 12 | constraint: '[GPU_MEM:40GB]' 13 | mail-type: ALL 14 | array_parallelism: 50 15 | -------------------------------------------------------------------------------- /src/neural_decoder/conf/hydra/launcher/gpu_slurm_med_time.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - submitit_slurm 3 | 4 | timeout_min: 600 5 | cpus_per_task: 3 6 | mem_gb: 64 7 | partition: owners,henderj 8 | gpus_per_node: 1 9 | setup: 10 | - ml load python/3.9.0 cuda/11.7.1 cudnn/8.6.0.163 11 | additional_parameters: 12 | constraint: '[GPU_MEM:24GB|GPU_MEM:32GB|GPU_MEM:40GB]' 13 | mail-type: ALL 14 | array_parallelism: 50 15 | -------------------------------------------------------------------------------- /src/neural_decoder/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class SpeechDataset(Dataset): 6 | def __init__(self, data, transform=None): 7 | self.data = data 8 | self.transform = transform 9 | self.n_days = len(data) 10 | self.n_trials = sum([len(d["sentenceDat"]) for d in data]) 11 | 12 | self.neural_feats = [] 13 | self.phone_seqs = [] 14 | self.neural_time_bins = [] 15 | self.phone_seq_lens = [] 16 | self.days = [] 17 | for day in range(self.n_days): 18 | for trial in range(len(data[day]["sentenceDat"])): 19 | self.neural_feats.append(data[day]["sentenceDat"][trial]) 20 | self.phone_seqs.append(data[day]["phonemes"][trial]) 21 | self.neural_time_bins.append(data[day]["sentenceDat"][trial].shape[0]) 22 | self.phone_seq_lens.append(data[day]["phoneLens"][trial]) 23 | self.days.append(day) 24 | 25 | def __len__(self): 26 | return self.n_trials 27 | 28 | def __getitem__(self, idx): 29 | neural_feats = torch.tensor(self.neural_feats[idx], dtype=torch.float32) 30 | 31 | if self.transform: 32 | neural_feats = self.transform(neural_feats) 33 | 34 | return ( 35 | neural_feats, 36 | torch.tensor(self.phone_seqs[idx], dtype=torch.int32), 37 | torch.tensor(self.neural_time_bins[idx], dtype=torch.int32), 38 | torch.tensor(self.phone_seq_lens[idx], dtype=torch.int32), 39 | torch.tensor(self.days[idx], dtype=torch.int64), 40 | ) 41 | -------------------------------------------------------------------------------- /src/neural_decoder/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .augmentations import GaussianSmoothing 5 | 6 | 7 | class GRUDecoder(nn.Module): 8 | def __init__( 9 | self, 10 | neural_dim, 11 | n_classes, 12 | hidden_dim, 13 | layer_dim, 14 | nDays=24, 15 | dropout=0, 16 | device="cuda", 17 | strideLen=4, 18 | kernelLen=14, 19 | gaussianSmoothWidth=0, 20 | bidirectional=False, 21 | ): 22 | super(GRUDecoder, self).__init__() 23 | 24 | # Defining the number of layers and the nodes in each layer 25 | self.layer_dim = layer_dim 26 | self.hidden_dim = hidden_dim 27 | self.neural_dim = neural_dim 28 | self.n_classes = n_classes 29 | self.nDays = nDays 30 | self.device = device 31 | self.dropout = dropout 32 | self.strideLen = strideLen 33 | self.kernelLen = kernelLen 34 | self.gaussianSmoothWidth = gaussianSmoothWidth 35 | self.bidirectional = bidirectional 36 | self.inputLayerNonlinearity = torch.nn.Softsign() 37 | self.unfolder = torch.nn.Unfold( 38 | (self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen 39 | ) 40 | self.gaussianSmoother = GaussianSmoothing( 41 | neural_dim, 20, self.gaussianSmoothWidth, dim=1 42 | ) 43 | self.dayWeights = torch.nn.Parameter(torch.randn(nDays, neural_dim, neural_dim)) 44 | self.dayBias = torch.nn.Parameter(torch.zeros(nDays, 1, neural_dim)) 45 | 46 | for x in range(nDays): 47 | self.dayWeights.data[x, :, :] = torch.eye(neural_dim) 48 | 49 | # GRU layers 50 | self.gru_decoder = nn.GRU( 51 | (neural_dim) * self.kernelLen, 52 | hidden_dim, 53 | layer_dim, 54 | batch_first=True, 55 | dropout=self.dropout, 56 | bidirectional=self.bidirectional, 57 | ) 58 | 59 | for name, param in self.gru_decoder.named_parameters(): 60 | if "weight_hh" in name: 61 | nn.init.orthogonal_(param) 62 | if "weight_ih" in name: 63 | nn.init.xavier_uniform_(param) 64 | 65 | # Input layers 66 | for x in range(nDays): 67 | setattr(self, "inpLayer" + str(x), nn.Linear(neural_dim, neural_dim)) 68 | 69 | for x in range(nDays): 70 | thisLayer = getattr(self, "inpLayer" + str(x)) 71 | thisLayer.weight = torch.nn.Parameter( 72 | thisLayer.weight + torch.eye(neural_dim) 73 | ) 74 | 75 | # rnn outputs 76 | if self.bidirectional: 77 | self.fc_decoder_out = nn.Linear( 78 | hidden_dim * 2, n_classes + 1 79 | ) # +1 for CTC blank 80 | else: 81 | self.fc_decoder_out = nn.Linear(hidden_dim, n_classes + 1) # +1 for CTC blank 82 | 83 | def forward(self, neuralInput, dayIdx): 84 | neuralInput = torch.permute(neuralInput, (0, 2, 1)) 85 | neuralInput = self.gaussianSmoother(neuralInput) 86 | neuralInput = torch.permute(neuralInput, (0, 2, 1)) 87 | 88 | # apply day layer 89 | dayWeights = torch.index_select(self.dayWeights, 0, dayIdx) 90 | transformedNeural = torch.einsum( 91 | "btd,bdk->btk", neuralInput, dayWeights 92 | ) + torch.index_select(self.dayBias, 0, dayIdx) 93 | transformedNeural = self.inputLayerNonlinearity(transformedNeural) 94 | 95 | # stride/kernel 96 | stridedInputs = torch.permute( 97 | self.unfolder( 98 | torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3) 99 | ), 100 | (0, 2, 1), 101 | ) 102 | 103 | # apply RNN layer 104 | if self.bidirectional: 105 | h0 = torch.zeros( 106 | self.layer_dim * 2, 107 | transformedNeural.size(0), 108 | self.hidden_dim, 109 | device=self.device, 110 | ).requires_grad_() 111 | else: 112 | h0 = torch.zeros( 113 | self.layer_dim, 114 | transformedNeural.size(0), 115 | self.hidden_dim, 116 | device=self.device, 117 | ).requires_grad_() 118 | 119 | hid, _ = self.gru_decoder(stridedInputs, h0.detach()) 120 | 121 | # get seq 122 | seq_out = self.fc_decoder_out(hid) 123 | return seq_out 124 | -------------------------------------------------------------------------------- /src/neural_decoder/neural_decoder_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import time 4 | 5 | from edit_distance import SequenceMatcher 6 | import hydra 7 | import numpy as np 8 | import torch 9 | from torch.nn.utils.rnn import pad_sequence 10 | from torch.utils.data import DataLoader 11 | 12 | from .model import GRUDecoder 13 | from .dataset import SpeechDataset 14 | 15 | 16 | def getDatasetLoaders( 17 | datasetName, 18 | batchSize, 19 | ): 20 | with open(datasetName, "rb") as handle: 21 | loadedData = pickle.load(handle) 22 | 23 | def _padding(batch): 24 | X, y, X_lens, y_lens, days = zip(*batch) 25 | X_padded = pad_sequence(X, batch_first=True, padding_value=0) 26 | y_padded = pad_sequence(y, batch_first=True, padding_value=0) 27 | 28 | return ( 29 | X_padded, 30 | y_padded, 31 | torch.stack(X_lens), 32 | torch.stack(y_lens), 33 | torch.stack(days), 34 | ) 35 | 36 | train_ds = SpeechDataset(loadedData["train"], transform=None) 37 | test_ds = SpeechDataset(loadedData["test"]) 38 | 39 | train_loader = DataLoader( 40 | train_ds, 41 | batch_size=batchSize, 42 | shuffle=True, 43 | num_workers=0, 44 | pin_memory=True, 45 | collate_fn=_padding, 46 | ) 47 | test_loader = DataLoader( 48 | test_ds, 49 | batch_size=batchSize, 50 | shuffle=False, 51 | num_workers=0, 52 | pin_memory=True, 53 | collate_fn=_padding, 54 | ) 55 | 56 | return train_loader, test_loader, loadedData 57 | 58 | def trainModel(args): 59 | os.makedirs(args["outputDir"], exist_ok=True) 60 | torch.manual_seed(args["seed"]) 61 | np.random.seed(args["seed"]) 62 | device = "cuda" 63 | 64 | with open(args["outputDir"] + "/args", "wb") as file: 65 | pickle.dump(args, file) 66 | 67 | trainLoader, testLoader, loadedData = getDatasetLoaders( 68 | args["datasetPath"], 69 | args["batchSize"], 70 | ) 71 | 72 | model = GRUDecoder( 73 | neural_dim=args["nInputFeatures"], 74 | n_classes=args["nClasses"], 75 | hidden_dim=args["nUnits"], 76 | layer_dim=args["nLayers"], 77 | nDays=len(loadedData["train"]), 78 | dropout=args["dropout"], 79 | device=device, 80 | strideLen=args["strideLen"], 81 | kernelLen=args["kernelLen"], 82 | gaussianSmoothWidth=args["gaussianSmoothWidth"], 83 | bidirectional=args["bidirectional"], 84 | ).to(device) 85 | 86 | loss_ctc = torch.nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True) 87 | optimizer = torch.optim.Adam( 88 | model.parameters(), 89 | lr=args["lrStart"], 90 | betas=(0.9, 0.999), 91 | eps=0.1, 92 | weight_decay=args["l2_decay"], 93 | ) 94 | scheduler = torch.optim.lr_scheduler.LinearLR( 95 | optimizer, 96 | start_factor=1.0, 97 | end_factor=args["lrEnd"] / args["lrStart"], 98 | total_iters=args["nBatch"], 99 | ) 100 | 101 | # --train-- 102 | testLoss = [] 103 | testCER = [] 104 | startTime = time.time() 105 | for batch in range(args["nBatch"]): 106 | model.train() 107 | 108 | X, y, X_len, y_len, dayIdx = next(iter(trainLoader)) 109 | X, y, X_len, y_len, dayIdx = ( 110 | X.to(device), 111 | y.to(device), 112 | X_len.to(device), 113 | y_len.to(device), 114 | dayIdx.to(device), 115 | ) 116 | 117 | # Noise augmentation is faster on GPU 118 | if args["whiteNoiseSD"] > 0: 119 | X += torch.randn(X.shape, device=device) * args["whiteNoiseSD"] 120 | 121 | if args["constantOffsetSD"] > 0: 122 | X += ( 123 | torch.randn([X.shape[0], 1, X.shape[2]], device=device) 124 | * args["constantOffsetSD"] 125 | ) 126 | 127 | # Compute prediction error 128 | pred = model.forward(X, dayIdx) 129 | 130 | loss = loss_ctc( 131 | torch.permute(pred.log_softmax(2), [1, 0, 2]), 132 | y, 133 | ((X_len - model.kernelLen) / model.strideLen).to(torch.int32), 134 | y_len, 135 | ) 136 | loss = torch.sum(loss) 137 | 138 | # Backpropagation 139 | optimizer.zero_grad() 140 | loss.backward() 141 | optimizer.step() 142 | scheduler.step() 143 | 144 | # print(endTime - startTime) 145 | 146 | # Eval 147 | if batch % 100 == 0: 148 | with torch.no_grad(): 149 | model.eval() 150 | allLoss = [] 151 | total_edit_distance = 0 152 | total_seq_length = 0 153 | for X, y, X_len, y_len, testDayIdx in testLoader: 154 | X, y, X_len, y_len, testDayIdx = ( 155 | X.to(device), 156 | y.to(device), 157 | X_len.to(device), 158 | y_len.to(device), 159 | testDayIdx.to(device), 160 | ) 161 | 162 | pred = model.forward(X, testDayIdx) 163 | loss = loss_ctc( 164 | torch.permute(pred.log_softmax(2), [1, 0, 2]), 165 | y, 166 | ((X_len - model.kernelLen) / model.strideLen).to(torch.int32), 167 | y_len, 168 | ) 169 | loss = torch.sum(loss) 170 | allLoss.append(loss.cpu().detach().numpy()) 171 | 172 | adjustedLens = ((X_len - model.kernelLen) / model.strideLen).to( 173 | torch.int32 174 | ) 175 | for iterIdx in range(pred.shape[0]): 176 | decodedSeq = torch.argmax( 177 | torch.tensor(pred[iterIdx, 0 : adjustedLens[iterIdx], :]), 178 | dim=-1, 179 | ) # [num_seq,] 180 | decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1) 181 | decodedSeq = decodedSeq.cpu().detach().numpy() 182 | decodedSeq = np.array([i for i in decodedSeq if i != 0]) 183 | 184 | trueSeq = np.array( 185 | y[iterIdx][0 : y_len[iterIdx]].cpu().detach() 186 | ) 187 | 188 | matcher = SequenceMatcher( 189 | a=trueSeq.tolist(), b=decodedSeq.tolist() 190 | ) 191 | total_edit_distance += matcher.distance() 192 | total_seq_length += len(trueSeq) 193 | 194 | avgDayLoss = np.sum(allLoss) / len(testLoader) 195 | cer = total_edit_distance / total_seq_length 196 | 197 | endTime = time.time() 198 | print( 199 | f"batch {batch}, ctc loss: {avgDayLoss:>7f}, cer: {cer:>7f}, time/batch: {(endTime - startTime)/100:>7.3f}" 200 | ) 201 | startTime = time.time() 202 | 203 | if len(testCER) > 0 and cer < np.min(testCER): 204 | torch.save(model.state_dict(), args["outputDir"] + "/modelWeights") 205 | testLoss.append(avgDayLoss) 206 | testCER.append(cer) 207 | 208 | tStats = {} 209 | tStats["testLoss"] = np.array(testLoss) 210 | tStats["testCER"] = np.array(testCER) 211 | 212 | with open(args["outputDir"] + "/trainingStats", "wb") as file: 213 | pickle.dump(tStats, file) 214 | 215 | 216 | def loadModel(modelDir, nInputLayers=24, device="cuda"): 217 | modelWeightPath = modelDir + "/modelWeights" 218 | with open(modelDir + "/args", "rb") as handle: 219 | args = pickle.load(handle) 220 | 221 | model = GRUDecoder( 222 | neural_dim=args["nInputFeatures"], 223 | n_classes=args["nClasses"], 224 | hidden_dim=args["nUnits"], 225 | layer_dim=args["nLayers"], 226 | nDays=nInputLayers, 227 | dropout=args["dropout"], 228 | device=device, 229 | strideLen=args["strideLen"], 230 | kernelLen=args["kernelLen"], 231 | gaussianSmoothWidth=args["gaussianSmoothWidth"], 232 | bidirectional=args["bidirectional"], 233 | ).to(device) 234 | 235 | model.load_state_dict(torch.load(modelWeightPath, map_location=device)) 236 | return model 237 | 238 | 239 | @hydra.main(version_base="1.1", config_path="conf", config_name="config") 240 | def main(cfg): 241 | cfg.outputDir = os.getcwd() 242 | trainModel(cfg) 243 | 244 | if __name__ == "__main__": 245 | main() --------------------------------------------------------------------------------