├── INSTALL.md ├── LICENSE ├── README.md ├── examples ├── README.md ├── demo_calc_logprobs.sh ├── demo_generate.sh ├── demo_generate_fast.sh ├── demo_train.sh ├── download_effect_predictions.sh ├── download_example_data.sh ├── download_generated_nanobodies.sh ├── download_sequences.sh └── run_example.sh ├── linux_setup.sh ├── pyproject.toml ├── requirements.txt ├── requirements_gpu.txt ├── setup.cfg ├── setup.py └── src └── seqdesign_pt ├── __init__.py ├── autoregressive_model.py ├── autoregressive_train.py ├── aws_utils.py ├── data_loaders.py ├── functions.py ├── layers.py ├── model_logging.py ├── scripts ├── __init__.py ├── calc_logprobs_seqs_fr.py ├── generate_sample_seqs_fr.py ├── run_autoregressive_fr.py └── run_autoregressive_vae_fr.py ├── tf_reader.py ├── utils.py └── version.py /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | We recommend using SeqDesign with a GPU that supports CUDA, especially for training. 4 | If a GPU is available, install the [TensorFlow GPU dependencies](https://www.tensorflow.org/install/gpu), 5 | then install the SeqDesign dependencies with: 6 | ```shell script 7 | pip install -r requirements_gpu.txt 8 | ``` 9 | 10 | Using the [linux_setup.sh](linux_setup.sh) script, 11 | installation on a fresh Ubuntu 18.04 LTS machine took 5 minutes. 12 | 13 | If no GPU is available, use: 14 | ```shell script 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | Then install SeqDesign: 19 | ```shell script 20 | python setup.py install 21 | ``` 22 | 23 | ### Used software and versions tested: 24 | - python - 3.7 25 | - tensorflow - 1.15 26 | - numpy - 1.15 27 | - scipy - 0.19 28 | - sklearn - 0.18 29 | 30 | Tested on Ubuntu 18.04 LTS 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Aaron Kollasch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeqDesign 2 | 3 | SeqDesign is a generative, unsupervised model for biological sequences. 4 | It is capable of learning functional constraints from unaligned sequences 5 | in order to predict the effects of mutations and generate novel sequences, 6 | including insertions and deletions. For more information, 7 | check out the [biorxiv preprint](https://doi.org/10.1101/757252). 8 | 9 | This version of the codebase is compatible with Python 3 and PyTorch. 10 | It also implements [Fast Wavenet](https://github.com/tomlepaine/fast-wavenet) generation. 11 | A TensorFlow version is available [here](https://github.com/debbiemarkslab/SeqDesign) 12 | 13 | ## Installation 14 | 15 | See [INSTALL.md](INSTALL.md). 16 | 17 | ## Examples 18 | 19 | See the [examples](examples) directory for examples of 20 | training, mutation effect prediction, and generation. 21 | 22 | 23 | ## Usage 24 | Run each script with the `-h` argument to see additional arguments: 25 | ### Training 26 | 27 | Given a fasta file of training sequences, run: 28 | ```shell script 29 | run_autoregressive_fr --dataset .fa 30 | ``` 31 | Sequences are uniformly weighted by default. To set sequence 32 | weights, append `:` and a weight to each fasta header, e.g. `:1.0`. 33 | 34 | ### Mutation effect prediction 35 | Deterministic: 36 | ```shell script 37 | calc_logprobs_seqs_fr --sess --dropout-p 1.0 --num-samples 1 --input .fa --output .csv 38 | ``` 39 | 40 | Average of 500 samples: 41 | ```shell script 42 | calc_logprobs_seqs_fr --sess --dropout-p 0.5 --num-samples 500 --input .fa --output .csv 43 | ``` 44 | 45 | ### Sequence generation 46 | ```shell script 47 | generate_sample_seqs_fr --sess 48 | ``` 49 | Use the `--fast-generation` argument for Fast Wavenet. 50 | 51 | ## Data availability 52 | See the [examples](examples) directory to download training sequences, 53 | mutation effect predictions, and generated sequences. 54 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # SeqDesign Examples 2 | 3 | ## Downloading data 4 | ### Example data 5 | ```shell script 6 | ./download_example_data.sh 7 | ``` 8 | 9 | This script will download the following files for 10 | training, mutation prediction, and sequence generation: 11 | 12 | - `datasets/sequences/BLAT_ECOLX_1_b0.5_lc_weights.fa` 13 | - `datasets/nanobodies/Manglik_filt_seq_id80_id90.fa` 14 | - `datasets/nanobodies/Manglik_labelled_nanobodies.txt` 15 | - `calc_logprobs/input/BLAT_ECOLX_r24-286_Ranganathan2015.fa` 16 | - `sess/BLAT_ECOLX_v2_channels-48_rseed-11_19Aug16_0626PM.ckpt-250000*` 17 | - `sess/nanobody.ckpt-250000*` 18 | 19 | ### Training sequences 20 | ```shell script 21 | ./download_sequences.sh 22 | ``` 23 | This script will download all training sequences from the paper 24 | to `datasets/sequences/` and `datasets/nanobodies/` 25 | 26 | ### Effect predictions 27 | ```shell script 28 | ./download_effect_predictions.sh 29 | ``` 30 | This script will download all mutation effect predictions 31 | from the paper to `calc_logprobs/output/` 32 | 33 | ### Generated sequences 34 | ```shell script 35 | ./download_generated_nanobodies.sh 36 | ``` 37 | This script will download the designed nanobody library to `generated/` 38 | 39 | ## Running examples 40 | 41 | ### Training the model 42 | ```shell script 43 | ./demo_train.sh 44 | ``` 45 | 46 | This script will run 100 training iterations on the β-Lactamase sequence dataset 47 | (the full model runs 250,000 iterations). 48 | The final model checkpoint will appear as three files in 49 | `sess/BLAT_ECOLX_elu_channels-48_rseed-11_.ckpt-100*` 50 | 51 | On an AWS p2.xlarge instance, this demonstration took 2 minutes. 52 | 53 | ### Predicting mutation effects 54 | ```shell script 55 | ./demo_calc_logprobs.sh 56 | ``` 57 | 58 | This script will use the pretrained model weights in 59 | `sess/BLAT_ECOLX_v2_channels-48_rseed-11_19Aug16_0626PM.ckpt-250000*` 60 | to make mutation effect predictions for the β-Lactamase mutational scan from 61 | [Stiffler et al., Cell, 2015](https://doi.org/10.1016/j.cell.2015.01.035). 62 | 63 | The final predictions are the average of 10 predictions 64 | (500 are used in the full test). 65 | These predictions will appear in 66 | `calc_logprobs/output/demo_BLAT_ECOLX_r24-286_Ranganathan2015_rseed-11_channels-48_dropoutp-0.5.csv` 67 | 68 | On an AWS p2.xlarge instance, this demonstration took 3.5 minutes. 69 | 70 | ### Generating nanobody libraries 71 | ```shell script 72 | ./demo_generate.sh 73 | ./demo_generate_fast.sh 74 | ``` 75 | 76 | This will generate nanobody CDR3 and FRA4 sequences given a preceding VH sequence. 77 | The full nanobody sequences will be output in 78 | `generated/nanobody.ckpt-250000_temp-1.0_rseed-42.fa` 79 | 80 | On an AWS p2.xlarge instance, `demo_generate.sh` took 2.5 minutes and 81 | `demo_generate_fast.sh` took 30 seconds. 82 | -------------------------------------------------------------------------------- /examples/demo_calc_logprobs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd calc_logprobs || exit 3 | calc_logprobs_seqs_fr --sess BLAT_ECOLX_v2_channels-48_rseed-11_19Aug16_0626PM.ckpt-250000 \ 4 | --channels 48 --r-seed 11 --dropout-p 0.5 --num-samples 10 --from-tf \ 5 | --input input/BLAT_ECOLX_r24-286_Ranganathan2015.fa \ 6 | --output output/demo_BLAT_ECOLX_r24-286_Ranganathan2015_rseed-11_channels-48_dropoutp-0.5.csv -------------------------------------------------------------------------------- /examples/demo_generate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | generate_sample_seqs_fr --sess nanobody.ckpt-250000 --r-seed 42 --temp 1.0 --batch-size 500 --num-batches 10 --channels 48 --from-tf -------------------------------------------------------------------------------- /examples/demo_generate_fast.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | generate_sample_seqs_fr --sess nanobody.ckpt-250000 --r-seed 42 --temp 1.0 --batch-size 500 --num-batches 10 --channels 48 --from-tf --fast-generation -------------------------------------------------------------------------------- /examples/demo_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | run_autoregressive_fr --dataset BLAT_ECOLX --channels 48 --r-seed 11 --num-iterations 100 --snapshot-interval 100 -------------------------------------------------------------------------------- /examples/download_effect_predictions.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | curl -o predictions.tar.gz https://marks.hms.harvard.edu/seqdesign/predictions.tar.gz 3 | tar -xzvf predictions.tar.gz 4 | rm predictions.tar.gz -------------------------------------------------------------------------------- /examples/download_example_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | curl -o example.tar.gz https://marks.hms.harvard.edu/seqdesign/example.tar.gz 3 | tar -xzvf example.tar.gz 4 | rm example.tar.gz -------------------------------------------------------------------------------- /examples/download_generated_nanobodies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | curl -o generated_nbs.tar.gz https://marks.hms.harvard.edu/seqdesign/generated_nbs.tar.gz 3 | tar -xzvf generated_nbs.tar.gz 4 | rm generated_nbs.tar.gz -------------------------------------------------------------------------------- /examples/download_sequences.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | curl -o sequences.tar.gz https://marks.hms.harvard.edu/seqdesign/sequences.tar.gz 3 | tar -xzvf sequences.tar.gz 4 | rm sequences.tar.gz -------------------------------------------------------------------------------- /examples/run_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | run_autoregressive_fr --dataset BLAT_ECOLX_1 --channels 48 --r-seed 11 3 | -------------------------------------------------------------------------------- /linux_setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Example installation script of SeqDesign for Tensorflow-GPU from scratch 3 | # Tested on Ubuntu 18.04 LTS, runtime ~5 minutes including a reboot. 4 | # Miniconda and Tensorflow 1.12 are installed here, but a working Tensorflow 1 environment can substitute. 5 | # Before running this script, first run `git clone -b v3 https://github.com/aaronkollasch/seqdesign-pytorch.git` 6 | # and then `cd seqdesign-pytorch` 7 | # If NVIDIA drivers have not been installed before, this script must be run twice, rebooting the system in between. 8 | 9 | if [ ! -f "/proc/driver/nvidia/version" ]; then 10 | echo "NVIDIA driver not found; installing." 11 | sudo apt update 12 | sudo apt install -y --no-install-recommends nvidia-driver-430 13 | echo " 14 | NVIDIA drivers installed. 15 | Please reboot your system, then run linux_setup.sh a second time." 16 | exit 17 | fi 18 | 19 | # set up conda and the SeqDesign environment 20 | if [ ! -d "$HOME/miniconda3" ]; then 21 | echo "miniconda3 not found; installing." 22 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 23 | sh Miniconda3-latest-Linux-x86_64.sh -b -p "$HOME"/miniconda3 24 | rm Miniconda3-latest-Linux-x86_64.sh 25 | fi 26 | "$HOME"/miniconda3/bin/conda init 27 | "$HOME"/miniconda3/bin/conda install mamba -n base -c conda-forge 28 | "$HOME"/miniconda3/bin/mamba create -n seqdesign -y -c pytorch python=3.7 pip pytorch scipy scikit-learn gitpython pandas biopython pillow 29 | "$HOME"/miniconda3/envs/seqdesign/bin/python -c "import torch; print(torch.cuda.is_available()); print([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])" # test GPU install 30 | "$HOME"/miniconda3/bin/conda install -y -n seqdesign "tensorflow>1.12,<2" # necessary to read tensorflow model files 31 | #"$HOME"/miniconda3/envs/seqdesign/bin/python -c "from tensorflow.python.client import device_lib; print(device_lib.list_local_devices())" # test GPU install 32 | 33 | # download SeqDesign code: 34 | # git clone -b v3 https://github.com/aaronkollasch/seqdesign-pytorch.git 35 | # cd seqdesign-pytorch || exit 36 | "$HOME"/miniconda3/envs/seqdesign/bin/pip install . # use setup.py develop if you want to modify the code files 37 | 38 | # download demo/example data 39 | if [ ! -f examples/datasets/sequences/BLAT_ECOLX_1_b0.5_lc_weights.fa ]; then 40 | "echo examples not found; downloading." 41 | cd examples || exit 42 | ./download_example_data.sh 43 | fi 44 | 45 | echo " 46 | SeqDesign installed. 47 | Run 'source ~/.bashrc; conda activate seqdesign' before using." 48 | 49 | # # to run training demo: 50 | # ./demo_train.sh 51 | 52 | # # to run calc_logprobs using trained weights: 53 | # ./demo_calc_logprobs.sh 54 | 55 | # # to generate sequences: 56 | # ./demo_generate.sh 57 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2 2 | scipy 3 | numpy 4 | scikit-learn 5 | gitpython 6 | pandas 7 | biopython 8 | pillow 9 | -------------------------------------------------------------------------------- /requirements_gpu.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2 2 | scipy 3 | numpy 4 | scikit-learn 5 | gitpython 6 | pandas 7 | biopython 8 | pillow 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = seqdesign-pt 3 | version = 1.0.1 4 | author = Aaron Kollasch 5 | author_email = aaron@kollasch.dev 6 | description = Protein design and variant prediction using autoregressive generative models 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown; charset=UTF-8 9 | url = https://github.com/aaronkollasch/seqdesign-pytorch 10 | license = MIT 11 | license_file = LICENSE 12 | keywords = jupyter, remote, ssh, slurm, pexpect, orchestra, o2 13 | platforms = any 14 | classifiers = 15 | Development Status :: 5 - Production/Stable 16 | Intended Audience :: Science/Research 17 | Environment :: Console 18 | Topic :: Utilities 19 | License :: OSI Approved :: MIT License 20 | Programming Language :: Python :: 3 21 | Programming Language :: Python :: 3.6 22 | Programming Language :: Python :: 3.7 23 | Programming Language :: Python :: 3.8 24 | Programming Language :: Python :: 3.9 25 | 26 | [options] 27 | setup_requires = 28 | install_requires = 29 | torch>=1.2 30 | scipy 31 | numpy 32 | scikit-learn 33 | gitpython 34 | pandas 35 | biopython 36 | pillow 37 | python_requires = >=3.6 38 | package_dir = 39 | = src 40 | packages = find: 41 | include_package_data = True 42 | 43 | [options.packages.find] 44 | where = src 45 | exclude = 46 | tests 47 | 48 | [options.entry_points] 49 | console_scripts = 50 | calc_logprobs_seqs_fr = seqdesign_pt.scripts.calc_logprobs_seqs_fr:main 51 | generate_sample_seqs_fr = seqdesign_pt.scripts.generate_sample_seqs_fr:main 52 | run_autoregressive_fr = seqdesign_pt.scripts.run_autoregressive_fr:main 53 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup() 4 | -------------------------------------------------------------------------------- /src/seqdesign_pt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronkollasch/seqdesign-pytorch/f9391b2f4689827d7d2993d00830fbe2f406676a/src/seqdesign_pt/__init__.py -------------------------------------------------------------------------------- /src/seqdesign_pt/autoregressive_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import itertools 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.distributions as dist 7 | import torch.nn.functional as F 8 | 9 | from seqdesign_pt import layers 10 | from seqdesign_pt.utils import recursive_update 11 | from seqdesign_pt.functions import nonlinearity, comb_losses, clamp 12 | 13 | 14 | class Autoregressive(nn.Module): 15 | """An autoregressive model 16 | 17 | """ 18 | model_type = 'autoregressive' 19 | 20 | def __init__( 21 | self, 22 | dims=None, 23 | hyperparams=None, 24 | channels=None, 25 | r_seed=None, 26 | dropout_p=None, 27 | ): 28 | super(Autoregressive, self).__init__() 29 | 30 | self.dims = { 31 | "batch": 10, 32 | "alphabet": 21, 33 | "length": 256, 34 | "embedding_size": 1 35 | } 36 | if dims is not None: 37 | self.dims.update(dims) 38 | self.dims.setdefault('input', self.dims['alphabet']) 39 | 40 | self.hyperparams = { 41 | # For purely dilated conv network 42 | "encoder": { 43 | "channels": 48, 44 | "nonlinearity": "elu", 45 | "num_dilation_blocks": 6, 46 | "num_layers": 9, 47 | "dilation_schedule": None, 48 | "transformer": False, 49 | "inverse_temperature": False, 50 | "dropout_loc": "inter", # options = "final", "inter", "gaussian" 51 | "dropout_p": 0.5, # probability of zeroing out value, not the keep probability 52 | "dropout_type": "independent", 53 | "config": "original", # options = "original", "updated", "standard" 54 | }, 55 | "sampler_hyperparams": { 56 | 'warm_up': 1, 57 | 'annealing_type': 'linear', 58 | 'anneal_kl': True, 59 | 'anneal_noise': True 60 | }, 61 | "embedding_hyperparams": { 62 | 'warm_up': 1, 63 | 'annealing_type': 'linear', 64 | 'anneal_kl': True, 65 | 'anneal_noise': False 66 | }, 67 | "random_seed": 42, 68 | "optimization": { 69 | "l2_regularization": True, 70 | "bayesian": False, # TODO implement bayesian 71 | "l2_lambda": 1., 72 | "bayesian_logits": False, 73 | "mle_logits": False, 74 | } 75 | } 76 | if hyperparams is not None: 77 | recursive_update(self.hyperparams, hyperparams) 78 | if channels is not None: 79 | self.hyperparams['encoder']['channels'] = channels 80 | if dropout_p is not None: 81 | self.hyperparams['encoder']['dropout_p'] = dropout_p 82 | if r_seed is not None: 83 | self.hyperparams['random_seed'] = r_seed 84 | 85 | # initialize encoder modules 86 | enc_params = self.hyperparams['encoder'] 87 | nonlin = nonlinearity(enc_params['nonlinearity']) 88 | 89 | self.start_conv = layers.Conv2d( 90 | self.dims['input'], 91 | enc_params['channels'], 92 | kernel_width=(1, 1), 93 | activation=None, 94 | ) 95 | 96 | self.dilation_blocks = nn.ModuleList() 97 | for block in range(enc_params['num_dilation_blocks']): 98 | self.dilation_blocks.append(layers.ConvNet1D( 99 | channels=enc_params['channels'], 100 | layers=enc_params['num_layers'], 101 | dropout_p=enc_params['dropout_p'], 102 | dropout_type=enc_params['dropout_type'], 103 | causal=True, 104 | config=enc_params['config'], 105 | dilation_schedule=enc_params['dilation_schedule'], 106 | transpose=False, 107 | nonlinearity=nonlin, 108 | )) 109 | 110 | if enc_params['dropout_loc'] == "final": 111 | self.final_dropout = nn.Dropout(max(enc_params['dropout_p']-0.3, 0.)) 112 | else: 113 | self.register_parameter('final_dropout', None) 114 | 115 | self.end_conv = layers.Conv2d( 116 | enc_params['channels'], 117 | self.dims['alphabet'], 118 | kernel_width=(1, 1), 119 | g_init=0.1, 120 | activation=None, 121 | ) 122 | 123 | self.step = 0 124 | self.image_summaries = {} 125 | 126 | self.generating = False 127 | self.generating_reset = True 128 | 129 | @staticmethod 130 | def _log_gaussian(z, prior_mu, prior_sigma): 131 | prior = dist.Normal(prior_mu, prior_sigma) 132 | return prior.log_prob(z) 133 | 134 | def _kl_mixture_gaussians( 135 | self, z, log_sigma, p=0.1, 136 | mu_one=0., mu_two=0., sigma_one=1., sigma_two=1. 137 | ): 138 | gauss_one = self._log_gaussian(z, mu_one, sigma_one) 139 | gauss_two = self._log_gaussian(z, mu_two, sigma_two) 140 | entropy = 0.5 * torch.log(2.0 * math.pi * math.e * torch.exp(2. * log_sigma)) 141 | return (p * gauss_one) + ((1. - p) * gauss_two) + entropy 142 | 143 | def _mle_mixture_gaussians( 144 | self, z, p=0.1, 145 | mu_one=0., mu_two=0., sigma_one=1., sigma_two=1. 146 | ): 147 | gauss_one = self._log_gaussian(z, mu_one, sigma_one) 148 | gauss_two = self._log_gaussian(z, mu_two, sigma_two) 149 | return (p * gauss_one) + ((1. - p) * gauss_two) 150 | 151 | def generate(self, mode=True): 152 | self.generating = mode 153 | self.generating_reset = True 154 | for module in itertools.chain(self.children(), self.dilation_blocks): 155 | if hasattr(module, "generate") and callable(module.generate): 156 | module.generate(mode) 157 | return self 158 | 159 | def weight_costs(self): 160 | return ( 161 | self.start_conv.weight_costs() + 162 | tuple(cost for layer in self.dilation_blocks for cost in layer.weight_costs()) + 163 | self.end_conv.weight_costs() 164 | ) 165 | 166 | def parameter_count(self): 167 | return sum(param.numel() for param in self.parameters()) 168 | 169 | def forward(self, inputs, input_masks): 170 | """ 171 | :param inputs: (N, C_in, 1, L) 172 | :param input_masks: (N, 1, 1, L) 173 | :return: 174 | """ 175 | enc_params = self.hyperparams['encoder'] 176 | 177 | if self.generating: 178 | if self.generating_reset: 179 | self.generating_reset = False 180 | else: 181 | inputs = inputs[:, :, :, -1:] 182 | 183 | up_val_1d = self.start_conv(inputs) 184 | for convnet in self.dilation_blocks: 185 | up_val_1d = convnet(up_val_1d, input_masks) 186 | self.image_summaries['LayerFeatures'] = dict(img=up_val_1d.permute(0, 1, 3, 2).detach(), max_outputs=3) 187 | if enc_params['dropout_loc'] == "final": 188 | up_val_1d = self.final_dropout(up_val_1d) 189 | up_val_1d = self.end_conv(up_val_1d) 190 | return up_val_1d 191 | 192 | @staticmethod 193 | def reconstruction_loss(seq_logits, target_seqs, mask): 194 | seq_reconstruct = F.log_softmax(seq_logits, 1) 195 | # cross_entropy = F.cross_entropy(seq_logits, target_seqs.argmax(1), reduction='none') 196 | cross_entropy = F.nll_loss(seq_reconstruct, target_seqs.argmax(1), reduction='none') 197 | cross_entropy = cross_entropy * mask.squeeze(1) 198 | reconstruction_per_seq = cross_entropy.sum([1, 2]) 199 | bitperchar_per_seq = reconstruction_per_seq / mask.sum([1, 2, 3]) 200 | reconstruction_loss = reconstruction_per_seq.mean() 201 | bitperchar = bitperchar_per_seq.mean() 202 | return { 203 | 'seq_reconstruct_per_char': seq_reconstruct, 204 | 'ce_loss': reconstruction_loss, 205 | 'ce_loss_per_seq': reconstruction_per_seq, 206 | 'ce_loss_per_char': cross_entropy.squeeze(1), 207 | 'bitperchar': bitperchar, 208 | 'bitperchar_per_seq': bitperchar_per_seq 209 | } 210 | 211 | def calculate_loss( 212 | self, seq_logits, target_seqs, mask, n_eff 213 | ): 214 | """ 215 | 216 | :param seq_logits: (N, C, 1, L) 217 | :param target_seqs: (N, C, 1, L) as one-hot 218 | :param mask: (N, 1, 1, L) 219 | :param n_eff: 220 | :return: 221 | """ 222 | hyperparams = self.hyperparams 223 | 224 | # cross-entropy 225 | reconstruction_loss = self.reconstruction_loss( 226 | seq_logits, target_seqs, mask 227 | ) 228 | loss_per_seq = reconstruction_loss['ce_loss_per_seq'] 229 | loss = reconstruction_loss['ce_loss'] 230 | 231 | if hyperparams["optimization"]["l2_regularization"] or hyperparams["optimization"]["bayesian"]: 232 | 233 | # regularization 234 | weight_cost = torch.stack(self.weight_costs()).sum() / n_eff 235 | weight_cost = weight_cost * hyperparams["optimization"]["l2_lambda"] 236 | kl_weight_loss = weight_cost 237 | kl_loss = weight_cost 238 | 239 | # merge losses 240 | loss = loss + weight_cost 241 | 242 | # KL loss 243 | if hyperparams["optimization"]["bayesian_logits"] or hyperparams["optimization"]["mle_logits"]: 244 | if self.hyperparams["optimization"]["mle_logits"]: 245 | kl_logits = - self._mle_mixture_gaussians( 246 | seq_logits, p=.6, mu_one=0., mu_two=0., sigma_one=1.25, sigma_two=3. 247 | ) 248 | else: 249 | kl_logits = None 250 | 251 | kl_logits = kl_logits * mask 252 | kl_logits_per_seq = kl_logits.sum([1, 2, 3]) 253 | loss_per_seq = loss_per_seq + kl_logits_per_seq 254 | kl_logits_loss = kl_logits_per_seq.mean() 255 | kl_loss += kl_logits_loss 256 | kl_embedding_loss = kl_logits_loss 257 | loss = loss + self._anneal_embedding(self.step) * kl_logits_loss 258 | else: 259 | kl_embedding_loss = kl_weight_loss 260 | 261 | else: 262 | weight_cost = None 263 | kl_embedding_loss = torch.zeros([]) 264 | kl_loss = torch.zeros([]) 265 | 266 | seq_reconstruct = reconstruction_loss.pop('seq_reconstruct_per_char') 267 | self.image_summaries['SeqReconstruct'] = dict(img=seq_reconstruct.permute(0, 1, 3, 2).detach(), max_outputs=3) 268 | self.image_summaries['SeqTarget'] = dict(img=target_seqs.permute(0, 1, 3, 2).detach(), max_outputs=3) 269 | self.image_summaries['SeqDelta'] = dict( 270 | img=(seq_reconstruct - target_seqs).permute(0, 1, 3, 2).detach(), max_outputs=3) 271 | 272 | output = { 273 | 'loss': loss, 274 | 'ce_loss': None, 275 | 'bitperchar': None, 276 | 'loss_per_seq': loss_per_seq, 277 | 'bitperchar_per_seq': None, 278 | 'ce_loss_per_seq': None, 279 | 'kl_embedding_loss': kl_embedding_loss, 280 | 'kl_loss': kl_loss, 281 | 'weight_cost': weight_cost, 282 | } 283 | output.update(reconstruction_loss) 284 | return output 285 | 286 | 287 | class AutoregressiveFR(nn.Module): 288 | sub_model_class = Autoregressive 289 | model_type = 'autoregressive_fr' 290 | 291 | def __init__( 292 | self, 293 | **kwargs 294 | ): 295 | super(AutoregressiveFR, self).__init__() 296 | self.model = nn.ModuleDict({ 297 | 'model_f': self.sub_model_class(**kwargs), 298 | 'model_r': self.sub_model_class(**kwargs) 299 | }) 300 | self.dims = self.model.model_f.dims 301 | self.hyperparams = self.model.model_f.hyperparams 302 | 303 | # make dictionaries the same in memory 304 | self.model.model_r.dims = self.model.model_f.dims 305 | self.model.model_r.hyperparams = self.model.model_f.hyperparams 306 | 307 | @property 308 | def step(self): 309 | return self.model.model_f.step 310 | 311 | @step.setter 312 | def step(self, new_step): 313 | self.model.model_f.step = new_step 314 | self.model.model_r.step = new_step 315 | 316 | @property 317 | def image_summaries(self): 318 | img_summaries_f = self.model.model_f.image_summaries 319 | img_summaries_r = self.model.model_r.image_summaries 320 | img_summaries = {} 321 | for key in img_summaries_f.keys(): 322 | img_summaries[key + '_f'] = img_summaries_f[key] 323 | img_summaries[key + '_r'] = img_summaries_r[key] 324 | return img_summaries 325 | 326 | def generate(self, mode=True): 327 | for module in self.model.children(): 328 | if hasattr(module, "generate") and callable(module.generate): 329 | module.generate(mode) 330 | return self 331 | 332 | def weight_costs(self): 333 | return tuple(cost for model in self.model.children() for cost in model.weight_costs()) 334 | 335 | def parameter_count(self): 336 | return sum(model.parameter_count() for model in self.model.children()) 337 | 338 | def forward(self, input_f, mask_f, input_r, mask_r): 339 | output_logits_f = self.model.model_f(input_f, mask_f) 340 | output_logits_r = self.model.model_r(input_r, mask_r) 341 | return output_logits_f, output_logits_r 342 | 343 | def reconstruction_loss( 344 | self, 345 | seq_logits_f, target_seqs_f, mask_f, 346 | seq_logits_r, target_seqs_r, mask_r, 347 | ): 348 | losses_f = self.model.model_f.reconstruction_loss( 349 | seq_logits_f, target_seqs_f, mask_f 350 | ) 351 | losses_r = self.model.model_r.reconstruction_loss( 352 | seq_logits_r, target_seqs_r, mask_r 353 | ) 354 | return comb_losses(losses_f, losses_r) 355 | 356 | def calculate_loss(self, *args): 357 | losses_f = self.model.model_f.calculate_loss(*args[:len(args)//2]) 358 | losses_r = self.model.model_r.calculate_loss(*args[len(args)//2:]) 359 | return comb_losses(losses_f, losses_r) 360 | 361 | 362 | class AutoregressiveVAE(nn.Module): 363 | """An autoregressive variational autoencoder 364 | 365 | """ 366 | model_type = 'autoregressive_vae' 367 | 368 | def __init__( 369 | self, 370 | dims=None, 371 | hyperparams=None, 372 | channels=None, 373 | r_seed=None, 374 | dropout_p=None, 375 | ): 376 | super(AutoregressiveVAE, self).__init__() 377 | 378 | self.dims = { 379 | "batch": 10, 380 | "alphabet": 21, 381 | "length": 256, 382 | "embedding_size": 1 383 | } 384 | if dims is not None: 385 | self.dims.update(dims) 386 | self.dims.setdefault('input', self.dims['alphabet']) 387 | 388 | self.hyperparams = { 389 | "encoder": { 390 | "channels": 48, 391 | "nonlinearity": "elu", 392 | "num_dilation_blocks": 3, 393 | "num_layers": 9, 394 | "dilation_schedule": None, 395 | "transformer": False, 396 | "inverse_temperature": False, 397 | "embedding_nnet_nonlinearity": "elu", 398 | "embedding_nnet_size": 200, 399 | "latent_size": 30, 400 | "dropout_p": 0.1, 401 | "dropout_type": "2D", 402 | "config": "updated", 403 | }, 404 | "decoder": { 405 | "channels": 48, 406 | "nonlinearity": "elu", 407 | "num_dilation_blocks": 3, 408 | "num_layers": 9, 409 | "dilation_schedule": None, 410 | "transformer": False, 411 | "inverse_temperature": False, 412 | "positional_embedding": True, 413 | "skip_connections": False, 414 | "pos_emb_max_len": 400, 415 | "pos_emb_step": 5, 416 | "config": "updated", 417 | "dropout_type": "2D", 418 | "dropout_p": 0.5, 419 | }, 420 | "sampler_hyperparams": { 421 | 'warm_up': 10000, 422 | 'annealing_type': 'linear', 423 | 'anneal_kl': True, 424 | 'anneal_noise': True 425 | }, 426 | "embedding_hyperparams": { 427 | 'warm_up': 10000, 428 | 'annealing_type': 'piecewise_linear', 429 | 'anneal_kl': True, 430 | 'anneal_noise': True 431 | }, 432 | "random_seed": 42, 433 | "optimization": { 434 | "l2_regularization": True, 435 | "bayesian": True, 436 | "l2_lambda": 1., 437 | "bayesian_logits": False, 438 | "mle_logits": False, 439 | } 440 | } 441 | if hyperparams is not None: 442 | recursive_update(self.hyperparams, hyperparams) 443 | if channels is not None: 444 | self.hyperparams['encoder']['channels'] = channels 445 | if dropout_p is not None: 446 | self.hyperparams['decoder']['dropout_p'] = dropout_p 447 | if r_seed is not None: 448 | self.hyperparams['random_seed'] = r_seed 449 | 450 | # initialize encoder modules 451 | enc_params = self.hyperparams['encoder'] 452 | nonlin = nonlinearity(enc_params['nonlinearity']) 453 | 454 | self.encoder = nn.ModuleDict() 455 | self.encoder.start_conv = layers.Conv2d( 456 | self.dims['input'], 457 | enc_params['channels'], 458 | kernel_width=(1, 1), 459 | activation=nonlin, 460 | ) 461 | 462 | self.encoder.dilation_blocks = nn.ModuleList() 463 | for block in range(enc_params['num_dilation_blocks']): 464 | self.encoder.dilation_blocks.append(layers.ConvNet1D( 465 | channels=enc_params['channels'], 466 | layers=enc_params['num_layers'], 467 | dropout_p=enc_params['dropout_p'], 468 | dropout_type=enc_params['dropout_type'], 469 | causal=False, 470 | config=enc_params['config'], 471 | dilation_schedule=enc_params['dilation_schedule'], 472 | transpose=False, 473 | nonlinearity=nonlin, 474 | )) 475 | 476 | self.encoder.emb_mu_one = nn.Linear(enc_params['channels'], enc_params['embedding_nnet_size']) 477 | self.encoder.emb_log_sigma_one = nn.Linear(enc_params['channels'], enc_params['embedding_nnet_size']) 478 | self.encoder.emb_mu_out = nn.Linear(enc_params['embedding_nnet_size'], enc_params['latent_size']) 479 | self.encoder.emb_log_sigma_out = nn.Linear(enc_params['embedding_nnet_size'], enc_params['latent_size']) 480 | 481 | # initialize decoder modules 482 | dec_params = self.hyperparams['decoder'] 483 | nonlin = nonlinearity(dec_params['nonlinearity']) 484 | 485 | if dec_params['positional_embedding']: 486 | max_len = dec_params['pos_emb_max_len'] 487 | step = dec_params['pos_emb_step'] 488 | rbf_locations = torch.arange(step, max_len+1, step, dtype=torch.float32) 489 | rbf_locations = rbf_locations.view(1, dec_params['pos_emb_max_len'] // dec_params['pos_emb_step'], 1, 1) 490 | self.register_buffer('rbf_locations', rbf_locations) 491 | else: 492 | self.register_buffer('rbf_locations', None) 493 | 494 | self.decoder = nn.ModuleDict() 495 | self.decoder.start_conv = layers.Conv2d( 496 | ( 497 | self.dims['input'] + 498 | ( 499 | dec_params['pos_emb_max_len'] // dec_params['pos_emb_step'] 500 | if dec_params['positional_embedding'] else 0 501 | ) + 502 | enc_params['latent_size'] 503 | ), 504 | dec_params['channels'], 505 | kernel_width=(1, 1), 506 | activation=nonlin, 507 | ) 508 | 509 | self.decoder.dilation_blocks = nn.ModuleList() 510 | for block in range(dec_params['num_dilation_blocks']): 511 | self.decoder.dilation_blocks.append(layers.ConvNet1D( 512 | channels=dec_params['channels'], 513 | layers=dec_params['num_layers'], 514 | add_input_channels=enc_params['channels'] if dec_params['skip_connections'] else 0, 515 | add_input_layer='all' if dec_params['skip_connections'] else None, 516 | dropout_p=dec_params['dropout_p'], 517 | dropout_type=dec_params['dropout_type'], 518 | causal=True, 519 | config=dec_params['config'], 520 | dilation_schedule=dec_params['dilation_schedule'], 521 | transpose=False, 522 | nonlinearity=nonlin, 523 | )) 524 | 525 | self.decoder.end_conv = layers.Conv2d( 526 | dec_params['channels'], 527 | self.dims['alphabet'], 528 | kernel_width=(1, 1), 529 | g_init=0.1, 530 | activation=None, 531 | ) 532 | 533 | self.step = 0 534 | self.forward_state = {'kl_embedding': None} 535 | self.image_summaries = {} 536 | self._enable_gradient = 'ed' 537 | 538 | @property 539 | def enable_gradient(self): 540 | return self._enable_gradient 541 | 542 | @enable_gradient.setter 543 | def enable_gradient(self, value): 544 | if self._enable_gradient == value: 545 | return 546 | self._enable_gradient = value 547 | for p in self.encoder.parameters(): 548 | p.requires_grad = 'e' in value 549 | if 'e' not in value: 550 | # p.grad = None 551 | if p.grad is not None: 552 | p.grad.detach_() 553 | p.grad.zero_() 554 | for p in self.decoder.parameters(): 555 | p.requires_grad = 'd' in value 556 | if 'd' not in value: 557 | # p.grad = None 558 | if p.grad is not None: 559 | p.grad.detach_() 560 | p.grad.zero_() 561 | 562 | @staticmethod 563 | def _kl_standard_normal(mu, log_sigma): 564 | """ KL divergence between two Diagonal Gaussians """ 565 | return 0.5 * (mu.pow(2) + (2.0 * log_sigma).exp() - 2.0 * log_sigma - 1) 566 | # return dist.kl_divergence(dist.Normal(mu, log_sigma.exp()), dist.Normal(0., 1.)) 567 | 568 | @staticmethod 569 | def _log_gaussian(z, prior_mu, prior_sigma): 570 | prior = dist.Normal(prior_mu, prior_sigma) 571 | return prior.log_prob(z) 572 | 573 | def _kl_mixture_gaussians( 574 | self, z, log_sigma, p=0.1, 575 | mu_one=0., mu_two=0., sigma_one=1., sigma_two=1. 576 | ): 577 | gauss_one = self._log_gaussian(z, mu_one, sigma_one) 578 | gauss_two = self._log_gaussian(z, mu_two, sigma_two) 579 | entropy = 0.5 * torch.log(2.0 * math.pi * math.e * torch.exp(2. * log_sigma)) 580 | return (p * gauss_one) + ((1. - p) * gauss_two) + entropy 581 | 582 | def _mle_mixture_gaussians( 583 | self, z, p=0.1, 584 | mu_one=0., mu_two=0., sigma_one=1., sigma_two=1. 585 | ): 586 | gauss_one = self._log_gaussian(z, mu_one, sigma_one) 587 | gauss_two = self._log_gaussian(z, mu_two, sigma_two) 588 | return (p * gauss_one) + ((1. - p) * gauss_two) 589 | 590 | def _anneal(self, step): 591 | warm_up = self.hyperparams["sampler_hyperparams"]["warm_up"] 592 | annealing_type = self.hyperparams["sampler_hyperparams"]["annealing_type"] 593 | if annealing_type == "linear": 594 | return min(step / warm_up, 1.) 595 | elif annealing_type == "piecewise_linear": 596 | return clamp(torch.sigmoid(torch.tensor(step-warm_up).float()).item() * ((step-warm_up)/warm_up), 0., 1.) 597 | elif annealing_type == "sigmoid": 598 | slope = self.hyperparams["sampler_hyperparams"]["sigmoid_slope"] 599 | return torch.sigmoid(torch.tensor(slope * (step - warm_up))).item() 600 | 601 | def _anneal_embedding(self, step): 602 | warm_up = self.hyperparams["embedding_hyperparams"]["warm_up"] 603 | annealing_type = self.hyperparams["embedding_hyperparams"]["annealing_type"] 604 | if annealing_type == "linear": 605 | return min(step / warm_up, 1.) 606 | elif annealing_type == "piecewise_linear": 607 | return clamp(torch.sigmoid(torch.tensor(step-warm_up).float()).item() * ((step-warm_up)/warm_up), 0., 1.) 608 | elif annealing_type == "sigmoid": 609 | slope = self.hyperparams["embedding_hyperparams"]["sigmoid_slope"] 610 | return torch.sigmoid(torch.tensor(slope * (step-warm_up))).item() 611 | 612 | def sampler(self, mu, log_sigma, stddev=1.): 613 | if self.hyperparams["embedding_hyperparams"]["anneal_noise"]: 614 | stddev = self._anneal_embedding(self.step) 615 | # return dist.Normal(mu, log_sigma.exp() * stddev).rsample() 616 | eps = torch.zeros_like(log_sigma).normal_(std=stddev) 617 | return mu + log_sigma.exp() * eps 618 | 619 | def generate(self, mode=True): 620 | for module in self.decoder.dilation_blocks(): 621 | if hasattr(module, "generate") and callable(module.generate): 622 | module.generate(mode) 623 | return self 624 | 625 | def weight_costs(self): 626 | return ( 627 | self.decoder.start_conv.weight_costs() + 628 | tuple(cost for layer in self.decoder.dilation_blocks for cost in layer.weight_costs()) + 629 | self.decoder.end_conv.weight_costs() 630 | ) 631 | 632 | def parameter_count(self): 633 | return sum(param.numel() for param in self.parameters()) 634 | 635 | def encode(self, inputs, input_masks): 636 | enc_params = self.hyperparams['encoder'] 637 | nonlin = nonlinearity(enc_params['embedding_nnet_nonlinearity']) 638 | 639 | up_val_1d = self.encoder.start_conv(inputs) 640 | for convnet in self.encoder.dilation_blocks: 641 | up_val_1d = convnet(up_val_1d, input_masks) 642 | 643 | up_val_1d = up_val_1d * input_masks 644 | up_val_mu_logsigma_2d = up_val_1d.sum(dim=[2, 3]) / input_masks.sum(dim=[2, 3]) 645 | 646 | up_val_mu_2d = nonlin(self.encoder.emb_mu_one(up_val_mu_logsigma_2d)) 647 | up_val_log_sigma_2d = nonlin(self.encoder.emb_log_sigma_one(up_val_mu_logsigma_2d)) 648 | 649 | mu_2d = self.encoder.emb_mu_out(up_val_mu_2d) 650 | log_sigma_2d = self.encoder.emb_log_sigma_out(up_val_log_sigma_2d) 651 | 652 | self.image_summaries['mu'] = dict( 653 | img=mu_2d.unsqueeze(-1).unsqueeze(-1).permute(2, 1, 0, 3).detach(), max_outputs=1) 654 | self.image_summaries['log_sigma'] = dict( 655 | img=log_sigma_2d.unsqueeze(-1).unsqueeze(-1).permute(2, 1, 0, 3).detach(), max_outputs=1) 656 | 657 | return mu_2d, log_sigma_2d 658 | 659 | def decode(self, inputs, input_masks, z): 660 | dec_params = self.hyperparams['decoder'] 661 | 662 | z = z.unsqueeze(-1).unsqueeze(-1).expand((-1, -1, 1, inputs.size(3))) 663 | if dec_params['positional_embedding']: 664 | number_range = torch.arange(0, inputs.size(3), dtype=inputs.dtype, device=inputs.device) 665 | number_range = number_range.view((1, 1, 1, inputs.size(3))) 666 | pos_embed = torch.exp(-0.5 * (number_range - self.rbf_locations).pow(2)) 667 | pos_embed = pos_embed.expand((inputs.size(0), -1, 1, -1)) 668 | else: 669 | pos_embed = torch.tensor([], dtype=inputs.dtype, device=inputs.device) 670 | input_1d = torch.cat((inputs, z, pos_embed), 1) 671 | 672 | up_val_1d = self.decoder.start_conv(input_1d) 673 | for convnet in self.decoder.dilation_blocks: 674 | up_val_1d = convnet(up_val_1d, input_masks, additional_input=z) 675 | self.image_summaries['LayerFeatures'] = dict(img=up_val_1d.permute(0, 1, 3, 2).detach(), max_outputs=3) 676 | 677 | up_val_1d = self.decoder.end_conv(up_val_1d) 678 | return up_val_1d 679 | 680 | def forward(self, inputs, input_masks): 681 | """ 682 | :param inputs: (N, C_in, 1, L) 683 | :param input_masks: (N, 1, 1, L) 684 | :return: up_val_1d: (N, C_out, 1, L), kl_embedding: (N, C_emb) 685 | """ 686 | mu_2d, log_sigma_2d = self.encode(inputs, input_masks) 687 | 688 | kl_embedding = self._kl_standard_normal(mu_2d, log_sigma_2d) 689 | z_2d = self.sampler(mu_2d, log_sigma_2d) 690 | self.image_summaries['z'] = dict( 691 | img=z_2d.unsqueeze(-1).unsqueeze(-1).permute(2, 1, 0, 3).detach(), max_outputs=1) 692 | 693 | up_val_1d = self.decode(inputs, input_masks, z_2d) 694 | 695 | return up_val_1d, kl_embedding 696 | 697 | @staticmethod 698 | def reconstruction_loss(seq_logits, target_seqs, mask): 699 | """ 700 | :param seq_logits: (N, C, 1, L) 701 | :param target_seqs: (N, C, 1, L) as one-hot 702 | :param mask: (N, 1, 1, L) 703 | """ 704 | seq_reconstruct = F.log_softmax(seq_logits, 1) 705 | cross_entropy = F.nll_loss(seq_reconstruct, target_seqs.argmax(1), reduction='none') 706 | cross_entropy = cross_entropy * mask.squeeze(1) 707 | reconstruction_per_seq = cross_entropy.sum([1, 2]) 708 | bitperchar_per_seq = reconstruction_per_seq / mask.sum([1, 2, 3]) 709 | reconstruction_loss = reconstruction_per_seq.mean() 710 | bitperchar = bitperchar_per_seq.mean() 711 | return { 712 | 'seq_reconstruct_per_char': seq_reconstruct, 713 | 'ce_loss': reconstruction_loss, 714 | 'ce_loss_per_seq': reconstruction_per_seq, 715 | 'ce_loss_per_char': cross_entropy.squeeze(1), 716 | 'bitperchar': bitperchar, 717 | 'bitperchar_per_seq': bitperchar_per_seq 718 | } 719 | 720 | def calculate_loss( 721 | self, seq_logits, kl_embedding, target_seqs, mask, n_eff 722 | ): 723 | """ 724 | 725 | :param seq_logits: (N, C, 1, L) 726 | :param kl_embedding: (N, C) 727 | :param target_seqs: (N, C, 1, L) as one-hot 728 | :param mask: (N, 1, 1, L) 729 | :param n_eff: 730 | :return: dict 731 | """ 732 | hyperparams = self.hyperparams 733 | 734 | # cross-entropy 735 | reconstruction_loss = self.reconstruction_loss( 736 | seq_logits, target_seqs, mask 737 | ) 738 | 739 | # regularization 740 | weight_cost = torch.stack(self.weight_costs()).sum() / n_eff 741 | weight_cost = weight_cost * hyperparams["optimization"]["l2_lambda"] 742 | kl_weight_loss = weight_cost 743 | kl_loss = weight_cost 744 | 745 | # embedding calculation 746 | embed_cost_per_seq = kl_embedding.sum(1) 747 | kl_embedding_loss = embed_cost_per_seq.mean() 748 | 749 | # merge losses 750 | loss_per_seq = reconstruction_loss['ce_loss_per_seq'] + embed_cost_per_seq * self._anneal_embedding(self.step) 751 | loss = reconstruction_loss['ce_loss'] + weight_cost + kl_embedding_loss * self._anneal_embedding(self.step) 752 | 753 | # KL loss 754 | if hyperparams["optimization"]["bayesian_logits"] or hyperparams["optimization"]["mle_logits"]: 755 | if self.hyperparams["optimization"]["mle_logits"]: 756 | kl_logits = - self._mle_mixture_gaussians( 757 | seq_logits, p=.6, mu_one=0., mu_two=0., sigma_one=1.25, sigma_two=3. 758 | ) 759 | else: 760 | kl_logits = None 761 | 762 | kl_logits = kl_logits * mask 763 | kl_logits_per_seq = kl_logits.sum([1, 2, 3]) 764 | loss_per_seq = loss_per_seq + kl_logits_per_seq 765 | kl_logits_loss = kl_logits_per_seq.mean() 766 | kl_loss += kl_logits_loss 767 | loss = loss + self._anneal_embedding(self.step) * kl_logits_loss 768 | 769 | seq_reconstruct = reconstruction_loss.pop('seq_reconstruct_per_char') 770 | self.image_summaries['SeqReconstruct'] = dict(img=seq_reconstruct.permute(0, 1, 3, 2).detach(), max_outputs=3) 771 | self.image_summaries['SeqTarget'] = dict(img=target_seqs.permute(0, 1, 3, 2).detach(), max_outputs=3) 772 | self.image_summaries['SeqDelta'] = dict( 773 | img=(seq_reconstruct - target_seqs).permute(0, 1, 3, 2).detach(), max_outputs=3) 774 | 775 | output = { 776 | 'loss': loss, 777 | 'ce_loss': None, 778 | 'bitperchar': None, 779 | 'loss_per_seq': loss_per_seq, 780 | 'bitperchar_per_seq': None, 781 | 'ce_loss_per_seq': None, 782 | 'kl_embedding_loss': kl_embedding_loss, 783 | 'kl_loss': kl_loss, 784 | 'weight_cost': weight_cost, 785 | } 786 | output.update(reconstruction_loss) 787 | return output 788 | 789 | def calc_mi(self, x, x_mask): 790 | """Approximate the mutual information between x and z 791 | I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) 792 | 793 | Adapted from https://github.com/jxhe/vae-lagging-encoder 794 | return: Float 795 | """ 796 | # [x_batch, nz] 797 | mu, logstd = self.encode(x, x_mask) 798 | logvar = logstd * 2.0 799 | 800 | x_batch, nz = mu.size() 801 | 802 | # E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1) 803 | neg_entropy = (-0.5 * nz * math.log(2 * math.pi) - 0.5 * (1 + logvar).sum(-1)).mean() 804 | 805 | # [z_batch, 1, nz] 806 | z_samples = dist.Normal(mu, logstd.exp()).rsample().unsqueeze(1) 807 | 808 | # [1, x_batch, nz] 809 | mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0) 810 | var = logvar.exp() 811 | 812 | # (z_batch, x_batch, nz) 813 | dev = z_samples - mu 814 | 815 | # (z_batch, x_batch) 816 | log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \ 817 | 0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1)) 818 | 819 | # log q(z): aggregate posterior 820 | # [z_batch] 821 | log_qz = torch.logsumexp(log_density, dim=1) - math.log(x_batch) 822 | 823 | return (neg_entropy - log_qz.mean(-1)).item() 824 | 825 | 826 | class AutoregressiveVAEFR(AutoregressiveFR): 827 | sub_model_class = AutoregressiveVAE 828 | model_type = 'autoregressive_vae_fr' 829 | 830 | def __init__( 831 | self, 832 | **kwargs 833 | ): 834 | super(AutoregressiveVAEFR, self).__init__(**kwargs) 835 | 836 | @property 837 | def enable_gradient(self): 838 | return self.model.model_f.enable_gradient 839 | 840 | @enable_gradient.setter 841 | def enable_gradient(self, value): 842 | self.model.model_f.enable_gradient = value 843 | self.model.model_r.enable_gradient = value 844 | 845 | def encoder_parameters(self): 846 | return list(self.model.model_f.encoder.parameters()) + list(self.model.model_r.encoder.parameters()) 847 | 848 | def decoder_parameters(self): 849 | return list(self.model.model_f.decoder.parameters()) + list(self.model.model_r.decoder.parameters()) 850 | 851 | def encode(self, input_f, mask_f, input_r, mask_r): 852 | mu_f, log_sigma_f = self.model.model_f.encode(input_f, mask_f) 853 | mu_r, log_sigma_r = self.model.model_r.encode(input_r, mask_r) 854 | return mu_f, log_sigma_f, mu_r, log_sigma_r 855 | 856 | def decode(self, input_f, mask_f, z_f, input_r, mask_r, z_r): 857 | up_val_1d_f = self.model.model_f.encode(input_f, mask_f, z_f) 858 | up_val_1d_r = self.model.model_r.encode(input_r, mask_r, z_r) 859 | return up_val_1d_f, up_val_1d_r 860 | 861 | def forward(self, input_f, mask_f, input_r, mask_r): 862 | output_logits_f, kl_embedding_f = self.model.model_f(input_f, mask_f) 863 | output_logits_r, kl_embedding_r = self.model.model_r(input_r, mask_r) 864 | return output_logits_f, kl_embedding_f, output_logits_r, kl_embedding_r 865 | 866 | def calc_mi(self, input_f, mask_f, input_r, mask_r): 867 | mi_f = self.model.model_f.calc_mi(input_f, mask_f) 868 | mi_r = self.model.model_r.calc_mi(input_r, mask_r) 869 | return (mi_f + mi_r) / 2. 870 | -------------------------------------------------------------------------------- /src/seqdesign_pt/autoregressive_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | import torch.utils.data 10 | 11 | from seqdesign_pt.model_logging import Logger 12 | 13 | 14 | class AutoregressiveTrainer: 15 | default_params = { 16 | 'optimizer': 'Adam', 17 | 'lr': 0.001, 18 | 'weight_decay': 0, 19 | 'clip': 100.0, 20 | 'snapshot_path': None, 21 | 'snapshot_name': 'snapshot', 22 | 'snapshot_interval': 1000, 23 | 'fr_separate_batches': True, 24 | } 25 | 26 | def __init__( 27 | self, 28 | model, 29 | data_loader, 30 | optimizer=None, 31 | params=None, 32 | lr=None, 33 | weight_decay=None, 34 | gradient_clipping=None, 35 | logger=Logger(), 36 | snapshot_path=None, 37 | snapshot_name=None, 38 | snapshot_interval=None, 39 | snapshot_exec_template=None, 40 | device=torch.device('cpu') 41 | ): 42 | self.params = self.default_params.copy() 43 | if params is not None: 44 | self.params.update(params) 45 | if optimizer is not None: 46 | self.params['optimizer'] = optimizer 47 | if lr is not None: 48 | self.params['lr'] = lr 49 | if weight_decay is not None: 50 | self.params['weight_decay'] = weight_decay 51 | if gradient_clipping is not None: 52 | self.params['clip'] = gradient_clipping 53 | if snapshot_path is not None: 54 | self.params['snapshot_path'] = snapshot_path 55 | if snapshot_name is not None: 56 | self.params['snapshot_name'] = snapshot_name 57 | if snapshot_interval is not None: 58 | self.params['snapshot_interval'] = snapshot_interval 59 | if snapshot_exec_template is not None: 60 | self.params['snapshot_exec_template'] = snapshot_exec_template 61 | 62 | self.model = model 63 | self.loader = data_loader 64 | 65 | self.run_fr = 'fr' in model.model_type 66 | self.optimizer_type = getattr(optim, self.params['optimizer']) 67 | self.logger = logger 68 | self.logger.trainer = self 69 | self.device = device 70 | 71 | self.optimizer = self.optimizer_type( 72 | params=self.model.parameters(), 73 | lr=self.params['lr'], weight_decay=self.params['weight_decay']) 74 | 75 | def train(self, steps=1e8): 76 | self.model.train() 77 | 78 | data_iter = iter(self.loader) 79 | n_eff = self.loader.dataset.n_eff 80 | 81 | # print(' step step-t load-t loss CE-loss bitperchar l2-norm', flush=True) 82 | for step in range(int(self.model.step) + 1, int(steps) + 1): 83 | self.model.step = step 84 | start = time.time() 85 | 86 | if self.run_fr: 87 | if self.params['fr_separate_batches']: 88 | batch_f = next(data_iter) 89 | batch_r = next(data_iter) 90 | batch = { 91 | 'decoder_input': batch_f['decoder_input'], 'decoder_output': batch_f['decoder_output'], 92 | 'decoder_mask': batch_f['decoder_mask'], 93 | 'decoder_input_r': batch_r['decoder_input_r'], 'decoder_output_r': batch_r['decoder_output_r'], 94 | 'decoder_mask_r': batch_r['decoder_mask'], 95 | } 96 | else: 97 | batch = next(data_iter) 98 | batch['decoder_mask_r'] = batch['decoder_mask'] 99 | for key in batch.keys(): 100 | if isinstance(batch[key], torch.Tensor): 101 | batch[key] = batch[key].to(self.device, non_blocking=True) 102 | data_load_time = time.time() - start 103 | 104 | output_logits_f, output_logits_r = self.model( 105 | batch['decoder_input'], batch['decoder_mask'], 106 | batch['decoder_input_r'], batch['decoder_mask_r']) 107 | losses = self.model.calculate_loss( 108 | output_logits_f, batch['decoder_output'], batch['decoder_mask'], n_eff, 109 | output_logits_r, batch['decoder_output_r'], batch['decoder_mask_r'], n_eff) 110 | else: 111 | batch = next(data_iter) 112 | for key in batch.keys(): 113 | if isinstance(batch[key], torch.Tensor): 114 | batch[key] = batch[key].to(self.device, non_blocking=True) 115 | data_load_time = time.time() - start 116 | 117 | output_logits_f = self.model(batch['decoder_input'], batch['decoder_mask']) 118 | losses = self.model.calculate_loss( 119 | output_logits_f, batch['decoder_output'], batch['decoder_mask'], n_eff) 120 | 121 | if step in [1, 10, 100, 1000, 10000, 100000]: 122 | try: 123 | print(f'step {step:6d}: ' 124 | f'GPU Mem Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB, ' 125 | f'Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB', 126 | flush=True) 127 | except (AttributeError, RuntimeError): 128 | pass 129 | 130 | self.optimizer.zero_grad() 131 | losses['loss'].backward() 132 | 133 | if self.params['clip'] is not None: 134 | total_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.params['clip']) 135 | else: 136 | total_norm = 0.0 137 | 138 | nan_check = False 139 | if nan_check and ( 140 | any(torch.isnan(param).any() for param in self.model.parameters()) or 141 | any(torch.isnan(param.grad).any() for param in self.model.parameters()) 142 | ): 143 | print("nan detected:") 144 | print("{: 8d} {:6.3f} {:5.4f} {:11.6f} {:11.6f} {:11.8f} {:10.6f}".format( 145 | step, time.time() - start, data_load_time, 146 | losses['loss'].detach(), losses['ce_loss'].detach(), 147 | losses['bitperchar'].detach(), losses['kl_embedding_loss'].detach())) 148 | print('grad norm', total_norm) 149 | print('params', [name for name, param in self.model.named_parameters() if torch.isnan(param).any()]) 150 | print('grads', [name for name, param in self.model.named_parameters() if torch.isnan(param.grad).any()]) 151 | self.save_state(last_batch=batch) 152 | break 153 | 154 | self.optimizer.step() 155 | 156 | for key in losses: 157 | losses[key] = losses[key].detach() 158 | if self.run_fr and 'per_seq' not in key and '_f' not in key and '_r' not in key: 159 | losses[key] /= 2 160 | losses.update({'grad_norm': total_norm}) 161 | 162 | if step in [1000, 2500, 5000, 10000, 25000, 50000, 100000, 250000, 500000, 1000000, 2500000] or \ 163 | step % self.params['snapshot_interval'] == 0: 164 | if self.params['snapshot_path'] is None: 165 | continue 166 | self.save_state() 167 | 168 | self.logger.log(step, losses, total_norm, load_time=data_load_time) 169 | # print("{: 8d} {:6.3f} {:5.4f} {:11.6f} {:11.6f} {:11.8f} {:10.6f}".format( 170 | # step, time.time()-start, data_load_time, 171 | # losses['loss'], losses['ce_loss'], losses['bitperchar'], losses['kl_embedding_loss']), flush=True) 172 | 173 | def validate(self, batch_size=48): 174 | return 0.0, 0.0 175 | # self.model.eval() 176 | # with torch.no_grad(): 177 | # ( 178 | # prot_decoder_input_f, prot_decoder_output_f, prot_mask_decoder, 179 | # prot_decoder_input_r, prot_decoder_output_r, 180 | # n_eff 181 | # ) = self.loader.dataset.generate_test_data(self, batch_size, matching=True) # TODO write generate_test_data 182 | # if self.run_fr: 183 | # output_logits_f, output_logits_r = self.model( 184 | # prot_decoder_input_f, prot_mask_decoder, prot_decoder_input_r, prot_mask_decoder) 185 | # output_logits = torch.cat((output_logits_f, output_logits_r), dim=0) 186 | # target_seqs = torch.cat((prot_decoder_output_f, prot_decoder_output_r), dim=0) 187 | # mask = torch.cat((prot_mask_decoder, prot_mask_decoder), dim=0) 188 | # else: 189 | # output_logits = self.model(prot_decoder_input_f, prot_mask_decoder) 190 | # target_seqs = prot_decoder_output_f 191 | # mask = prot_mask_decoder 192 | # 193 | # cross_entropy = F.cross_entropy(output_logits, target_seqs.argmax(1), reduction='none') 194 | # cross_entropy = cross_entropy * mask.squeeze(1) 195 | # reconstruction_per_seq = cross_entropy.sum([1, 2]) / mask.sum([1, 2, 3]) 196 | # reconstruction_loss = reconstruction_per_seq.mean() 197 | # accuracy_per_seq = target_seqs[output_logits.argmax(1, keepdim=True)].sum([1, 2]) / mask.sum([1, 2, 3]) 198 | # avg_accuracy = accuracy_per_seq.mean() 199 | # self.model.train() 200 | # return reconstruction_loss, avg_accuracy 201 | 202 | def test(self, data_loader, model_eval=True, num_samples=1, return_logits=False, return_ce=False): 203 | if model_eval: 204 | self.model.eval() 205 | 206 | alphabet_size = len(data_loader.dataset.alphabet) 207 | batch_size = data_loader.dataset.batch_size 208 | max_seq_len = data_loader.dataset.max_seq_len 209 | 210 | ce_fill_val = np.nan 211 | logits_f, logits_r, ce_f, ce_r = None, None, None, None 212 | if return_logits: 213 | logits_f = np.zeros((num_samples, data_loader.dataset.n_eff, alphabet_size, max_seq_len + 1)) 214 | logits_r = np.zeros((num_samples, data_loader.dataset.n_eff, alphabet_size, max_seq_len + 1)) 215 | if return_ce: 216 | ce_f = np.full((num_samples, data_loader.dataset.n_eff, max_seq_len+1), ce_fill_val) 217 | ce_r = np.full((num_samples, data_loader.dataset.n_eff, max_seq_len+1), ce_fill_val) 218 | 219 | print('sample step step-t CE-loss bit-per-char', flush=True) 220 | output = { 221 | 'name': [], 222 | 'mean': [], 223 | 'forward': [], 224 | 'reverse': [], 225 | 'bitperchar': [], 226 | 'sequence': [] 227 | } 228 | if not self.run_fr: 229 | del output['forward'] 230 | del output['reverse'] 231 | 232 | for i_sample in range(num_samples): 233 | output_i = { 234 | 'name': [], 235 | 'mean': [], 236 | 'forward': [], 237 | 'reverse': [], 238 | 'bitperchar': [], 239 | 'sequence': [] 240 | } 241 | if not self.run_fr: 242 | del output_i['forward'] 243 | del output_i['reverse'] 244 | 245 | for i_batch, batch in enumerate(data_loader): 246 | start = time.time() 247 | for key in batch.keys(): 248 | if isinstance(batch[key], torch.Tensor): 249 | batch[key] = batch[key].to(self.device, non_blocking=True) 250 | 251 | with torch.no_grad(): 252 | if self.run_fr: 253 | output_logits_f, output_logits_r = self.model( 254 | batch['decoder_input'], batch['decoder_mask'], 255 | batch['decoder_input_r'], batch['decoder_mask']) 256 | losses = self.model.reconstruction_loss( 257 | output_logits_f, batch['decoder_output'], batch['decoder_mask'], 258 | output_logits_r, batch['decoder_output_r'], batch['decoder_mask']) 259 | else: 260 | output_logits_f = self.model(batch['decoder_input'], batch['decoder_mask']) 261 | output_logits_r = None 262 | losses = self.model.reconstruction_loss( 263 | output_logits_f, batch['decoder_output'], batch['decoder_mask']) 264 | 265 | ce_loss = torch.stack((losses['ce_loss_per_seq_f'], losses['ce_loss_per_seq_r'])).cpu() 266 | bitperchar_per_seq = torch.stack((losses['bitperchar_per_seq_f'], losses['bitperchar_per_seq_r'])).cpu() 267 | 268 | if self.run_fr: 269 | ce_loss_per_seq = ce_loss.mean(0) 270 | bitperchar_per_seq = bitperchar_per_seq.mean(0) 271 | else: 272 | ce_loss_per_seq = ce_loss 273 | 274 | if return_logits or return_ce: 275 | ce_loss_per_char = torch.stack((losses['ce_loss_per_char_f'], losses['ce_loss_per_char_r'])) 276 | batch_max_len = ce_loss_per_char.size(-1) 277 | 278 | if self.run_fr: 279 | if return_logits: 280 | logits_f[i_sample, i_batch * batch_size:(i_batch + 1) * batch_size, :, 0:batch_max_len]\ 281 | = output_logits_f.squeeze(2).cpu().numpy() 282 | logits_r[i_sample, i_batch * batch_size:(i_batch + 1) * batch_size, :, 0:batch_max_len]\ 283 | = output_logits_r.squeeze(2).cpu().numpy() 284 | 285 | if return_ce: 286 | ce_mask = batch['decoder_mask'] == 0 287 | ce_mask = ce_mask.squeeze(1).squeeze(1).unsqueeze(0) 288 | ce_loss_per_char.masked_fill_(ce_mask, ce_fill_val) 289 | ce_f[i_sample, i_batch * batch_size:(i_batch + 1) * batch_size, 0:batch_max_len] = \ 290 | ce_loss_per_char[0].cpu().numpy() 291 | ce_r[i_sample, i_batch * batch_size:(i_batch + 1) * batch_size, 0:batch_max_len] = \ 292 | ce_loss_per_char[1].cpu().numpy() 293 | else: 294 | if return_logits: 295 | logits_f[i_sample, i_batch * batch_size:(i_batch + 1) * batch_size, :, 0:batch_max_len]\ 296 | = output_logits_f.squeeze(2).cpu().numpy() 297 | 298 | if return_ce: 299 | ce_mask = batch['decoder_mask'] == 0 300 | ce_mask = ce_mask.squeeze(1).squeeze(1) 301 | ce_loss_per_char.masked_fill_(ce_mask, ce_fill_val) 302 | ce_f[i_sample, i_batch * batch_size:(i_batch + 1) * batch_size, 0:batch_max_len] = \ 303 | ce_loss_per_char.cpu().numpy() 304 | 305 | output_i['name'].extend(batch['names']) 306 | output_i['sequence'].extend(batch['sequences']) 307 | output_i['mean'].extend(ce_loss_per_seq.numpy()) 308 | output_i['bitperchar'].extend(bitperchar_per_seq.numpy()) 309 | 310 | if self.run_fr: 311 | ce_loss_f = ce_loss[0] 312 | ce_loss_r = ce_loss[1] 313 | output_i['forward'].extend(ce_loss_f.numpy()) 314 | output_i['reverse'].extend(ce_loss_r.numpy()) 315 | 316 | print("{: 4d} {: 8d} {:6.3f} {:11.6f} {:11.6f}".format( 317 | i_sample, i_batch, time.time()-start, ce_loss_per_seq.mean(), bitperchar_per_seq.mean()), 318 | flush=True) 319 | 320 | output['name'] = output_i['name'] 321 | output['sequence'] = output_i['sequence'] 322 | output['bitperchar'].append(output_i['bitperchar']) 323 | output['mean'].append(output_i['mean']) 324 | 325 | if self.run_fr: 326 | output['forward'].append(output_i['forward']) 327 | output['reverse'].append(output_i['reverse']) 328 | 329 | output['bitperchar'] = np.array(output['bitperchar']).mean(0) 330 | output['mean'] = np.array(output['mean']).mean(0) 331 | 332 | if self.run_fr: 333 | output['forward'] = np.array(output['forward']).mean(0) 334 | output['reverse'] = np.array(output['reverse']).mean(0) 335 | 336 | self.model.train() 337 | if return_logits or return_ce: 338 | logits = dict() 339 | if return_logits: 340 | logits_f = logits_f.mean(0) 341 | logits_r = logits_r.mean(0) 342 | logits = dict(logits_f=logits_f, logits_r=logits_r) 343 | if return_ce: 344 | ce_f = ce_f.mean(0) 345 | ce_r = ce_r.mean(0) 346 | logits.update(dict(ce_f=ce_f, ce_r=ce_r)) 347 | return output, logits 348 | else: 349 | return output 350 | 351 | def save_state(self, last_batch=None): 352 | snapshot = f"{self.params['snapshot_path']}/{self.params['snapshot_name']}/{self.params['snapshot_name']}.ckpt-{self.model.step}.pth" 353 | revive_exec = f"{self.params['snapshot_path']}/revive_executable/{self.params['snapshot_name']}.sh" 354 | if not os.path.exists(os.path.dirname(snapshot)): 355 | os.makedirs(os.path.dirname(snapshot), exist_ok=True) 356 | torch.save( 357 | { 358 | 'step': self.model.step, 359 | 'model_type': self.model.model_type, 360 | 'model_state_dict': self.model.state_dict(), 361 | 'model_dims': self.model.dims, 362 | 'model_hyperparams': self.model.hyperparams, 363 | 'optimizer_state_dict': self.optimizer.state_dict(), 364 | 'train_params': self.params, 365 | 'dataset_params': self.loader.dataset.params, 366 | 'last_batch': last_batch 367 | }, 368 | snapshot 369 | ) 370 | if 'snapshot_exec_template' in self.params: 371 | if not os.path.exists(os.path.dirname(revive_exec)): 372 | os.makedirs(os.path.dirname(revive_exec), exist_ok=True) 373 | with open(revive_exec, "w") as f: 374 | snapshot_exec = self.params['snapshot_exec_template'].format( 375 | restore=os.path.abspath(snapshot) 376 | ) 377 | f.write(snapshot_exec) 378 | 379 | def load_state(self, checkpoint, map_location=None): 380 | if not isinstance(checkpoint, dict): 381 | checkpoint = torch.load(checkpoint, map_location=map_location) 382 | if self.model.model_type != checkpoint['model_type']: 383 | print("Warning: model type mismatch: loaded type {} for model type {}".format( 384 | checkpoint['model_type'], self.model.model_type 385 | )) 386 | if self.model.hyperparams != checkpoint['model_hyperparams']: 387 | print("Warning: model hyperparameter mismatch") 388 | self.model.load_state_dict(checkpoint['model_state_dict']) 389 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 390 | self.model.step = checkpoint['step'] 391 | self.params.update(checkpoint['train_params']) 392 | 393 | 394 | class AutoregressiveVAETrainer(AutoregressiveTrainer): 395 | default_params = { 396 | 'optimizer': 'Adam', 397 | 'lr': 0.001, 398 | 'weight_decay': 0, 399 | 'clip': 100.0, 400 | 'lagging_inference': True, 401 | 'lag_inf_aggressive': True, 402 | 'lag_inf_convergence': True, 403 | 'lag_inf_inner_loop_max_steps': 100, 404 | 'snapshot_path': None, 405 | 'snapshot_name': 'snapshot', 406 | 'snapshot_interval': 1000, 407 | } 408 | 409 | def __init__( 410 | self, 411 | model, 412 | data_loader, 413 | optimizer=None, 414 | params=None, 415 | lr=None, 416 | weight_decay=None, 417 | gradient_clipping=None, 418 | logger=Logger(), 419 | snapshot_path=None, 420 | snapshot_name=None, 421 | snapshot_interval=None, 422 | snapshot_exec_template=None, 423 | device=torch.device('cpu') 424 | ): 425 | super(AutoregressiveVAETrainer, self).__init__( 426 | model, data_loader, 427 | optimizer=optimizer, params=params, lr=lr, weight_decay=weight_decay, gradient_clipping=gradient_clipping, 428 | logger=logger, snapshot_path=snapshot_path, snapshot_name=snapshot_name, 429 | snapshot_interval=snapshot_interval, snapshot_exec_template=snapshot_exec_template, 430 | device=device 431 | ) 432 | 433 | self.enc_optimizer = self.optimizer_type( 434 | params=self.model.encoder_parameters(), 435 | lr=self.params['lr'], weight_decay=self.params['weight_decay']) 436 | self.dec_optimizer = self.optimizer_type( 437 | params=self.model.decoder_parameters(), 438 | lr=self.params['lr'], weight_decay=self.params['weight_decay']) 439 | 440 | def train(self, steps=1e8): 441 | self.model.train() 442 | device = self.device 443 | params = self.params 444 | epoch = 1 445 | pre_mi = 0 446 | 447 | data_iter = iter(self.loader) 448 | n_eff = self.loader.dataset.n_eff 449 | 450 | test_batch = [next(data_iter) for _ in range(4)] 451 | 452 | print(' step step-t load-t opt loss CE-loss bitperchar l2-norm KL-loss', flush=True) 453 | for step in range(int(self.model.step) + 1, int(steps) + 1): 454 | self.model.step = step 455 | start = time.time() 456 | 457 | batch = next(data_iter) 458 | for key in batch.keys(): 459 | if isinstance(batch[key], torch.Tensor): 460 | batch[key] = batch[key].to(self.device, non_blocking=True) 461 | data_load_time = time.time()-start 462 | 463 | if params['lagging_inference'] and self.params['lag_inf_aggressive']: 464 | self.model.enable_gradient = 'e' 465 | 466 | # lagging inference variables 467 | sub_batch = batch 468 | burn_cur_loss = 0 469 | burn_total_chars = 0 470 | burn_pre_loss = 1e4 471 | 472 | for sub_iter in range(1, params['lag_inf_inner_loop_max_steps']+1): 473 | self.enc_optimizer.zero_grad() 474 | self.dec_optimizer.zero_grad() 475 | 476 | if self.run_fr: 477 | output_logits_f, kl_embedding_f, output_logits_r, kl_embedding_r = self.model( 478 | sub_batch['decoder_input'], sub_batch['decoder_mask'], 479 | sub_batch['decoder_input_r'], sub_batch['decoder_mask']) 480 | losses = self.model.calculate_loss( 481 | output_logits_f, kl_embedding_f, 482 | sub_batch['decoder_output'], sub_batch['decoder_mask'], n_eff, 483 | output_logits_r, kl_embedding_f, 484 | sub_batch['decoder_output_r'], sub_batch['decoder_mask'], n_eff) 485 | else: 486 | output_logits_f, kl_embedding_f = self.model( 487 | sub_batch['decoder_input'], sub_batch['decoder_mask']) 488 | losses = self.model.calculate_loss( 489 | output_logits_f, kl_embedding_f, 490 | sub_batch['decoder_output'], sub_batch['decoder_mask'], n_eff) 491 | 492 | burn_cur_loss += losses['loss_per_seq'].sum().item() 493 | burn_total_chars += sub_batch['decoder_mask'].sum().item() 494 | loss = losses['loss'] 495 | # ce_loss = losses['ce_loss'] 496 | # kl_loss = losses['kl_embedding_loss'] 497 | # weight_cost = losses['weight_cost'] 498 | # bitperchar = losses['bitperchar'] 499 | loss.backward() 500 | 501 | if params['clip'] is not None: 502 | nn.utils.clip_grad_norm_(self.model.parameters(), params['clip']) 503 | self.enc_optimizer.step() 504 | 505 | # print("{: 8d} {:6.3f} {:5.4f} {: >2} {:11.6f} {:11.6f} {:11.8f} {:10.6f} {:11.6g}".format( 506 | # sub_iter, time.time() - start, data_load_time, self.model.enable_gradient, 507 | # loss.detach(), ce_loss.detach(), bitperchar.detach(), weight_cost.detach(), kl_loss.detach()), 508 | # flush=True) 509 | 510 | if params['lag_inf_convergence']: 511 | if sub_iter % 15 == 0: 512 | burn_cur_loss = burn_cur_loss / burn_total_chars 513 | if burn_pre_loss < burn_cur_loss: 514 | break 515 | burn_pre_loss = burn_cur_loss 516 | burn_cur_loss = 0 517 | burn_total_chars = 0 518 | 519 | sub_batch = next(data_iter) 520 | for key in sub_batch.keys(): 521 | sub_batch[key] = sub_batch[key].to(device, non_blocking=True) 522 | 523 | self.model.enable_gradient = 'd' 524 | else: 525 | self.model.enable_gradient = 'ed' 526 | 527 | if self.run_fr: 528 | output_logits_f, kl_embedding_f, output_logits_r, kl_embedding_r = self.model( 529 | batch['decoder_input'], batch['decoder_mask'], 530 | batch['decoder_input_r'], batch['decoder_mask']) 531 | losses = self.model.calculate_loss( 532 | output_logits_f, kl_embedding_f, batch['decoder_output'], batch['decoder_mask'], n_eff, 533 | output_logits_r, kl_embedding_f, batch['decoder_output_r'], batch['decoder_mask'], n_eff) 534 | else: 535 | output_logits_f, kl_embedding_f = self.model(batch['decoder_input'], batch['decoder_mask']) 536 | losses = self.model.calculate_loss( 537 | output_logits_f, kl_embedding_f, batch['decoder_output'], batch['decoder_mask'], n_eff) 538 | 539 | loss = losses['loss'] 540 | ce_loss = losses['ce_loss'] 541 | kl_loss = losses['kl_embedding_loss'] 542 | weight_cost = losses['weight_cost'] 543 | bitperchar = losses['bitperchar'] 544 | 545 | self.enc_optimizer.zero_grad() 546 | self.dec_optimizer.zero_grad() 547 | loss.backward() 548 | 549 | if params['clip'] is not None: 550 | total_norm = nn.utils.clip_grad_norm_(self.model.parameters(), params['clip']) 551 | else: 552 | total_norm = 0.0 553 | 554 | if 'e' in self.model.enable_gradient: 555 | self.enc_optimizer.step() 556 | if 'd' in self.model.enable_gradient: 557 | self.dec_optimizer.step() 558 | 559 | if step % self.loader.dataset.n_eff < 1: 560 | self.model.eval() 561 | with torch.no_grad(): 562 | cur_mi = calc_mi(self.model, test_batch, run_fr=self.run_fr, device=self.device) 563 | au, _, au_r, _ = calc_au(self.model, test_batch, run_fr=self.run_fr, device=self.device) 564 | self.model.train() 565 | print(f"epoch: {epoch}, active units: {au}f {au_r}r") 566 | print(f"pre mi: {pre_mi:.4f}, cur mi: {cur_mi:.4f}") 567 | if self.params['lag_inf_aggressive'] and cur_mi < pre_mi: 568 | self.params['lag_inf_aggressive'] = False 569 | print("STOP BURNING") 570 | pre_mi = cur_mi 571 | 572 | if step % params['snapshot_interval'] == 0: 573 | if params['snapshot_path'] is None: 574 | continue 575 | self.save_state() 576 | 577 | self.logger.log(step, losses, total_norm) 578 | print("{: 8d} {:6.3f} {:5.4f} {: >2} {:11.6f} {:11.6f} {:11.8f} {:10.6f} {:11.6g}".format( 579 | step, time.time()-start, data_load_time, self.model.enable_gradient, 580 | loss.detach(), ce_loss.detach(), bitperchar.detach(), weight_cost.detach(), kl_loss.detach()), 581 | flush=True) 582 | 583 | 584 | def calc_mi(model, test_data_batch, run_fr=False, device=torch.device('cpu')): 585 | mi = 0 586 | num_examples = 0 587 | for batch_data in test_data_batch: 588 | batch_size = batch_data['decoder_input'].size(0) 589 | num_examples += batch_size 590 | if run_fr: 591 | prot_decoder_input, prot_mask_decoder, prot_decoder_input_r = \ 592 | batch_data['decoder_input'].to(device), batch_data['decoder_mask'].to(device), \ 593 | batch_data['decoder_input_r'].to(device) 594 | mutual_info = model.calc_mi(prot_decoder_input, prot_mask_decoder, 595 | prot_decoder_input_r, prot_mask_decoder) 596 | else: 597 | prot_decoder_input, prot_mask_decoder = \ 598 | batch_data['decoder_input'].to(device), batch_data['decoder_mask'].to(device) 599 | mutual_info = model.calc_mi(prot_decoder_input, prot_mask_decoder) 600 | mi += mutual_info * batch_size 601 | return mi / num_examples 602 | 603 | 604 | def calc_au(model, test_data_batch, delta=0.01, run_fr=False, device=torch.device('cpu')): 605 | """compute the number of active units 606 | 607 | Adapted from https://github.com/jxhe/vae-lagging-encoder 608 | """ 609 | means_sum = 0 610 | cnt = 0 611 | means_sum_r = 0 612 | for batch_data in test_data_batch: 613 | if run_fr: 614 | prot_decoder_input, prot_mask_decoder, prot_decoder_input_r = \ 615 | batch_data['decoder_input'].to(device), batch_data['decoder_mask'].to(device), \ 616 | batch_data['decoder_input_r'].to(device) 617 | mu_f, _, mu_r, _ = model.encode(prot_decoder_input, prot_mask_decoder, 618 | prot_decoder_input_r, prot_mask_decoder) 619 | else: 620 | prot_decoder_input, prot_mask_decoder = \ 621 | batch_data['decoder_input'].to(device), batch_data['decoder_mask'].to(device) 622 | mu_f, _ = model.encode(prot_decoder_input, prot_mask_decoder) 623 | mu_r = None 624 | means_sum += mu_f.sum(dim=0, keepdim=True) 625 | cnt += mu_f.size(0) 626 | if run_fr: 627 | means_sum_r += mu_r.sum(dim=0, keepdim=True) 628 | 629 | # (1, nz) 630 | mean_mean = means_sum / cnt 631 | mean_mean_r = means_sum_r / cnt 632 | 633 | var_sum = 0 634 | cnt = 0 635 | var_sum_r = 0 636 | for batch_data in test_data_batch: 637 | if run_fr: 638 | prot_decoder_input, prot_mask_decoder, prot_decoder_input_r = \ 639 | batch_data['decoder_input'].to(device), batch_data['decoder_mask'].to(device), \ 640 | batch_data['decoder_input_r'].to(device) 641 | mu_f, _, mu_r, _ = model.encode(prot_decoder_input, prot_mask_decoder, 642 | prot_decoder_input_r, prot_mask_decoder) 643 | else: 644 | prot_decoder_input, prot_mask_decoder = \ 645 | batch_data['decoder_input'].to(device), batch_data['decoder_mask'].to(device) 646 | mu_f, _ = model.encode(prot_decoder_input, prot_mask_decoder) 647 | mu_r = None 648 | var_sum += ((mu_f - mean_mean) ** 2).sum(dim=0) 649 | cnt += mu_f.size(0) 650 | if run_fr: 651 | var_sum_r += ((mu_r - mean_mean_r) ** 2).sum(dim=0) 652 | 653 | # (nz) 654 | au_var = var_sum / (cnt - 1) 655 | au = (au_var >= delta).sum().item() 656 | if run_fr: 657 | au_var_r = var_sum_r / (cnt - 1) 658 | au_r = (au_var_r >= delta).sum().item() 659 | else: 660 | au_r, au_var_r = None, None 661 | return au, au_var, au_r, au_var_r 662 | -------------------------------------------------------------------------------- /src/seqdesign_pt/aws_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import re 3 | import os 4 | 5 | if os.path.exists('/n/groups/marks/software/aws-cli/bin/aws'): 6 | AWS_BIN = '/n/groups/marks/software/aws-cli/bin/aws' 7 | else: 8 | AWS_BIN = 'aws' 9 | 10 | 11 | class AWSUtility: 12 | def __init__(self, s3_base_path, s3_project): 13 | """ 14 | :param s3_base_path: S3 URL for target S3 folder, e.g. "s3://my-bucket/my-folder" 15 | :param s3_project: Project name, used as sub-folder of the base path, e.g. "v3" 16 | """ 17 | self.s3_base_path = s3_base_path.rstrip('/') 18 | self.s3_project = s3_project 19 | 20 | @staticmethod 21 | def run_cmd(cmd): 22 | try: 23 | if cmd[0] not in ('aws', AWS_BIN): 24 | cmd = [AWS_BIN] + cmd 25 | else: 26 | cmd[0] = AWS_BIN 27 | pipes = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding='UTF-8') 28 | std_out, std_err = pipes.communicate() 29 | if pipes.returncode != 0: 30 | print(f"AWS CLI error: {pipes.returncode}") 31 | print(std_err.strip()) 32 | return pipes.returncode, None, None 33 | else: 34 | return 0, std_out, std_err 35 | except OSError: 36 | print("AWS CLI not found.") 37 | return 1, None, None 38 | 39 | def s3_cp(self, local_file, s3_file, destination='s3'): 40 | s3_file = f"{self.s3_base_path}/{self.s3_project}/{s3_file}" 41 | if destination == 's3': 42 | print("Copying file to AWS S3.") 43 | src_file, dest_file = local_file, s3_file 44 | else: 45 | print("Copying file from AWS S3.") 46 | src_file, dest_file = s3_file, local_file 47 | cmd = ['s3', 'cp', src_file, dest_file] 48 | code, std_out, std_err = self.run_cmd(cmd) 49 | if code == 0: 50 | print("Success.") 51 | 52 | def s3_sync(self, local_folder, s3_folder, destination='s3', args=()): 53 | local_folder = f"{local_folder.rstrip('/')}/" 54 | s3_folder = f"{self.s3_base_path}/{self.s3_project}/{s3_folder.rstrip('/')}/" 55 | if destination == 's3': 56 | print("Syncing data to AWS S3.") 57 | src_folder, dest_folder = local_folder, s3_folder 58 | else: 59 | print("Syncing data from AWS S3.") 60 | src_folder, dest_folder = s3_folder, local_folder 61 | cmd = ['s3', 'sync', src_folder, dest_folder, *args] 62 | code, std_out, std_err = self.run_cmd(cmd) 63 | if code == 0: 64 | print("Success.") 65 | 66 | def s3_get_file_grep(self, s3_folder, dest_folder, search_pattern): 67 | dest_folder = f"{dest_folder.rstrip('/')}/" 68 | s3_folder = f"{self.s3_base_path}/{self.s3_project}/{s3_folder.rstrip('/')}/" 69 | print(f"Finding files in {s3_folder} on AWS S3.") 70 | cmd = ['s3', 'ls', s3_folder] 71 | code, std_out, std_err = self.run_cmd(cmd) 72 | if code != 0: 73 | return False 74 | filenames = re.findall(search_pattern, std_out) 75 | if not filenames: 76 | print("No files found.") 77 | return False 78 | print(f"Found: {filenames}") 79 | for filename in filenames: 80 | filename = f"{s3_folder}{filename}" 81 | print(f"Copying file {filename} from AWS S3.") 82 | cmd = ['s3', 'cp', filename, dest_folder] 83 | code, std_out, std_err = self.run_cmd(cmd) 84 | if code != 0: 85 | return False 86 | print("Success.") 87 | return True 88 | -------------------------------------------------------------------------------- /src/seqdesign_pt/data_loaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import math 4 | import re 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import torch.distributions as dist 10 | import torch.utils.data as data 11 | from Bio.SeqIO.FastaIO import SimpleFastaParser 12 | 13 | from seqdesign_pt.utils import temp_seed 14 | 15 | PROTEIN_ALPHABET = 'ACDEFGHIKLMNPQRSTVWY*' 16 | PROTEIN_REORDERED_ALPHABET = 'DEKRHNQSTPGAVILMCFYW*' 17 | RNA_ALPHABET = 'ACGU*' 18 | DNA_ALPHABET = 'ACGT*' 19 | START_END = "*" 20 | 21 | 22 | def get_alphabet(alphabet_type='protein'): 23 | if alphabet_type == 'protein': 24 | return PROTEIN_ALPHABET, PROTEIN_REORDERED_ALPHABET 25 | elif alphabet_type == 'RNA': 26 | return RNA_ALPHABET, RNA_ALPHABET 27 | elif alphabet_type == 'DNA': 28 | return DNA_ALPHABET, DNA_ALPHABET 29 | else: 30 | raise ValueError('unknown alphabet type') 31 | 32 | 33 | class GeneratorDataset(data.Dataset): 34 | """A Dataset that can be used as a generator.""" 35 | def __init__( 36 | self, 37 | batch_size=32, 38 | unlimited_epoch=True, 39 | ): 40 | self.batch_size = batch_size 41 | self.unlimited_epoch = unlimited_epoch 42 | 43 | @property 44 | def params(self): 45 | return {"batch_size": self.batch_size, "unlimited_epoch": self.unlimited_epoch} 46 | 47 | @params.setter 48 | def params(self, d): 49 | if 'batch_size' in d: 50 | self.batch_size = d['batch_size'] 51 | if 'unlimited_epoch' in d: 52 | self.unlimited_epoch = d['unlimited_epoch'] 53 | 54 | @property 55 | def n_eff(self): 56 | raise NotImplementedError 57 | 58 | def __getitem__(self, index): 59 | raise NotImplementedError 60 | 61 | def __len__(self): 62 | if self.unlimited_epoch: 63 | return 2 ** 62 64 | else: 65 | return math.ceil(self.n_eff / self.batch_size) 66 | 67 | @staticmethod 68 | def collate_fn(batch): 69 | return batch[0] 70 | 71 | 72 | class GeneratorDataLoader(data.DataLoader): 73 | """A DataLoader used with a GeneratorDataset""" 74 | def __init__(self, dataset: GeneratorDataset, **kwargs): 75 | kwargs.update(dict( 76 | batch_size=1, shuffle=False, sampler=None, batch_sampler=None, collate_fn=dataset.collate_fn, 77 | )) 78 | super(GeneratorDataLoader, self).__init__( 79 | dataset, 80 | **kwargs) 81 | 82 | 83 | class TrainTestDataset(data.Dataset): 84 | """A Dataset that has training and testing modes""" 85 | def __init__(self): 86 | self._training = True 87 | 88 | def train(self, training=True): 89 | self._training = training 90 | 91 | def test(self): 92 | self.train(False) 93 | 94 | def __getitem__(self, index): 95 | raise NotImplementedError 96 | 97 | def __len__(self): 98 | raise NotImplementedError 99 | 100 | 101 | class TrainValTestDataset(data.Dataset): 102 | """A Dataset that has training, validation, and testing modes""" 103 | def __init__(self): 104 | self._mode = 'train' 105 | 106 | def train(self, mode='train'): 107 | self._mode = mode 108 | 109 | def val(self): 110 | self.train('val') 111 | 112 | def test(self): 113 | self.train('test') 114 | 115 | def __getitem__(self, index): 116 | raise NotImplementedError 117 | 118 | def __len__(self): 119 | raise NotImplementedError 120 | 121 | 122 | def sequences_to_decoder_onehot( 123 | sequences, input_char_map, output_char_map, reverse=False, matching=False, start_char='*', end_char='*'): 124 | num_seqs = len(sequences) 125 | max_seq_len = max([len(seq) for seq in sequences]) + 1 126 | decoder_input = np.zeros((num_seqs, len(input_char_map), 1, max_seq_len)) 127 | decoder_output = np.zeros((num_seqs, len(output_char_map), 1, max_seq_len)) 128 | decoder_mask = np.zeros((num_seqs, 1, 1, max_seq_len)) 129 | 130 | if matching: 131 | decoder_input_r = np.zeros((num_seqs, len(input_char_map), 1, max_seq_len)) 132 | decoder_output_r = np.zeros((num_seqs, len(output_char_map), 1, max_seq_len)) 133 | else: 134 | decoder_input_r = None 135 | decoder_output_r = None 136 | 137 | for i, sequence in enumerate(sequences): 138 | if reverse: 139 | sequence = sequence[::-1] 140 | 141 | decoder_input_seq = start_char + sequence 142 | decoder_output_seq = sequence + end_char 143 | 144 | if matching: 145 | sequence_r = sequence[::-1] 146 | decoder_input_seq_r = start_char + sequence_r 147 | decoder_output_seq_r = sequence_r + end_char 148 | else: 149 | decoder_input_seq_r = None 150 | decoder_output_seq_r = None 151 | 152 | for j in range(len(decoder_input_seq)): 153 | decoder_input[i, input_char_map[decoder_input_seq[j]], 0, j] = 1 154 | decoder_output[i, output_char_map[decoder_output_seq[j]], 0, j] = 1 155 | decoder_mask[i, 0, 0, j] = 1 156 | 157 | if matching: 158 | decoder_input_r[i, input_char_map[decoder_input_seq_r[j]], 0, j] = 1 159 | decoder_output_r[i, output_char_map[decoder_output_seq_r[j]], 0, j] = 1 160 | 161 | return decoder_input, decoder_output, decoder_mask, decoder_input_r, decoder_output_r 162 | 163 | 164 | def sequences_to_encoder_onehot(sequences, char_map, start_char='', end_char=''): 165 | num_seqs = len(sequences) 166 | max_seq_len = max([len(seq) for seq in sequences]) + len(start_char) + len(end_char) 167 | encoder_input = np.zeros((num_seqs, len(char_map), 1, max_seq_len)) 168 | encoder_mask = np.zeros((num_seqs, 1, 1, max_seq_len)) 169 | 170 | for i, sequence in enumerate(sequences): 171 | encoder_input_seq = start_char + sequence + end_char 172 | 173 | for j in range(len(encoder_input_seq)): 174 | encoder_input[i, char_map[encoder_input_seq[j]], 0, j] = 1 175 | encoder_mask[i, 0, 0, j] = 1 176 | 177 | return encoder_input, encoder_mask 178 | 179 | 180 | class SequenceDataset(GeneratorDataset): 181 | """Abstract sequence dataset""" 182 | SUPPORTED_OUTPUT_SHAPES = ['NCHW', 'NHWC', 'NLC'] 183 | 184 | def __init__( 185 | self, 186 | batch_size=32, 187 | unlimited_epoch=True, 188 | alphabet_type='protein', 189 | reverse=False, 190 | matching=False, 191 | output_shape='NCHW', 192 | output_types='decoder,encoder', 193 | ): 194 | super(SequenceDataset, self).__init__(batch_size=batch_size, unlimited_epoch=unlimited_epoch) 195 | 196 | self.alphabet_type = alphabet_type 197 | self.reverse = reverse 198 | self.matching = matching 199 | self.output_shape = output_shape 200 | self.output_types = output_types 201 | self.max_seq_len = -1 202 | 203 | if output_shape not in self.SUPPORTED_OUTPUT_SHAPES: 204 | raise KeyError(f'Unsupported output shape: {output_shape}') 205 | 206 | self.aa_dict = self.idx_to_aa = self.output_aa_dict = self.output_idx_to_aa = None 207 | self.update_aa_dict() 208 | 209 | @property 210 | def params(self): 211 | params = super(SequenceDataset, self).params 212 | params.update({ 213 | "alphabet_type": self.alphabet_type, 214 | "reverse": self.reverse, 215 | "matching": self.matching, 216 | "output_shape": self.output_shape, 217 | "output_types": self.output_types, 218 | }) 219 | return params 220 | 221 | @params.setter 222 | def params(self, d): 223 | GeneratorDataset.params.__set__(self, d) 224 | if 'alphabet_type' in d: 225 | self.alphabet_type = d['alphabet_type'] 226 | self.update_aa_dict() 227 | if 'reverse' in d: 228 | self.reverse = d['reverse'] 229 | if 'matching' in d: 230 | self.matching = d['matching'] 231 | if 'output_shape' in d: 232 | self.output_shape = d['output_shape'] 233 | if 'output_types' in d: 234 | self.output_types = d['output_types'] 235 | 236 | @property 237 | def alphabet(self): 238 | if self.alphabet_type == 'protein': 239 | return PROTEIN_ALPHABET 240 | elif self.alphabet_type == 'RNA': 241 | return RNA_ALPHABET 242 | elif self.alphabet_type == 'DNA': 243 | return DNA_ALPHABET 244 | 245 | @property 246 | def output_alphabet(self): 247 | return self.alphabet 248 | 249 | def update_aa_dict(self): 250 | self.aa_dict = {aa: i for i, aa in enumerate(self.alphabet)} 251 | self.idx_to_aa = {i: aa for i, aa in enumerate(self.alphabet)} 252 | self.output_aa_dict = {aa: i for i, aa in enumerate(self.output_alphabet)} 253 | self.output_idx_to_aa = {i: aa for i, aa in enumerate(self.output_alphabet)} 254 | 255 | @property 256 | def n_eff(self): 257 | raise NotImplementedError 258 | 259 | def __getitem__(self, index): 260 | raise NotImplementedError 261 | 262 | def sequences_to_onehot(self, sequences, reverse=None, matching=None): 263 | """ 264 | 265 | :param sequences: list/iterable of strings 266 | :param reverse: reverse the sequences 267 | :param matching: output forward and reverse sequences 268 | :return: dictionary of strings 269 | """ 270 | reverse = self.reverse if reverse is None else reverse 271 | matching = self.matching if matching is None else matching 272 | output = {} 273 | 274 | if 'decoder' in self.output_types: 275 | decoder_input, decoder_output, decoder_mask, decoder_input_r, decoder_output_r = \ 276 | sequences_to_decoder_onehot(sequences, self.aa_dict, self.output_aa_dict, 277 | reverse=reverse, matching=matching) 278 | 279 | if matching: 280 | output.update({ 281 | 'decoder_input': decoder_input, 282 | 'decoder_output': decoder_output, 283 | 'decoder_mask': decoder_mask, 284 | 'decoder_input_r': decoder_input_r, 285 | 'decoder_output_r': decoder_output_r 286 | }) 287 | else: 288 | output.update({ 289 | 'decoder_input': decoder_input, 290 | 'decoder_output': decoder_output, 291 | 'decoder_mask': decoder_mask 292 | }) 293 | if 'encoder' in self.output_types: 294 | encoder_input, encoder_mask = sequences_to_encoder_onehot(sequences, self.aa_dict) 295 | output.update({ 296 | 'encoder_input': encoder_input, 297 | 'encoder_mask': encoder_mask 298 | }) 299 | 300 | for key in output.keys(): 301 | output[key] = torch.as_tensor(output[key], dtype=torch.float32) 302 | if self.output_shape == 'NHWC': 303 | output[key] = output[key].permute(0, 2, 3, 1).contiguous() 304 | elif self.output_shape == 'NLC': 305 | output[key] = output[key].squeeze(2).permute(0, 2, 1).contiguous() 306 | 307 | return output 308 | 309 | 310 | class FastaDataset(SequenceDataset): 311 | """Load batches of sequences from a fasta file, either sequentially or sampled isotropically""" 312 | 313 | def __init__( 314 | self, 315 | dataset='', 316 | working_dir='.', 317 | batch_size=32, 318 | unlimited_epoch=False, 319 | alphabet_type='protein', 320 | reverse=False, 321 | matching=False, 322 | output_shape='NCHW', 323 | output_types='decoder', 324 | # TODO add shuffle parameter: iterate through shuffled sequences 325 | ): 326 | super(FastaDataset, self).__init__( 327 | batch_size=batch_size, 328 | unlimited_epoch=unlimited_epoch, 329 | alphabet_type=alphabet_type, 330 | reverse=reverse, 331 | matching=matching, 332 | output_shape=output_shape, 333 | output_types=output_types, 334 | ) 335 | self.dataset = dataset 336 | self.working_dir = working_dir 337 | 338 | self.names = None 339 | self.sequences = None 340 | 341 | self.load_data() 342 | 343 | def load_data(self): 344 | filename = os.path.join(self.working_dir, self.dataset) 345 | names_list = [] 346 | sequence_list = [] 347 | max_seq_len = 0 348 | 349 | with open(filename, 'r') as fa: 350 | for i, (title, seq) in enumerate(SimpleFastaParser(fa)): 351 | valid = True 352 | for letter in seq: 353 | if letter not in self.aa_dict: 354 | valid = False 355 | if not valid: 356 | continue 357 | 358 | names_list.append(title) 359 | sequence_list.append(seq) 360 | if len(seq) > max_seq_len: 361 | max_seq_len = len(seq) 362 | 363 | self.names = np.array(names_list) 364 | self.sequences = np.array(sequence_list) 365 | self.max_seq_len = max_seq_len 366 | 367 | print("Number of sequences:", self.n_eff) 368 | print("Max sequence length:", max_seq_len) 369 | 370 | @property 371 | def n_eff(self): 372 | return len(self.sequences) # not a true n_eff 373 | 374 | def __getitem__(self, index): 375 | """ 376 | :param index: batch index; ignored if unlimited_epoch 377 | :return: batch of size self.batch_size 378 | """ 379 | 380 | if self.unlimited_epoch: 381 | indices = np.random.randint(0, self.n_eff, self.batch_size) 382 | else: 383 | first_index = index * self.batch_size 384 | last_index = min((index+1) * self.batch_size, self.n_eff) 385 | indices = np.arange(first_index, last_index) 386 | 387 | seqs = self.sequences[indices] 388 | batch = self.sequences_to_onehot(seqs) 389 | batch['names'] = self.names[indices] 390 | batch['sequences'] = seqs 391 | return batch 392 | 393 | 394 | class SingleFamilyDataset(SequenceDataset): 395 | def __init__( 396 | self, 397 | dataset='', 398 | working_dir='.', 399 | batch_size=32, 400 | unlimited_epoch=True, 401 | alphabet_type='protein', 402 | reverse=False, 403 | matching=False, 404 | output_shape='NCHW', 405 | output_types='decoder', 406 | ): 407 | super(SingleFamilyDataset, self).__init__( 408 | batch_size=batch_size, 409 | unlimited_epoch=unlimited_epoch, 410 | alphabet_type=alphabet_type, 411 | reverse=reverse, 412 | matching=matching, 413 | output_shape=output_shape, 414 | output_types=output_types, 415 | ) 416 | self.dataset = dataset 417 | self.working_dir = working_dir 418 | 419 | self.family_name_to_sequence_list = {} 420 | self.family_name_to_sequence_weight_list = {} 421 | self.family_name_to_n_eff = {} 422 | self.family_name_list = [] 423 | self.family_idx_list = [] 424 | self.family_name = '' 425 | self.family_name_to_idx = {} 426 | self.idx_to_family_name = {} 427 | 428 | self.num_families = 0 429 | self.max_family_size = 0 430 | 431 | self.load_data() 432 | 433 | def load_data(self): 434 | max_seq_len = 0 435 | max_family_size = 0 436 | family_name = '' 437 | weight_list = [] 438 | 439 | if os.path.exists(self.dataset): 440 | f_names = [self.dataset] 441 | else: 442 | f_names = glob.glob(f'{self.working_dir}/datasets/sequences/{self.dataset}*.fa') 443 | if len(f_names) != 1: 444 | raise AssertionError('Wrong number of families: {}'.format(len(f_names))) 445 | 446 | for filename in f_names: 447 | sequence_list = [] 448 | weight_list = [] 449 | 450 | family_name = filename.rsplit('/', 1)[-1].rsplit('.', 1)[0] 451 | family_name_list = family_name.split('_') 452 | if len(family_name_list) >= 2: 453 | family_name = family_name_list[0] + '_' + family_name_list[1] 454 | print(f"Family: {family_name}") 455 | 456 | family_size = 0 457 | ind_family_idx_list = [] 458 | with open(filename, 'r') as fa: 459 | # check if first sequence header has a sequence weight 460 | line = 'start' 461 | for title, seq in SimpleFastaParser(fa): 462 | line = title 463 | break 464 | try: 465 | weight = float(line.rsplit(':', 1)[-1]) 466 | uniform_weights = False 467 | except ValueError: 468 | print(f"No sequence weights detected: {line}\nUsing uniform weights.") 469 | uniform_weights = True 470 | fa.seek(0) 471 | 472 | for i, (title, seq) in enumerate(SimpleFastaParser(fa)): 473 | weight = 1.0 if uniform_weights else float(title.rsplit(':', 1)[-1]) 474 | valid = True 475 | for letter in seq: 476 | if letter not in self.aa_dict: 477 | valid = False 478 | if not valid: 479 | continue 480 | 481 | sequence_list.append(seq) 482 | ind_family_idx_list.append(family_size) 483 | weight_list.append(weight) 484 | family_size += 1 485 | if len(seq) > max_seq_len: 486 | max_seq_len = len(seq) 487 | 488 | if family_size > max_family_size: 489 | max_family_size = family_size 490 | 491 | self.family_name_to_sequence_list[family_name] = sequence_list 492 | self.family_name_to_sequence_weight_list[family_name] = ( 493 | np.asarray(weight_list) / np.sum(weight_list) 494 | ).tolist() 495 | self.family_name_to_n_eff[family_name] = np.sum(weight_list) 496 | self.family_name = family_name 497 | self.family_name_list.append(family_name) 498 | self.family_idx_list.append(ind_family_idx_list) 499 | 500 | self.family_name = family_name 501 | self.max_seq_len = max_seq_len 502 | self.num_families = len(self.family_name_list) 503 | self.max_family_size = max_family_size 504 | 505 | print("Number of families:", self.num_families) 506 | print("Neff:", np.sum(weight_list)) 507 | print("Max family size:", max_family_size) 508 | print("Max sequence length:", max_seq_len) 509 | 510 | for i, family_name in enumerate(self.family_name_list): 511 | self.family_name_to_idx[family_name] = i 512 | self.idx_to_family_name[i] = family_name 513 | 514 | @property 515 | def n_eff(self): 516 | return self.family_name_to_n_eff[self.family_name] 517 | 518 | def __getitem__(self, index): 519 | """ 520 | :param index: ignored 521 | :return: batch of size self.batch_size 522 | """ 523 | family_name = self.family_name 524 | family_seqs = self.family_name_to_sequence_list[family_name] 525 | family_weights = self.family_name_to_sequence_weight_list[family_name] 526 | 527 | seq_idx = np.random.choice(len(family_seqs), self.batch_size, p=family_weights) 528 | seqs = [family_seqs[idx] for idx in seq_idx] 529 | 530 | batch = self.sequences_to_onehot(seqs) 531 | return batch 532 | 533 | 534 | class SingleClusteredSequenceDataset(SequenceDataset): 535 | def __init__( 536 | self, 537 | dataset='', 538 | working_dir='.', 539 | batch_size=32, 540 | unlimited_epoch=True, 541 | alphabet_type='protein', 542 | reverse=False, 543 | matching=False, 544 | output_shape='NCHW', 545 | output_types='decoder', 546 | ): 547 | super(SingleClusteredSequenceDataset, self).__init__( 548 | batch_size=batch_size, 549 | unlimited_epoch=unlimited_epoch, 550 | alphabet_type=alphabet_type, 551 | reverse=reverse, 552 | matching=matching, 553 | output_shape=output_shape, 554 | output_types=output_types, 555 | ) 556 | self.dataset = dataset 557 | self.working_dir = working_dir 558 | 559 | self.name_to_sequence = {} 560 | self.clu1_to_seq_names = {} 561 | self.clu1_list = [] 562 | 563 | self.load_data() 564 | 565 | def load_data(self): 566 | max_seq_len = 0 567 | num_seqs = 0 568 | filename = os.path.join(self.working_dir, self.dataset) 569 | with open(filename, 'r') as fa: 570 | for i, (title, seq) in enumerate(SimpleFastaParser(fa)): 571 | name, clu1 = title.split(':') # TODO handle different split lengths 572 | valid = True 573 | for letter in seq: 574 | if letter not in self.aa_dict: 575 | valid = False 576 | if not valid: 577 | continue 578 | 579 | self.name_to_sequence[name] = seq 580 | if clu1 in self.clu1_to_seq_names: 581 | self.clu1_to_seq_names[clu1].append(name) 582 | else: 583 | self.clu1_to_seq_names[clu1] = [name] 584 | 585 | if len(seq) > max_seq_len: 586 | max_seq_len = len(seq) 587 | num_seqs += 1 588 | 589 | self.clu1_list = list(self.clu1_to_seq_names.keys()) 590 | self.max_seq_len = max_seq_len 591 | 592 | print("Num clusters:", len(self.clu1_list)) 593 | print("Num sequences:", num_seqs) 594 | print("Max sequence length:", max_seq_len) 595 | 596 | @property 597 | def n_eff(self): 598 | return len(self.clu1_list) 599 | 600 | def __getitem__(self, index): 601 | """ 602 | :param index: ignored 603 | :return: batch of size self.batch_size 604 | """ 605 | names = [] 606 | seqs = [] 607 | for i in range(self.batch_size): 608 | # Pick a cluster id90 609 | clu1_idx = np.random.randint(0, len(self.clu1_list)) 610 | clu1 = self.clu1_list[clu1_idx] 611 | 612 | # Then pick a random sequence all in those clusters 613 | seq_name = np.random.choice(self.clu1_to_seq_names[clu1]) 614 | names.append(seq_name) 615 | 616 | # then grab the associated sequence 617 | seqs.append(self.name_to_sequence[seq_name]) 618 | 619 | batch = self.sequences_to_onehot(seqs) 620 | batch['names'] = names 621 | batch['sequences'] = seqs 622 | return batch 623 | 624 | 625 | class DoubleClusteredSequenceDataset(SequenceDataset): 626 | def __init__( 627 | self, 628 | dataset='', 629 | working_dir='.', 630 | batch_size=32, 631 | unlimited_epoch=True, 632 | alphabet_type='protein', 633 | reverse=False, 634 | matching=False, 635 | output_shape='NCHW', 636 | output_types='decoder', 637 | ): 638 | super(DoubleClusteredSequenceDataset, self).__init__( 639 | batch_size=batch_size, 640 | unlimited_epoch=unlimited_epoch, 641 | alphabet_type=alphabet_type, 642 | reverse=reverse, 643 | matching=matching, 644 | output_shape=output_shape, 645 | output_types=output_types, 646 | ) 647 | self.dataset = dataset 648 | self.working_dir = working_dir 649 | 650 | self.name_to_sequence = {} 651 | self.clu1_to_clu2_to_seq_names = {} 652 | self.clu1_list = [] 653 | 654 | self.load_data() 655 | 656 | def load_data(self): 657 | max_seq_len = 0 658 | num_subclusters = 0 659 | num_seqs = 0 660 | filename = os.path.join(self.working_dir, self.dataset) 661 | with open(filename, 'r') as fa: 662 | for i, (title, seq) in enumerate(SimpleFastaParser(fa)): 663 | name, clu1, clu2 = title.split(':') 664 | valid = True 665 | for letter in seq: 666 | if letter not in self.aa_dict: 667 | valid = False 668 | if not valid: 669 | continue 670 | 671 | self.name_to_sequence[name] = seq 672 | if clu1 in self.clu1_to_clu2_to_seq_names: 673 | if clu2 in self.clu1_to_clu2_to_seq_names[clu1]: 674 | self.clu1_to_clu2_to_seq_names[clu1][clu2].append(name) 675 | else: 676 | self.clu1_to_clu2_to_seq_names[clu1][clu2] = [name] 677 | num_subclusters += 1 678 | else: 679 | self.clu1_to_clu2_to_seq_names[clu1] = {clu2: [name]} 680 | num_subclusters += 1 681 | 682 | if len(seq) > max_seq_len: 683 | max_seq_len = len(seq) 684 | num_seqs += 1 685 | 686 | self.clu1_list = list(self.clu1_to_clu2_to_seq_names.keys()) 687 | self.max_seq_len = max_seq_len 688 | 689 | print("Num clusters:", len(self.clu1_list)) 690 | print("Num subclusters:", num_subclusters) 691 | print("Num sequences:", num_seqs) 692 | print("Max sequence length:", max_seq_len) 693 | 694 | @property 695 | def n_eff(self): 696 | return len(self.clu1_list) 697 | 698 | def __getitem__(self, index): 699 | """ 700 | :param index: ignored 701 | :return: batch of size self.batch_size 702 | """ 703 | names = [] 704 | seqs = [] 705 | for i in range(self.batch_size): 706 | # Pick a cluster id80 707 | clu1_idx = np.random.randint(0, len(self.clu1_list)) 708 | clu1 = self.clu1_list[clu1_idx] 709 | 710 | # Then pick a cluster id90 from the cluster id80s 711 | clu2 = np.random.choice(list(self.clu1_to_clu2_to_seq_names[clu1].keys())) 712 | 713 | # Then pick a random sequence all in those clusters 714 | seq_name = np.random.choice(self.clu1_to_clu2_to_seq_names[clu1][clu2]) 715 | names.append(seq_name) 716 | 717 | # then grab the associated sequence 718 | seqs.append(self.name_to_sequence[seq_name]) 719 | 720 | batch = self.sequences_to_onehot(seqs) 721 | batch['names'] = names 722 | batch['sequences'] = seqs 723 | return batch 724 | 725 | 726 | class IndexedFastaDataset(SequenceDataset): 727 | """Load batches of sequences from an indexed fasta file, either sequentially or sampled isotropically""" 728 | 729 | def __init__( 730 | self, 731 | dataset='', 732 | working_dir='.', 733 | batch_size=32, 734 | unlimited_epoch=False, 735 | alphabet_type='protein', 736 | reverse=False, 737 | matching=False, 738 | output_shape='NCHW', 739 | output_types='decoder', 740 | # TODO add shuffle parameter: iterate through shuffled sequences 741 | ): 742 | super(IndexedFastaDataset, self).__init__( 743 | batch_size=batch_size, 744 | unlimited_epoch=unlimited_epoch, 745 | alphabet_type=alphabet_type, 746 | reverse=reverse, 747 | matching=matching, 748 | output_shape=output_shape, 749 | output_types=output_types, 750 | ) 751 | self.dataset = dataset 752 | self.dataset_f = os.path.join(working_dir, dataset) 753 | self.dataset_idx = dataset + '.fai' 754 | self.working_dir = working_dir 755 | 756 | self.names = None 757 | self.offsets = None 758 | 759 | self.load_data() 760 | 761 | def load_data(self): 762 | filename = os.path.join(self.working_dir, self.dataset_idx) 763 | num_seqs = 0 764 | max_name_len = 0 765 | max_seq_len = 0 766 | with open(filename, 'rb') as f: 767 | for line in f: 768 | line = line.split(b'\t') 769 | name_len, seq_len = len(line[0]), int(line[1]) 770 | num_seqs += 1 771 | if name_len > max_name_len: 772 | max_name_len = name_len 773 | if seq_len > max_seq_len: 774 | max_seq_len = seq_len 775 | 776 | self.names = np.empty(num_seqs, dtype=f' max_seq_len: 878 | max_seq_len = length 879 | num_seqs += 1 880 | 881 | self.clu1_list = list(self.clu1_to_offset.keys()) 882 | self.max_seq_len = max_seq_len 883 | 884 | print("Num clusters:", len(self.clu1_list)) 885 | print("Num sequences:", num_seqs) 886 | print("Max sequence length:", max_seq_len) 887 | 888 | @property 889 | def n_eff(self): 890 | return len(self.clu1_list) 891 | 892 | def __getitem__(self, index): 893 | """ 894 | :param index: ignored 895 | :return: batch of size self.batch_size 896 | """ 897 | offsets = [] 898 | seqs = [] 899 | 900 | with open(self.dataset_f, 'rt', buffering=1) as dataset_f: 901 | for i in range(self.batch_size): 902 | # Pick a cluster id80 903 | clu1_idx = np.random.randint(0, len(self.clu1_list)) 904 | clu1 = self.clu1_list[clu1_idx] 905 | 906 | # Then pick a cluster id90 from the cluster id80s 907 | clu2 = np.random.choice(list(self.clu1_to_offset[clu1].keys())) 908 | 909 | # Then pick a random sequence all in those clusters 910 | offset = np.random.choice(self.clu1_to_offset[clu1][clu2]) 911 | offsets.append(offset) 912 | 913 | # then grab the associated sequence 914 | dataset_f.seek(offset) 915 | seqs.append(dataset_f.readline().rstrip()) 916 | 917 | try: 918 | batch = self.sequences_to_onehot(seqs) 919 | return batch 920 | except KeyError as e: 921 | print(offsets, seqs, flush=True) 922 | raise e 923 | 924 | 925 | class DoubleClusteredIndexedSequenceDataset(SequenceDataset): 926 | """ 927 | Reads an indexed fasta dataset. 928 | Index filename is (fasta_file).fai 929 | Create index using `samtools faidx (fasta file)`. 930 | """ 931 | 932 | def __init__( 933 | self, 934 | dataset='', 935 | working_dir='.', 936 | batch_size=32, 937 | unlimited_epoch=True, 938 | alphabet_type='protein', 939 | reverse=False, 940 | matching=False, 941 | output_shape='NCHW', 942 | output_types='decoder', 943 | ): 944 | super(DoubleClusteredIndexedSequenceDataset, self).__init__( 945 | batch_size=batch_size, 946 | unlimited_epoch=unlimited_epoch, 947 | alphabet_type=alphabet_type, 948 | reverse=reverse, 949 | matching=matching, 950 | output_shape=output_shape, 951 | output_types=output_types, 952 | ) 953 | self.dataset = dataset 954 | self.dataset_f = os.path.join(working_dir, dataset) 955 | self.dataset_idx = dataset + '.fai' 956 | self.working_dir = working_dir 957 | 958 | self.clu1_to_clu2_to_offset = {} 959 | self.clu1_list = [] 960 | 961 | self.load_data() 962 | 963 | def load_data(self): 964 | filename = os.path.join(self.working_dir, self.dataset_idx) 965 | max_seq_len = 0 966 | num_subclusters = 0 967 | num_seqs = 0 968 | with open(filename, 'rb') as f: 969 | for line in f: 970 | line = line.split(b'\t') 971 | title, length, offset = line[0], int(line[1]), int(line[2]) 972 | name, clu1, clu2 = title.split(b':') 973 | 974 | if clu1 in self.clu1_to_clu2_to_offset: 975 | if clu2 in self.clu1_to_clu2_to_offset[clu1]: 976 | self.clu1_to_clu2_to_offset[clu1][clu2].append(offset) 977 | else: 978 | self.clu1_to_clu2_to_offset[clu1][clu2] = [offset] 979 | num_subclusters += 1 980 | else: 981 | self.clu1_to_clu2_to_offset[clu1] = {clu2: [offset]} 982 | num_subclusters += 1 983 | 984 | if length > max_seq_len: 985 | max_seq_len = length 986 | num_seqs += 1 987 | 988 | self.clu1_list = list(self.clu1_to_clu2_to_offset.keys()) 989 | self.max_seq_len = max_seq_len 990 | 991 | print("Num clusters:", len(self.clu1_list)) 992 | print("Num subclusters:", num_subclusters) 993 | print("Num sequences:", num_seqs) 994 | print("Max sequence length:", max_seq_len) 995 | 996 | @property 997 | def n_eff(self): 998 | return len(self.clu1_list) 999 | 1000 | def __getitem__(self, index): 1001 | """ 1002 | :param index: ignored 1003 | :return: batch of size self.batch_size 1004 | """ 1005 | offsets = [] 1006 | seqs = [] 1007 | 1008 | with open(self.dataset_f, 'rt', buffering=1) as dataset_f: 1009 | for i in range(self.batch_size): 1010 | # Pick a cluster id80 1011 | clu1_idx = np.random.randint(0, len(self.clu1_list)) 1012 | clu1 = self.clu1_list[clu1_idx] 1013 | 1014 | # Then pick a cluster id90 from the cluster id80s 1015 | clu2 = np.random.choice(list(self.clu1_to_clu2_to_offset[clu1].keys())) 1016 | 1017 | # Then pick a random sequence all in those clusters 1018 | offset = np.random.choice(self.clu1_to_clu2_to_offset[clu1][clu2]) 1019 | offsets.append(offset) 1020 | 1021 | # then grab the associated sequence 1022 | dataset_f.seek(offset) 1023 | seqs.append(dataset_f.readline().rstrip()) 1024 | 1025 | try: 1026 | batch = self.sequences_to_onehot(seqs) 1027 | return batch 1028 | except KeyError as e: 1029 | print(offsets, seqs, flush=True) 1030 | raise e 1031 | 1032 | 1033 | class AntibodySequenceDataset(SequenceDataset): 1034 | IPI_VL_SEQS = ['VK1-39', 'VL1-51', 'VK3-15'] 1035 | IPI_VH_SEQS = ['VH1-46', 'VH1-69', 'VH3-7', 'VH3-15', 'VH4-39', 'VH5-51'] 1036 | LABELED = False 1037 | 1038 | def __init__( 1039 | self, 1040 | dataset='', 1041 | working_dir='.', 1042 | batch_size=32, 1043 | unlimited_epoch=True, 1044 | alphabet_type='protein', 1045 | reverse=False, 1046 | matching=False, 1047 | output_shape='NCHW', 1048 | output_types='encoder', 1049 | include_vl=False, 1050 | include_vh=False, 1051 | ): 1052 | SequenceDataset.__init__( 1053 | self, 1054 | batch_size=batch_size, 1055 | unlimited_epoch=unlimited_epoch, 1056 | alphabet_type=alphabet_type, 1057 | reverse=reverse, 1058 | matching=matching, 1059 | output_shape=output_shape, 1060 | output_types=output_types, 1061 | ) 1062 | self.dataset = dataset 1063 | self.working_dir = working_dir 1064 | self.include_vl = include_vl 1065 | self.include_vh = include_vh 1066 | 1067 | self.vl_list = self.IPI_VL_SEQS.copy() 1068 | self.vh_list = self.IPI_VH_SEQS.copy() 1069 | 1070 | @property 1071 | def light_to_idx(self): 1072 | if self.vh_list is None: 1073 | raise RuntimeError("VL list not loaded.") 1074 | else: 1075 | return {vh: i for i, vh in enumerate(self.vl_list)} 1076 | 1077 | @property 1078 | def heavy_to_idx(self): 1079 | if self.vh_list is None: 1080 | raise RuntimeError("VH list not loaded.") 1081 | else: 1082 | return {vh: i for i, vh in enumerate(self.vh_list)} 1083 | 1084 | @property 1085 | def input_dim(self): 1086 | input_dim = len(self.alphabet) 1087 | if self.include_vl: 1088 | input_dim += len(self.light_to_idx) 1089 | if self.include_vh: 1090 | input_dim += len(self.heavy_to_idx) 1091 | return input_dim 1092 | 1093 | @property 1094 | def params(self): 1095 | params = super(AntibodySequenceDataset, self).params 1096 | params.update({ 1097 | "include_vl": self.include_vl, 1098 | "include_vh": self.include_vh, 1099 | "vl_seqs": self.vl_list, 1100 | "vh_seqs": self.vh_list, 1101 | }) 1102 | return params 1103 | 1104 | @params.setter 1105 | def params(self, d): 1106 | if 'for_decoder' in d: 1107 | if d['for_decoder']: 1108 | d['output_types'] = 'decoder' 1109 | else: 1110 | d['output_types'] = 'encoder' 1111 | SequenceDataset.params.__set__(self, d) 1112 | if 'include_vl' in d: 1113 | self.include_vl = d['include_vl'] 1114 | if 'include_vh' in d: 1115 | self.include_vh = d['include_vh'] 1116 | if 'vl_seqs' in d: 1117 | self.vl_list = d['vl_seqs'] 1118 | if 'vh_seqs' in d: 1119 | self.vh_list = d['vh_seqs'] 1120 | 1121 | def sequences_to_onehot(self, sequences, vls=None, vhs=None, reverse=None, matching=None): 1122 | reverse = self.reverse if reverse is None else reverse 1123 | matching = self.matching if matching is None else matching 1124 | num_seqs = len(sequences) 1125 | for i in range(num_seqs): 1126 | # normalize CDR3 sequences to exclude constant characters 1127 | # TODO add strip_cw param 1128 | if sequences[i][0] == 'C': 1129 | sequences[i] = sequences[i][1:] 1130 | if sequences[i][-1] == 'W': 1131 | sequences[i] = sequences[i][:-1] 1132 | 1133 | if 'decoder' in self.output_types: 1134 | seq_arr, seq_output_arr, seq_mask, seq_arr_r, seq_output_arr_r = sequences_to_decoder_onehot( 1135 | sequences, self.aa_dict, self.output_aa_dict, reverse=reverse, matching=matching 1136 | ) 1137 | else: 1138 | seq_arr, seq_mask = sequences_to_encoder_onehot(sequences, self.aa_dict) 1139 | seq_output_arr = seq_arr_r = seq_output_arr_r = None 1140 | 1141 | light_arr = heavy_arr = None 1142 | if self.include_vl: 1143 | light_arr = np.zeros((num_seqs, len(self.light_to_idx), 1, seq_arr.shape[-1])) 1144 | if self.include_vh: 1145 | heavy_arr = np.zeros((num_seqs, len(self.heavy_to_idx), 1, seq_arr.shape[-1])) 1146 | 1147 | for i in range(num_seqs): 1148 | if self.include_vl: 1149 | light_arr[i, self.light_to_idx[vls[i]], 0, :] = 1. 1150 | if self.include_vl: 1151 | heavy_arr[i, self.heavy_to_idx[vhs[i]], 0, :] = 1. 1152 | 1153 | if self.include_vl: 1154 | seq_arr = np.concatenate((seq_arr, light_arr), axis=1) 1155 | if seq_arr_r is not None: 1156 | seq_arr_r = np.concatenate((seq_arr_r, light_arr), axis=1) 1157 | if self.include_vh: 1158 | seq_arr = np.concatenate((seq_arr, heavy_arr), axis=1) 1159 | if seq_arr_r is not None: 1160 | seq_arr_r = np.concatenate((seq_arr_r, heavy_arr), axis=1) 1161 | 1162 | if 'decoder' in self.output_types: 1163 | output = { 1164 | 'decoder_input': seq_arr, 1165 | 'decoder_mask': seq_mask, 1166 | 'decoder_output': seq_output_arr, 1167 | 'decoder_input_r': seq_arr_r, 1168 | 'decoder_output_r': seq_output_arr_r, 1169 | } 1170 | else: 1171 | output = {'encoder_input': seq_arr, 'encoder_mask': seq_mask} 1172 | for key in output.keys(): 1173 | output[key] = torch.as_tensor(output[key], dtype=torch.float32) 1174 | if self.output_shape == 'NHWC': 1175 | output[key] = output[key].permute(0, 2, 3, 1).contiguous() 1176 | elif self.output_shape == 'NLC': 1177 | output[key] = output[key].squeeze(2).permute(0, 2, 1).contiguous() 1178 | return output 1179 | 1180 | @property 1181 | def n_eff(self): 1182 | raise NotImplementedError 1183 | 1184 | def __getitem__(self, item): 1185 | raise NotImplementedError 1186 | 1187 | 1188 | class IPIFastaDataset(AntibodySequenceDataset): 1189 | """Unweighted antibody dataset. 1190 | fasta: >*_heavy-{VH}_light-{VL}* 1191 | """ 1192 | 1193 | def __init__( 1194 | self, 1195 | dataset='', 1196 | working_dir='.', 1197 | batch_size=32, 1198 | unlimited_epoch=True, 1199 | alphabet_type='protein', 1200 | reverse=False, 1201 | matching=False, 1202 | output_shape='NCHW', 1203 | output_types='decoder', 1204 | include_vl=False, 1205 | include_vh=False, 1206 | ): 1207 | AntibodySequenceDataset.__init__( 1208 | self, 1209 | batch_size=batch_size, 1210 | unlimited_epoch=unlimited_epoch, 1211 | alphabet_type=alphabet_type, 1212 | reverse=reverse, 1213 | matching=matching, 1214 | output_shape=output_shape, 1215 | output_types=output_types, 1216 | include_vl=include_vl, 1217 | include_vh=include_vh, 1218 | ) 1219 | self.dataset = dataset 1220 | self.working_dir = working_dir 1221 | 1222 | self.names = None 1223 | self.sequences = None 1224 | self.vhs = None 1225 | self.vls = None 1226 | 1227 | self.load_data() 1228 | 1229 | def load_data(self): 1230 | filename = os.path.join(self.working_dir, self.dataset) 1231 | names_list = [] 1232 | sequence_list = [] 1233 | vhs = [] 1234 | vls = [] 1235 | max_seq_len = 0 1236 | skipped_seqs = 0 1237 | name_pat = re.compile(r'_heavy-([A-Z0-9\-]+)_light-([A-Z0-9\-]+)') 1238 | 1239 | with open(filename, 'r') as fa: 1240 | for i, (title, seq) in enumerate(SimpleFastaParser(fa)): 1241 | valid = True 1242 | for letter in seq: 1243 | if letter not in self.aa_dict: 1244 | valid = False 1245 | if not valid: 1246 | skipped_seqs += 1 1247 | continue 1248 | 1249 | match = name_pat.search(title) 1250 | if not match or match.group(1) not in self.vh_list or match.group(2) not in self.vl_list: 1251 | skipped_seqs += 1 1252 | continue 1253 | 1254 | names_list.append(title) 1255 | sequence_list.append(seq) 1256 | if len(seq) > max_seq_len: 1257 | max_seq_len = len(seq) 1258 | vhs.append(match.group(1)) 1259 | vls.append(match.group(2)) 1260 | 1261 | self.names = np.array(names_list) 1262 | self.sequences = np.array(sequence_list) 1263 | self.max_seq_len = max_seq_len 1264 | self.vhs = np.array(vhs) 1265 | self.vls = np.array(vls) 1266 | 1267 | print("Number of sequences:", self.n_eff) 1268 | print("Number of sequences skipped:", skipped_seqs) 1269 | print("Max sequence length:", max_seq_len) 1270 | 1271 | @property 1272 | def n_eff(self): 1273 | return len(self.sequences) # not a true n_eff 1274 | 1275 | def __getitem__(self, index): 1276 | """ 1277 | :param index: batch index; ignored if unlimited_epoch 1278 | :return: batch of size self.batch_size 1279 | """ 1280 | 1281 | if self.unlimited_epoch: 1282 | indices = np.random.randint(0, self.n_eff, self.batch_size) 1283 | else: 1284 | first_index = index * self.batch_size 1285 | last_index = min((index + 1) * self.batch_size, self.n_eff) 1286 | indices = np.arange(first_index, last_index) 1287 | 1288 | seqs = self.sequences[indices] 1289 | names = self.names[indices] 1290 | vhs = self.vhs[indices] 1291 | vls = self.vls[indices] 1292 | batch = self.sequences_to_onehot(seqs, vhs=vhs, vls=vls) 1293 | batch['names'] = names 1294 | batch['sequences'] = seqs 1295 | return batch 1296 | 1297 | 1298 | class IPISingleClusteredSequenceDataset(AntibodySequenceDataset): 1299 | """Single-weighted antibody dataset. 1300 | fasta: >seq:vh:vl:clu1 1301 | clu1: cluster id 1302 | """ 1303 | 1304 | def __init__( 1305 | self, 1306 | dataset='', 1307 | working_dir='.', 1308 | batch_size=32, 1309 | unlimited_epoch=True, 1310 | alphabet_type='protein', 1311 | reverse=False, 1312 | matching=False, 1313 | output_shape='NCHW', 1314 | output_types='decoder', 1315 | include_vl=False, 1316 | include_vh=False, 1317 | ): 1318 | AntibodySequenceDataset.__init__( 1319 | self, 1320 | batch_size=batch_size, 1321 | unlimited_epoch=unlimited_epoch, 1322 | alphabet_type=alphabet_type, 1323 | reverse=reverse, 1324 | matching=matching, 1325 | output_shape=output_shape, 1326 | output_types=output_types, 1327 | include_vl=include_vl, 1328 | include_vh=include_vh, 1329 | ) 1330 | self.dataset = dataset 1331 | self.working_dir = working_dir 1332 | 1333 | self.name_to_sequence = {} 1334 | self.clu1_to_seq_names = {} 1335 | self.clu1_list = [] 1336 | 1337 | self.load_data() 1338 | 1339 | def load_data(self): 1340 | max_seq_len = 0 1341 | num_seqs = 0 1342 | filename = os.path.join(self.working_dir, self.dataset) 1343 | with open(filename, 'r') as fa: 1344 | for i, (title, seq) in enumerate(SimpleFastaParser(fa)): 1345 | name, vh, vl, clu1 = title.split(':') 1346 | valid = True 1347 | for letter in seq: 1348 | if letter not in self.aa_dict: 1349 | valid = False 1350 | if not valid: 1351 | continue 1352 | elif vh not in self.vh_list: 1353 | print(f"Unrecognized VH gene: {vh}") 1354 | continue 1355 | elif vl not in self.vl_list: 1356 | print(f"Unrecognized VL gene: {vl}") 1357 | continue 1358 | 1359 | name = f"{name}:{vh}:{vl}" 1360 | if name in self.name_to_sequence: 1361 | print(f"Name collision: {name}") 1362 | self.name_to_sequence[name] = seq 1363 | if clu1 in self.clu1_to_seq_names: 1364 | self.clu1_to_seq_names[clu1].append(name) 1365 | else: 1366 | self.clu1_to_seq_names[clu1] = [name] 1367 | 1368 | if len(seq) > max_seq_len: 1369 | max_seq_len = len(seq) 1370 | num_seqs += 1 1371 | 1372 | self.clu1_list = list(self.clu1_to_seq_names.keys()) 1373 | self.max_seq_len = max_seq_len 1374 | 1375 | print("Num clusters:", len(self.clu1_list)) 1376 | print("Num sequences:", num_seqs) 1377 | print("Max sequence length:", max_seq_len) 1378 | 1379 | @property 1380 | def n_eff(self): 1381 | return len(self.clu1_list) 1382 | 1383 | def __getitem__(self, index): 1384 | """ 1385 | :param index: ignored 1386 | :return: batch of size self.batch_size 1387 | """ 1388 | names = [] 1389 | vhs = [] 1390 | vls = [] 1391 | seqs = [] 1392 | for i in range(self.batch_size): 1393 | # Pick a cluster id90 1394 | clu1_idx = np.random.randint(0, len(self.clu1_list)) 1395 | clu1 = self.clu1_list[clu1_idx] 1396 | 1397 | # Then pick a random sequence all in those clusters 1398 | seq_name = np.random.choice(self.clu1_to_seq_names[clu1]) 1399 | _, vh, vl = seq_name.split(':') 1400 | names.append(seq_name) 1401 | vhs.append(vh) 1402 | vls.append(vl) 1403 | 1404 | # then grab the associated sequence 1405 | seqs.append(self.name_to_sequence[seq_name]) 1406 | 1407 | batch = self.sequences_to_onehot(seqs, vhs=vhs, vls=vls) 1408 | batch['names'] = names 1409 | batch['sequences'] = seqs 1410 | return batch 1411 | 1412 | 1413 | class IPITrainTestDataset(AntibodySequenceDataset, TrainTestDataset): 1414 | LABELED = True 1415 | 1416 | def __init__( 1417 | self, 1418 | dataset='', 1419 | working_dir='.', 1420 | batch_size=32, 1421 | unlimited_epoch=True, 1422 | alphabet_type='protein', 1423 | reverse=False, 1424 | matching=False, 1425 | output_shape='NCHW', 1426 | output_types='encoder', 1427 | comparisons=(('Aff1', 'PSR1', 0., 0.),), # before, after, thresh_before, thresh_after 1428 | train_test_split=1.0, 1429 | split_seed=42, 1430 | include_vl=False, 1431 | include_vh=False, 1432 | ): 1433 | AntibodySequenceDataset.__init__( 1434 | self, 1435 | batch_size=batch_size, 1436 | unlimited_epoch=unlimited_epoch, 1437 | alphabet_type=alphabet_type, 1438 | reverse=reverse, 1439 | matching=matching, 1440 | output_shape=output_shape, 1441 | output_types=output_types, 1442 | include_vl=include_vl, 1443 | include_vh=include_vh, 1444 | ) 1445 | TrainTestDataset.__init__(self) 1446 | self.dataset = dataset 1447 | self.working_dir = working_dir 1448 | self.comparisons = comparisons 1449 | self.train_test_split = train_test_split 1450 | self.split_seed = split_seed 1451 | self.include_vl = include_vl 1452 | self.include_vh = include_vh 1453 | 1454 | self.cdr_to_output = {} 1455 | self.cdr_to_heavy = {} 1456 | self.cdr_to_light = {} 1457 | self.all_cdr_seqs = [] 1458 | self.cdr_seqs_train = [] 1459 | self.cdr_seqs_test = [] 1460 | self.comparison_pos_weights = torch.ones(len(comparisons)) 1461 | 1462 | self.load_data() 1463 | 1464 | def load_data(self): 1465 | seq_col = 'CDR3' 1466 | heavy_col = 'heavy' 1467 | light_col = 'light' 1468 | count_cols = list({col: None for comparison in self.comparisons for col in comparison[:2]}.keys()) 1469 | 1470 | # load data file 1471 | filename = os.path.join(self.working_dir, self.dataset) 1472 | use_cols = [seq_col, heavy_col, light_col] + count_cols 1473 | df = pd.read_csv(filename, usecols=use_cols) 1474 | df[count_cols] = df[count_cols].fillna(0.) 1475 | 1476 | # load output data 1477 | comparison_cdr_to_output = [] 1478 | for i_comparison, comparison in enumerate(self.comparisons): 1479 | before, after, before_threshold, after_threshold = comparison 1480 | comp_df = df.loc[(df[before] > before_threshold) | (df[after] > after_threshold), :] 1481 | 1482 | comp_out = pd.Series((comp_df[after] > after_threshold).astype(int).values, index=comp_df[seq_col]) 1483 | pos_weight = (len(comp_out) - comp_out.sum()) / comp_out.sum() 1484 | comparison_cdr_to_output.append(comp_out.to_dict()) 1485 | self.comparison_pos_weights[i_comparison] = pos_weight 1486 | print(f'comparison: {comparison}, {len(comp_out)} seqs, ' 1487 | f'{comp_out.mean() * 100:0.1f}% positive, {pos_weight:0.4f} pos_weight') 1488 | 1489 | # keep only sequences with all output information 1490 | all_cdrs = set.intersection(*(set(d.keys()) for d in comparison_cdr_to_output)) 1491 | df = df[df[seq_col].isin(all_cdrs)] 1492 | self.all_cdr_seqs = df[seq_col].values 1493 | print(f'total seqs after intersection: {len(self.all_cdr_seqs)}') 1494 | 1495 | # split data into train-test 1496 | with temp_seed(self.split_seed): 1497 | indices = np.random.permutation(len(self.all_cdr_seqs)) 1498 | partition = math.ceil(len(indices) * self.train_test_split) 1499 | training_idx, test_idx = indices[:partition], indices[partition:] 1500 | self.cdr_seqs_train, self.cdr_seqs_test = self.all_cdr_seqs[training_idx], self.all_cdr_seqs[test_idx] 1501 | print(f'train-test split: {self.train_test_split}') 1502 | print(f'num train, test seqs: {len(self.cdr_seqs_train)}, {len(self.cdr_seqs_test)}') 1503 | 1504 | # make table of output values 1505 | self.cdr_to_output = {} 1506 | for cdr in df[seq_col]: 1507 | output = [] 1508 | for d in comparison_cdr_to_output: 1509 | output.append(d.get(cdr, 0)) 1510 | self.cdr_to_output[cdr] = output 1511 | 1512 | df = df.set_index(seq_col) 1513 | self.cdr_to_heavy = df[heavy_col].to_dict() 1514 | self.cdr_to_light = df[light_col].to_dict() 1515 | 1516 | @property 1517 | def n_eff(self): 1518 | return len(self.cdr_seqs) 1519 | 1520 | @property 1521 | def cdr_seqs(self): 1522 | if self._training: 1523 | return self.cdr_seqs_train 1524 | else: 1525 | return self.cdr_seqs_test 1526 | 1527 | def __getitem__(self, index): 1528 | if self.unlimited_epoch: 1529 | indices = np.random.randint(0, self.n_eff, self.batch_size) 1530 | else: 1531 | first_index = index * self.batch_size 1532 | last_index = min((index+1) * self.batch_size, self.n_eff) 1533 | indices = np.arange(first_index, last_index) 1534 | 1535 | seqs = self.cdr_seqs[indices].tolist() 1536 | label_arr = torch.zeros(len(indices), len(self.comparisons)) 1537 | for i, seq in enumerate(seqs): 1538 | for j, output in enumerate(self.cdr_to_output[seq]): 1539 | label_arr[i, j] = output 1540 | 1541 | if len(seqs) == 0: 1542 | return None 1543 | vls = [self.cdr_to_light[cdr] for cdr in seqs] 1544 | vhs = [self.cdr_to_heavy[cdr] for cdr in seqs] 1545 | batch = self.sequences_to_onehot(seqs, vls=vls, vhs=vhs) 1546 | batch['label'] = label_arr 1547 | return batch 1548 | 1549 | 1550 | class VHAntibodyDataset(AntibodySequenceDataset): 1551 | """Abstract antibody dataset""" 1552 | IPI_VH_SEQS = ['IGHV1-46', 'IGHV1-69', 'IGHV3-7', 'IGHV3-15', 'IGHV4-39', 'IGHV5-51'] # TODO IGHV1-69D? 1553 | LABELED = False 1554 | 1555 | def __init__( 1556 | self, 1557 | batch_size=32, 1558 | unlimited_epoch=True, 1559 | alphabet_type='protein', 1560 | reverse=False, 1561 | matching=False, 1562 | output_shape='NCHW', 1563 | output_types='encoder', 1564 | include_vh=False, 1565 | vh_set_name='IPI', 1566 | ): 1567 | super(VHAntibodyDataset, self).__init__( 1568 | batch_size=batch_size, 1569 | unlimited_epoch=unlimited_epoch, 1570 | alphabet_type=alphabet_type, 1571 | reverse=reverse, 1572 | matching=matching, 1573 | output_shape=output_shape, 1574 | output_types=output_types, 1575 | include_vl=False, 1576 | include_vh=include_vh, 1577 | ) 1578 | self.vh_set_name = vh_set_name 1579 | 1580 | self._n_eff = 1 1581 | if self.vh_set_name == 'IPI': 1582 | self.vh_list = self.IPI_VH_SEQS.copy() 1583 | else: 1584 | self.vh_list = None 1585 | 1586 | @property 1587 | def input_dim(self): 1588 | input_dim = len(self.alphabet) 1589 | if self.include_vh: 1590 | input_dim += len(self.heavy_to_idx) 1591 | return input_dim 1592 | 1593 | @property 1594 | def params(self): 1595 | params = super(VHAntibodyDataset, self).params 1596 | params.pop('vl_seqs', None) 1597 | params.pop('include_vl', None) 1598 | params.update({ 1599 | "vh_set_name": self.vh_set_name, 1600 | "vh_seqs": self.vh_list, 1601 | }) 1602 | return params 1603 | 1604 | @params.setter 1605 | def params(self, d): 1606 | d.pop('vl_seqs', None) 1607 | d.pop('include_vl', None) 1608 | AntibodySequenceDataset.params.__set__(self, d) 1609 | if 'vh_set_name' in d: 1610 | self.vh_set_name = d['vh_set_name'] 1611 | if self.vh_set_name == 'IPI': 1612 | self.vh_list = self.IPI_VH_SEQS.copy() 1613 | else: 1614 | self.vh_list = None 1615 | if 'vh_seqs' in d: 1616 | self.vh_list = d['vh_seqs'] 1617 | 1618 | @property 1619 | def n_eff(self): 1620 | """Number of clusters across all VH genes""" 1621 | return self._n_eff 1622 | 1623 | def __getitem__(self, index): 1624 | raise NotImplementedError 1625 | 1626 | 1627 | class VHAntibodyFastaDataset(VHAntibodyDataset): 1628 | """Antibody dataset with VH sequences. 1629 | fasta: >seq(:.+)*:VH_gene 1630 | """ 1631 | 1632 | def __init__( 1633 | self, 1634 | dataset='', 1635 | working_dir='.', 1636 | batch_size=32, 1637 | unlimited_epoch=True, 1638 | alphabet_type='protein', 1639 | reverse=False, 1640 | matching=False, 1641 | output_shape='NCHW', 1642 | output_types='decoder', 1643 | include_vh=False, 1644 | vh_set_name='IPI', 1645 | ): 1646 | super(VHAntibodyFastaDataset, self).__init__( 1647 | batch_size=batch_size, 1648 | unlimited_epoch=unlimited_epoch, 1649 | alphabet_type=alphabet_type, 1650 | reverse=reverse, 1651 | matching=matching, 1652 | output_shape=output_shape, 1653 | output_types=output_types, 1654 | include_vh=include_vh, 1655 | vh_set_name=vh_set_name, 1656 | ) 1657 | self.dataset = dataset 1658 | self.working_dir = working_dir 1659 | 1660 | self.names = None 1661 | self.vh_genes = None 1662 | self.sequences = None 1663 | 1664 | self.load_data() 1665 | 1666 | def load_data(self): 1667 | filename = os.path.join(self.working_dir, self.dataset) 1668 | names_list = [] 1669 | vh_genes_list = [] 1670 | sequence_list = [] 1671 | 1672 | with open(filename, 'r') as fa: 1673 | for i, (title, seq) in enumerate(SimpleFastaParser(fa)): 1674 | valid = True 1675 | for letter in seq: 1676 | if letter not in self.aa_dict: 1677 | valid = False 1678 | if not valid: 1679 | continue 1680 | 1681 | names_list.append(title) 1682 | vh_genes_list.append(title.split(':')[-1]) 1683 | sequence_list.append(seq) 1684 | 1685 | self.names = np.array(names_list) 1686 | self.vh_genes = np.array(vh_genes_list) 1687 | self.sequences = np.array(sequence_list) 1688 | 1689 | print("Number of sequences:", self.n_eff) 1690 | 1691 | @property 1692 | def n_eff(self): 1693 | return len(self.sequences) # not a true n_eff 1694 | 1695 | def __getitem__(self, index): 1696 | """ 1697 | :param index: batch index; ignored if unlimited_epoch 1698 | :return: batch of size self.batch_size 1699 | """ 1700 | 1701 | if self.unlimited_epoch: 1702 | indices = np.random.randint(0, self.n_eff, self.batch_size) 1703 | else: 1704 | first_index = index * self.batch_size 1705 | last_index = min((index + 1) * self.batch_size, self.n_eff) 1706 | indices = np.arange(first_index, last_index) 1707 | 1708 | seqs = self.sequences[indices] 1709 | vhs = self.vh_genes[indices] 1710 | batch = self.sequences_to_onehot(seqs, vhs=vhs) 1711 | batch['names'] = self.names[indices] 1712 | batch['sequences'] = [seq for seq, vh in seqs] 1713 | return batch 1714 | 1715 | 1716 | class VHClusteredAntibodyDataset(VHAntibodyDataset): 1717 | """Double-weighted antibody dataset. 1718 | fasta: >seq:clu1:clu2 1719 | clu1: VH gene 1720 | clu2: cluster id 1721 | """ 1722 | 1723 | def __init__( 1724 | self, 1725 | dataset='', 1726 | working_dir='.', 1727 | batch_size=32, 1728 | unlimited_epoch=True, 1729 | alphabet_type='protein', 1730 | reverse=False, 1731 | matching=False, 1732 | output_shape='NCHW', 1733 | output_types='decoder', 1734 | include_vh=False, 1735 | vh_set_name='IPI', 1736 | ): 1737 | super(VHClusteredAntibodyDataset, self).__init__( 1738 | batch_size=batch_size, 1739 | unlimited_epoch=unlimited_epoch, 1740 | alphabet_type=alphabet_type, 1741 | reverse=reverse, 1742 | matching=matching, 1743 | output_shape=output_shape, 1744 | output_types=output_types, 1745 | include_vh=include_vh, 1746 | vh_set_name=vh_set_name, 1747 | ) 1748 | self.dataset = dataset 1749 | self.working_dir = working_dir 1750 | 1751 | self.clu1_to_clu2s = {} 1752 | self.clu1_to_clu2_to_seqs = {} 1753 | 1754 | self.load_data() 1755 | 1756 | @property 1757 | def clu1_list(self): 1758 | return self.vh_list 1759 | 1760 | def load_data(self): 1761 | filename = self.working_dir + '/datasets/' + self.dataset 1762 | with open(filename, 'r') as fa: 1763 | for i, (title, seq) in enumerate(SimpleFastaParser(fa)): 1764 | name, clu1, clu2 = title.split(':') 1765 | valid = True 1766 | for letter in seq: 1767 | if letter not in self.aa_dict: 1768 | valid = False 1769 | if not valid: 1770 | continue 1771 | 1772 | if clu1 in self.clu1_to_clu2_to_seqs: 1773 | if clu2 in self.clu1_to_clu2_to_seqs[clu1]: 1774 | self.clu1_to_clu2_to_seqs[clu1][clu2].append(seq) 1775 | else: 1776 | self.clu1_to_clu2s[clu1].append(clu2) 1777 | self.clu1_to_clu2_to_seqs[clu1][clu2] = [seq] 1778 | else: 1779 | self.clu1_to_clu2s[clu1] = [clu2] 1780 | self.clu1_to_clu2_to_seqs[clu1] = {clu2: [seq]} 1781 | 1782 | if self.clu1_list is None: 1783 | self.vh_list = list(self.clu1_to_clu2_to_seqs.keys()) 1784 | self._n_eff = sum(len(clu2s) for clu2s in self.clu1_to_clu2s.values()) 1785 | print("Num VH genes:", len(self.clu1_list)) 1786 | print("N_eff:", self.n_eff) 1787 | 1788 | def __getitem__(self, index): 1789 | """ 1790 | :param index: ignored 1791 | :return: batch of size self.batch_size 1792 | """ 1793 | seqs = [] 1794 | vhs = [] 1795 | for i in range(self.batch_size): 1796 | # Pick a VH gene 1797 | clu1_idx = np.random.randint(0, len(self.clu1_list)) 1798 | clu1 = self.clu1_list[clu1_idx] 1799 | 1800 | # Then pick a cluster for that VH gene 1801 | clu2_list = self.clu1_to_clu2s[clu1] 1802 | clu2_idx = np.random.randint(0, len(clu2_list)) 1803 | clu2 = clu2_list[clu2_idx] 1804 | 1805 | # Then pick a random sequence from the cluster 1806 | clu_seqs = self.clu1_to_clu2_to_seqs[clu1][clu2] 1807 | seq_idx = np.random.randint(0, len(clu_seqs)) 1808 | seqs.append(clu_seqs[seq_idx]) 1809 | vhs.append(clu1) 1810 | 1811 | batch = self.sequences_to_onehot(seqs, vhs=vhs) 1812 | return batch 1813 | -------------------------------------------------------------------------------- /src/seqdesign_pt/functions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.autograd as autograd 6 | 7 | 8 | def gelu(x): 9 | """BERT's implementation of the gelu activation function. 10 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 11 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 12 | """ 13 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 14 | 15 | 16 | def swish(x): 17 | return x * torch.sigmoid(x) 18 | 19 | 20 | ACT_TO_FUN = { 21 | 'elu': F.elu, 22 | 'relu': F.relu, 23 | 'lrelu': F.leaky_relu, 24 | 'gelu': gelu, 25 | 'swish': swish, 26 | 'none': lambda x: x, 27 | } 28 | 29 | 30 | def nonlinearity(nonlin_type): 31 | return ACT_TO_FUN[nonlin_type] 32 | 33 | 34 | def comb_losses(losses_f, losses_r): 35 | losses_comb = {} 36 | for key in losses_f.keys(): 37 | if 'per_seq' in key or 'per_char' in key: 38 | losses_comb[key + '_f'] = losses_f[key] 39 | losses_comb[key + '_r'] = losses_r[key] 40 | else: 41 | losses_comb[key] = losses_f[key] + losses_r[key] 42 | losses_comb[key + '_f'] = losses_f[key] 43 | losses_comb[key + '_r'] = losses_r[key] 44 | return losses_comb 45 | 46 | 47 | def clamp(x, min_val=0., max_val=1.): 48 | return max(min_val, min(x, max_val)) 49 | 50 | 51 | def l2_normalize(w, dim, eps=1e-12): 52 | """PyTorch implementation of tf.nn.l2_normalize 53 | """ 54 | return w / w.pow(2).sum(dim, keepdim=True).clamp(min=eps).sqrt() 55 | 56 | 57 | def l2_norm_except_dim(w, dim, eps=1e-12): 58 | norm_dims = [i for i, _ in enumerate(w.shape)] 59 | del norm_dims[dim] 60 | return l2_normalize(w, norm_dims, eps) 61 | 62 | 63 | def moments(x, dim, keepdim=False): 64 | """PyTorch implementation of tf.nn.moments over a single dimension 65 | """ 66 | # n = x.numel() / torch.prod(torch.tensor(x.shape)[dim]) # useful for multiple dims 67 | mean = x.mean(dim=dim, keepdim=True) 68 | variance = (x - mean.detach()).pow(2).mean(dim=dim, keepdim=keepdim) 69 | if not keepdim: 70 | mean = mean.squeeze(dim) 71 | return mean, variance 72 | 73 | 74 | class Normalize(autograd.Function): 75 | """Normalize x across dim 76 | """ 77 | @staticmethod 78 | def forward(ctx, x, dim, eps=1e-5): 79 | x_mu = x - x.mean(dim=dim, keepdim=True) 80 | inv_std = 1 / (x_mu.pow(2).mean(dim=dim, keepdim=True) + eps).sqrt() 81 | x_norm = x_mu * inv_std 82 | 83 | if ctx is not None: 84 | ctx.save_for_backward(x_mu, inv_std) 85 | ctx.dim = dim 86 | return x_norm 87 | 88 | @staticmethod 89 | def backward(ctx, grad_out): 90 | x_mu, inv_std = ctx.saved_tensors 91 | dim = ctx.dim 92 | n = x_mu.size(dim) 93 | 94 | # adapted from: https://cthorey.github.io/backpropagation/ 95 | # https://wiseodd.github.io/techblog/2016/07/04/batchnorm/ 96 | dx = inv_std / n * ( 97 | grad_out * n - 98 | grad_out.sum(dim, keepdim=True) - 99 | (grad_out * x_mu).sum(dim, keepdim=True) * x_mu * inv_std ** 2 100 | ) 101 | return dx, None, None 102 | 103 | @staticmethod 104 | def test(): 105 | x = torch.DoubleTensor(3, 4, 2, 5).normal_(0, 1).requires_grad_() 106 | inputs = (x, 1) 107 | return autograd.gradcheck(Normalize.apply, inputs) 108 | 109 | 110 | normalize = Normalize.apply 111 | -------------------------------------------------------------------------------- /src/seqdesign_pt/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | 6 | from seqdesign_pt.functions import normalize, l2_norm_except_dim 7 | 8 | 9 | class HyperparameterError(ValueError): 10 | pass 11 | 12 | 13 | class LayerChannelNorm(nn.Module): 14 | """Normalizes across C for input NCHW 15 | """ 16 | __constants__ = ['num_channels', 'dim', 'eps', 'affine', 'weight', 'bias', 'g_init', 'bias_init'] 17 | 18 | def __init__(self, num_channels, dim=1, eps=1e-5, affine=True, g_init=1.0, bias_init=0.1): 19 | super(LayerChannelNorm, self).__init__() 20 | self.num_channels = num_channels 21 | self.dim = dim 22 | self.eps = eps 23 | self.affine = affine 24 | self.g_init = g_init 25 | self.bias_init = bias_init 26 | if self.affine: 27 | self.weight = Parameter(torch.Tensor(num_channels)) 28 | self.bias = Parameter(torch.Tensor(num_channels)) 29 | else: 30 | self.register_parameter('weight', None) 31 | self.register_parameter('bias', None) 32 | self.reset_parameters() 33 | 34 | def reset_parameters(self): 35 | if self.affine: 36 | nn.init.normal_(self.weight, mean=self.g_init, std=1e-6) 37 | nn.init.normal_(self.bias, mean=self.bias_init, std=1e-6) 38 | 39 | def forward(self, x): 40 | h = normalize(x, self.dim) 41 | 42 | if self.affine: 43 | shape = [1 for _ in x.shape] 44 | shape[self.dim] = self.num_channels 45 | h = self.weight.view(shape) * h + self.bias.view(shape) 46 | return h 47 | 48 | 49 | class LayerNorm(nn.GroupNorm): 50 | """Normalizes across CHW for input NCHW 51 | """ 52 | def __init__(self, num_channels, eps=1e-5, affine=True, g_init=1.0, bias_init=0.1): 53 | self.g_init = g_init 54 | self.bias_init = bias_init 55 | super(LayerNorm, self).__init__(num_groups=1, num_channels=num_channels, eps=eps, affine=affine) 56 | 57 | def reset_parameters(self): 58 | if self.affine: 59 | nn.init.normal_(self.weight, mean=self.g_init, std=1e-6) 60 | nn.init.normal_(self.bias, mean=self.bias_init, std=1e-6) 61 | 62 | 63 | class Conv2d(nn.Module): 64 | def __init__( 65 | self, 66 | in_channels, 67 | out_channels, 68 | kernel_width=(1, 1), 69 | stride=(1, 1), 70 | dilation=(1, 1), 71 | g_init=1.0, 72 | bias_init=0.1, 73 | causal=False, 74 | activation=None, 75 | ): 76 | super(Conv2d, self).__init__() 77 | self.in_channels = in_channels 78 | self.out_channels = out_channels 79 | self.kernel_width = kernel_width 80 | self.stride = stride 81 | self.dilation = dilation 82 | self.causal = causal 83 | self.activation = activation 84 | 85 | self.generating = False 86 | self.generating_reset = True 87 | self._weight = None 88 | self._input_cache = None 89 | 90 | self.padding = tuple(d * (w-1)//2 for w, d in zip(kernel_width, dilation)) 91 | 92 | self.bias = Parameter(torch.Tensor(out_channels)) 93 | self.weight_v = Parameter(torch.Tensor(out_channels, in_channels, *kernel_width)) 94 | self.weight_g = Parameter(torch.Tensor(out_channels)) 95 | 96 | if causal: 97 | if any(w % 2 == 0 for w in kernel_width): 98 | raise HyperparameterError(f"Even kernel width incompatible with causal convolution: {kernel_width}") 99 | if kernel_width == (1, 3): # make common case explicit 100 | mask = torch.Tensor([1., 1., 0.]) 101 | elif kernel_width[0] == 1: 102 | mask = torch.ones(kernel_width) 103 | mask[0, kernel_width[1] // 2 + 1:] = 0 104 | else: 105 | mask = torch.ones(kernel_width) 106 | mask[kernel_width[0] // 2, kernel_width[1] // 2:] = 0 107 | mask[kernel_width[0] // 2 + 1:, :] = 0 108 | 109 | mask = mask.view(1, 1, *kernel_width) 110 | self.register_buffer('mask', mask) 111 | else: 112 | self.register_buffer('mask', None) 113 | 114 | self.reset_parameters(g_init=g_init, bias_init=bias_init) 115 | 116 | def reset_parameters(self, v_mean=0., v_std=0.05, g_init=1.0, bias_init=0.1): 117 | nn.init.normal_(self.weight_v, mean=v_mean, std=v_std) 118 | nn.init.constant_(self.weight_g, val=g_init) 119 | nn.init.constant_(self.bias, val=bias_init) 120 | 121 | def generate(self, mode=True): 122 | self.generating = mode 123 | self.generating_reset = True 124 | self._weight = None 125 | self._input_cache = None 126 | return self 127 | 128 | def weight_costs(self): 129 | return ( 130 | self.weight_v.pow(2).sum(), 131 | self.weight_g.pow(2).sum(), 132 | self.bias.pow(2).sum() 133 | ) 134 | 135 | @property 136 | def weight(self): 137 | shape = (self.out_channels, 1, 1, 1) 138 | weight = l2_norm_except_dim(self.weight_v, 0) * self.weight_g.view(shape) 139 | if self.mask is not None: 140 | weight = weight * self.mask 141 | return weight 142 | 143 | def forward(self, inputs): 144 | """ 145 | :param inputs: (N, C_in, H, W) 146 | :return: (N, C_out, H, W) 147 | """ 148 | if self.generating: 149 | if self.generating_reset: 150 | self.generating_reset = False 151 | if self.kernel_width != (1, 1): 152 | self._input_cache = inputs 153 | else: 154 | return self.forward_generate(inputs) 155 | 156 | h = F.conv2d(inputs, self.weight, bias=self.bias, 157 | stride=self.stride, padding=self.padding, dilation=self.dilation) 158 | if self.activation is not None: 159 | h = self.activation(h) 160 | return h 161 | 162 | def forward_generate(self, inputs): 163 | """Calculates forward for the last position in `inputs` 164 | Only implemented for kernel widths (1, 1) and (1, 3) and stride (1, 1). 165 | If the kernel width is (1, 3), causal must be True. 166 | 167 | :param inputs: tensor(N, C_in, 1, 1) 168 | :return: tensor(N, C_out, 1, 1) 169 | """ 170 | if self._weight is None: 171 | self._weight = self.weight 172 | self._weight = self._weight.transpose(0, 1) 173 | if self.kernel_width == (1, 1): 174 | h = inputs[:, :, 0, -1] @ self._weight[:, :, 0, 0] + self.bias.view(1, self.out_channels) 175 | elif self.kernel_width == (1, 3): 176 | h = inputs[:, :, 0, -1] @ self._weight[:, :, 0, 1] 177 | if self.dilation[1] < self._input_cache.size(3): 178 | h += self._input_cache[:, :, 0, -self.dilation[1]] @ self._weight[:, :, 0, 0] 179 | h += self.bias.view(1, self.out_channels) 180 | self._input_cache = torch.cat([self._input_cache, inputs], dim=3) 181 | else: 182 | raise HyperparameterError(f"Generate not supported for kernel width {self.kernel_width}.") 183 | if self.activation is not None: 184 | h = self.activation(h) 185 | return h.unsqueeze(-1).unsqueeze(-1) 186 | 187 | def extra_repr(self): 188 | s = '{in_channels}, {out_channels}, kernel_size={kernel_width}' 189 | if self.stride != (1,) * len(self.stride): 190 | s += ', stride={stride}' 191 | if self.dilation != (1,) * len(self.dilation): 192 | s += ', dilation={dilation}' 193 | if self.causal: 194 | s += ', causal=True' 195 | return s.format(**self.__dict__) 196 | 197 | 198 | class ConvNet1DLayer(nn.Module): 199 | configurations = ['original', 'updated', 'standard'] 200 | dropout_types = ['independent', '2D'] 201 | 202 | def __init__( 203 | self, 204 | channels=48, 205 | dilation=1, 206 | dropout_p=0.5, 207 | dropout_type='independent', 208 | causal=True, 209 | config='original', 210 | add_input_channels=None, 211 | transpose=False, 212 | nonlinearity=F.relu, 213 | ): 214 | super(ConvNet1DLayer, self).__init__() 215 | self.channels = channels 216 | self.dilation = dilation 217 | self.causal = causal 218 | self.dropout_p = dropout_p 219 | self.dropout_type = dropout_type 220 | self.add_input_channels = None if add_input_channels == 0 else add_input_channels 221 | self.transpose = transpose 222 | self.nonlinearity = nonlinearity 223 | self.config = config 224 | 225 | self.generating = False 226 | self.generating_reset = False 227 | self._dropout2d_mask = None 228 | 229 | if config not in self.configurations: 230 | raise HyperparameterError(f"Unknown configuration: '{config}'. Accepts {self.configurations}") 231 | if dropout_type not in self.dropout_types: 232 | raise HyperparameterError(f"Unknown dropout type: '{dropout_type}'. Accepts {self.dropout_types}") 233 | 234 | if self.add_input_channels is not None: 235 | input_channels = channels + self.add_input_channels 236 | else: 237 | input_channels = channels 238 | 239 | self.layernorm_1 = LayerChannelNorm(input_channels) 240 | self.layernorm_2 = LayerChannelNorm(channels) 241 | if config == 'standard': 242 | self.layernorm_3 = LayerChannelNorm(channels) 243 | else: 244 | self.register_parameter('layernorm_3', None) 245 | 246 | self.mix_conv_1 = Conv2d(input_channels, channels) 247 | self.dilated_conv = Conv2d( 248 | channels, channels, 249 | kernel_width=(1, 3), 250 | dilation=(1, dilation), 251 | causal=causal, 252 | bias_init=0.0, 253 | ) 254 | self.mix_conv_3 = Conv2d(channels, channels) 255 | 256 | if self.dropout_type == 'independent': 257 | self.dropout = nn.Dropout(p=dropout_p) 258 | elif self.dropout_type == '2D': 259 | self.dropout = nn.Dropout2d(p=dropout_p) # TODO test performance with Dropout2d 260 | 261 | if self.config == 'original': 262 | self.operations = [ 263 | self.layernorm_1, self.mix_conv_1, self.nonlinearity, 264 | self.dilated_conv, self.nonlinearity, 265 | self.mix_conv_3, self.nonlinearity, 266 | self.dropout, self.layernorm_2, 267 | ] 268 | elif self.config == 'updated': 269 | self.operations = [ 270 | self.layernorm_1, self.mix_conv_1, self.nonlinearity, 271 | self.dilated_conv, self.nonlinearity, 272 | self.mix_conv_3, self.nonlinearity, 273 | self.layernorm_2, self.dropout, 274 | ] 275 | elif self.config == 'standard': 276 | self.operations = [ 277 | self.layernorm_1, self.nonlinearity, self.mix_conv_1, 278 | self.layernorm_2, self.nonlinearity, self.dilated_conv, 279 | self.layernorm_3, self.nonlinearity, self.mix_conv_3, 280 | self.dropout, 281 | ] 282 | 283 | def generate(self, mode=True): 284 | self.generating = mode 285 | self.generating_reset = True 286 | self._dropout2d_mask = None 287 | for module in self.children(): 288 | if hasattr(module, "generate") and callable(module.generate): 289 | module.generate(mode) 290 | return self 291 | 292 | def weight_costs(self): 293 | return ( 294 | self.mix_conv_1.weight_costs() + 295 | self.dilated_conv.weight_costs() + 296 | self.mix_conv_3.weight_costs() 297 | ) 298 | 299 | def forward(self, inputs, input_masks, additional_input=None): 300 | """ 301 | :param inputs: Tensor(N, C, 1, L) 302 | :param input_masks: Tensor(N, 1, 1, L) 303 | :param additional_input: Tensor(N, C_add, 1, L) 304 | :return: Tensor(N, C, 1, L) 305 | """ 306 | if self.generating: 307 | return self.forward_generate(inputs, input_masks, additional_input) 308 | if self.add_input_channels is not None: 309 | delta_layer = torch.cat([inputs, additional_input], dim=1) 310 | else: 311 | delta_layer = inputs 312 | 313 | for op in self.operations: 314 | delta_layer = op(delta_layer) 315 | 316 | return delta_layer 317 | 318 | def forward_generate(self, inputs, input_masks, additional_input=None): 319 | """ 320 | :param inputs: Tensor(N, C, 1, L) initialization, or Tensor(N, C, 1, 1) afterwards 321 | :param input_masks: Tensor(N, 1, 1, >=L) 322 | :param additional_input: Tensor(N, C_add, 1, >=L) 323 | :return: Tensor(N, C, 1, L) 324 | """ 325 | if self.add_input_channels is not None: 326 | delta_layer = torch.cat([inputs, additional_input[:, :, :, 0:inputs.size(3)]], dim=1) 327 | else: 328 | delta_layer = inputs 329 | 330 | if self.generating_reset: 331 | self.generating_reset = False 332 | if self.training and self.dropout_type == '2D' and self._dropout2d_mask is None: 333 | p = 1 - self.dropout_p 334 | self._dropout2d_mask = torch.bernoulli(torch.full((1, self.channels, 1, 1), p)) / p 335 | 336 | for op in self.operations: 337 | if op is self.dropout and self.training and self.dropout_type == '2D': 338 | delta_layer *= self._dropout2d_mask 339 | else: 340 | delta_layer = op(delta_layer) 341 | return delta_layer 342 | 343 | def extra_repr(self): 344 | return '{channels}, dilation={dilation}, causal={causal}, config={config}, ' \ 345 | 'add_input_channels={add_input_channels}'.format(**self.__dict__) 346 | 347 | 348 | class ConvNet1D(nn.Module): 349 | additional_input_layers = ['all', 'first'] 350 | 351 | def __init__( 352 | self, 353 | channels=48, 354 | layers=9, 355 | dropout_p=0.5, 356 | dropout_type='independent', 357 | causal=True, 358 | config='original', 359 | add_input_channels=None, 360 | add_input_layer=None, # 'all', 'first' 361 | dilation_schedule=None, 362 | transpose=False, 363 | nonlinearity=F.elu, 364 | ): 365 | super(ConvNet1D, self).__init__() 366 | self.channels = channels 367 | self.num_layers = layers 368 | self.causal = causal 369 | self.dropout_p = dropout_p 370 | self.dropout_type = dropout_type 371 | self.transpose = transpose 372 | self.nonlinearity = nonlinearity 373 | self.add_input_channels = add_input_channels 374 | self.add_input_layer = add_input_layer 375 | self.config = config 376 | 377 | if add_input_layer is not None and add_input_layer not in self.additional_input_layers: 378 | raise HyperparameterError(f"Unknown additional input layer: '{add_input_layer}'. " 379 | f"Accepts {self.additional_input_layers}") 380 | 381 | if dilation_schedule is None: 382 | self.dilations = [2 ** i for i in range(layers)] 383 | else: 384 | self.dilations = dilation_schedule 385 | 386 | self.dilation_layers = nn.ModuleList() 387 | 388 | for i_layer, dilation in enumerate(self.dilations): 389 | add_input_c = None 390 | if self.add_input_layer == 'all' or (self.add_input_layer == 'first' and i_layer == 0): 391 | add_input_c = add_input_channels 392 | 393 | self.dilation_layers.append(ConvNet1DLayer( 394 | channels=channels, dilation=dilation, dropout_p=dropout_p, dropout_type=dropout_type, causal=causal, 395 | config=config, add_input_channels=add_input_c, transpose=transpose, nonlinearity=nonlinearity 396 | )) 397 | 398 | if causal: 399 | self.receptive_field = 2 ** (layers-1) 400 | else: 401 | self.receptive_field = 2 ** layers - 1 402 | 403 | def generate(self, mode=True): 404 | for module in self.dilation_layers: 405 | if hasattr(module, "generate") and callable(module.generate): 406 | module.generate(mode) 407 | return self 408 | 409 | def weight_costs(self): 410 | return [cost for layer in self.dilation_layers for cost in layer.weight_costs()] 411 | 412 | def forward(self, inputs, input_masks, additional_input=None): 413 | """ 414 | :param inputs: Tensor(N, C, 1, L) 415 | :param input_masks: Tensor(N, 1, 1, L) 416 | :param additional_input: Tensor(N, C_add, 1, L) 417 | :return: Tensor(N, C, 1, L) 418 | """ 419 | up_layer = inputs 420 | 421 | for layer, dilation in enumerate(self.dilations): 422 | add_input = None 423 | if self.add_input_layer == 'all' or (self.add_input_layer == 'first' and layer == 0): 424 | add_input = additional_input 425 | 426 | delta_layer = self.dilation_layers[layer](up_layer, input_masks, add_input) 427 | up_layer = up_layer + delta_layer 428 | 429 | return up_layer 430 | 431 | def extra_repr(self): 432 | return '{channels}, layers={num_layers}, causal={causal}, config={config}, ' \ 433 | 'add_input_channels={add_input_channels}'.format(**self.__dict__) 434 | -------------------------------------------------------------------------------- /src/seqdesign_pt/model_logging.py: -------------------------------------------------------------------------------- 1 | # code referenced from https://github.com/vincentherrmann/pytorch-wavenet/blob/master/model_logging.py 2 | import threading 3 | from io import BytesIO 4 | import time 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import torch 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | class Accumulator: 13 | def __init__(self, *keys): 14 | self._values = {key: 0. for key in keys} 15 | self.log_interval = 0 16 | 17 | def accumulate(self, **kwargs): 18 | for key in kwargs: 19 | self._values[key] += kwargs[key] 20 | self.log_interval += 1 21 | 22 | def reset(self): 23 | for key in self._values: 24 | self._values[key] = 0. 25 | self.log_interval = 0 26 | 27 | @property 28 | def values(self): 29 | return {key: value / self.log_interval for key, value in self._values.items()} 30 | 31 | def __getattr__(self, item): 32 | return self._values[item] / self.log_interval 33 | 34 | 35 | class Logger: 36 | def __init__(self, 37 | log_interval=50, 38 | validation_interval=200, 39 | generate_interval=500, 40 | info_interval=1000, 41 | trainer=None, 42 | generate_function=None): 43 | self.trainer = trainer 44 | self.log_interval = log_interval 45 | self.val_interval = validation_interval 46 | self.gen_interval = generate_interval 47 | self.info_interval = info_interval 48 | self.log_time = time.time() 49 | self.load_time = 0. 50 | self.accumulator = Accumulator('loss', 'ce_loss', 'bitperchar') 51 | self.generate_function = generate_function 52 | if self.generate_function is not None: 53 | self.generate_thread = threading.Thread(target=self.generate_function) 54 | self.generate_function.daemon = True 55 | 56 | def log(self, current_step, current_losses, current_grad_norm, load_time=0.): 57 | self.load_time += load_time 58 | self.accumulator.accumulate( 59 | loss=float(current_losses['loss'].detach()), 60 | ce_loss=float(current_losses['ce_loss'].detach()), 61 | bitperchar=float(current_losses['bitperchar'].detach()) if 'bitperchar' in current_losses else 0., 62 | ) 63 | 64 | if current_step % self.log_interval == 0 or current_step < 10: 65 | self.log_loss(current_step) 66 | self.log_time = time.time() 67 | self.load_time = 0. 68 | self.accumulator.reset() 69 | if self.val_interval is not None and self.val_interval > 0 and current_step % self.val_interval == 0: 70 | self.validate(current_step) 71 | if self.gen_interval is not None and self.gen_interval > 0 and current_step % self.gen_interval == 0: 72 | self.generate(current_step) 73 | if self.info_interval is not None and self.info_interval > 0 and current_step % self.info_interval == 0: 74 | self.info(current_step) 75 | 76 | def log_loss(self, current_step): 77 | v = self.accumulator.values 78 | print(f"{time.time() - self.log_time:7.3f} {self.load_time:7.3f} " 79 | f"loss, ce_loss, bitperchar at step {current_step:8d}: " 80 | f"{v['loss']:11.6f}, {v['ce_loss']:11.6f}, {v['bitperchar']:10.6f}", flush=True) 81 | 82 | def validate(self, current_step): 83 | avg_loss, avg_accuracy = self.trainer.validate() 84 | print("validation loss: " + str(avg_loss), flush=True) 85 | print("validation accuracy: " + str(avg_accuracy * 100) + "%", flush=True) 86 | 87 | def generate(self, current_step): 88 | if self.generate_function is None: 89 | return 90 | 91 | if self.generate_thread.is_alive(): 92 | print("Last generate is still running, skipping this one") 93 | else: 94 | self.generate_thread = threading.Thread(target=self.generate_function, args=[current_step]) 95 | self.generate_thread.daemon = True 96 | self.generate_thread.start() 97 | 98 | def info(self, current_step): 99 | pass 100 | # print( 101 | # 'GPU Mem Allocated:', 102 | # round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 103 | # 'GB, ', 104 | # 'Cached:', 105 | # round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 106 | # 'GB' 107 | # ) 108 | 109 | 110 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 111 | class TensorboardLogger(Logger): 112 | def __init__(self, 113 | log_interval=50, 114 | validation_interval=200, 115 | generate_interval=500, 116 | info_interval=1000, 117 | trainer=None, 118 | generate_function=None, 119 | log_dir='logs', 120 | log_param_histograms=False, 121 | log_image_summaries=True, 122 | print_output=False, 123 | ): 124 | super().__init__( 125 | log_interval, validation_interval, generate_interval, info_interval, trainer, generate_function) 126 | self.writer = SummaryWriter(log_dir) 127 | self.log_param_histograms = log_param_histograms 128 | self.log_image_summaries = log_image_summaries 129 | self.print_output = print_output 130 | 131 | def log(self, current_step, current_losses, current_grad_norm, load_time=0.): 132 | super(TensorboardLogger, self).log(current_step, current_losses, current_grad_norm, load_time) 133 | self.scalar_summary('grad norm', current_grad_norm, current_step) 134 | self.scalar_summary('loss', current_losses['loss'].detach(), current_step) 135 | self.scalar_summary('ce_loss', current_losses['ce_loss'].detach(), current_step) 136 | if 'accuracy' in current_losses: 137 | self.scalar_summary('accuracy', current_losses['accuracy'].detach(), current_step) 138 | if 'bitperchar' in current_losses: 139 | self.scalar_summary('bitperchar', current_losses['bitperchar'].detach(), current_step) 140 | self.scalar_summary('reconstruction loss', current_losses['ce_loss'].detach(), current_step) 141 | self.scalar_summary('regularization loss', current_losses['weight_cost'].detach(), current_step) 142 | 143 | def log_loss(self, current_step): 144 | if self.print_output: 145 | Logger.log_loss(self, current_step) 146 | # loss 147 | v = self.accumulator.values 148 | avg_loss, avg_ce_loss, avg_bitperchar = v['loss'], v['ce_loss'], v['bitperchar'] 149 | self.scalar_summary('avg loss', avg_loss, current_step) 150 | self.scalar_summary('avg ce loss', avg_ce_loss, current_step) 151 | self.scalar_summary('avg bitperchar', avg_bitperchar, current_step) 152 | 153 | if self.log_param_histograms: 154 | for tag, value, in self.trainer.model.named_parameters(): 155 | tag = tag.replace('.', '/') 156 | self.histo_summary(tag, value.data, current_step) 157 | if value.grad is not None: 158 | self.histo_summary(tag + '/grad', value.grad.data, current_step) 159 | 160 | if self.log_image_summaries: 161 | for tag, summary in self.trainer.model.image_summaries.items(): 162 | self.image_summary(tag, summary['img'], current_step, max_outputs=summary.get('max_outputs', 3)) 163 | 164 | def validate(self, current_step): 165 | avg_loss, avg_accuracy = self.trainer.validate() 166 | self.scalar_summary('validation loss', avg_loss, current_step) 167 | self.scalar_summary('validation accuracy', avg_accuracy, current_step) 168 | if self.print_output: 169 | print("validation loss: " + str(avg_loss), flush=True) 170 | print("validation accuracy: " + str(avg_accuracy * 100) + "%", flush=True) 171 | 172 | def scalar_summary(self, tag, value, step): 173 | """Log a scalar variable.""" 174 | if isinstance(value, torch.Tensor): 175 | value = value.item() # value must have 1 element only 176 | self.writer.add_scalar(tag, value, global_step=step) 177 | 178 | def image_summary(self, tag, images, step, max_outputs=3): 179 | """Log a tensor image. 180 | :param tag: string summary name 181 | :param images: (N, H, W, C) or (N, H, W) 182 | :param step: current step 183 | :param max_outputs: max N images to save 184 | """ 185 | 186 | images = images[:max_outputs] 187 | format = "NHW" if images.dim() == 3 else "NHWC" 188 | self.writer.add_images(tag, images, global_step=step, dataformats=format) 189 | 190 | def histo_summary(self, tag, values, step, bins=200): 191 | """Log a histogram of the tensor of values.""" 192 | self.writer.add_histogram(tag, values, global_step=step, bins=bins) 193 | -------------------------------------------------------------------------------- /src/seqdesign_pt/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronkollasch/seqdesign-pytorch/f9391b2f4689827d7d2993d00830fbe2f406676a/src/seqdesign_pt/scripts/__init__.py -------------------------------------------------------------------------------- /src/seqdesign_pt/scripts/calc_logprobs_seqs_fr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | try: 3 | import tensorflow as tf 4 | from seqdesign_pt.tf_reader import TFReader 5 | except ImportError: 6 | TFReader = None 7 | class tf: 8 | __version__ = None 9 | import numpy as np 10 | import pandas as pd 11 | import argparse 12 | import sys 13 | import os 14 | import glob 15 | import torch 16 | from seqdesign_pt import autoregressive_model 17 | from seqdesign_pt import autoregressive_train 18 | from seqdesign_pt import utils 19 | from seqdesign_pt import aws_utils 20 | from seqdesign_pt import data_loaders 21 | from seqdesign_pt.version import VERSION 22 | 23 | 24 | def main(): 25 | working_dir = '..' 26 | 27 | parser = argparse.ArgumentParser(description="Calculate the log probability of mutated sequences.") 28 | parser.add_argument("--channels", type=int, default=48, metavar='C', help="Number of channels.") 29 | parser.add_argument("--r-seed", type=int, default=-1, metavar='RSEED', help="Random seed.") 30 | parser.add_argument("--num-samples", type=int, default=1, metavar='N', help="Number of iterations to run the model.") 31 | parser.add_argument("--minibatch-size", type=int, default=100, metavar='B', help="Minibatch size for inferring effect prediction.") 32 | parser.add_argument("--dropout-p", type=float, default=0., metavar='P', help="Dropout p while sampling log p(x).") 33 | parser.add_argument("--sess", type=str, default='', help="Session folder name for restoring a model.", required=True) 34 | parser.add_argument("--checkpoint", type=int, default=None, metavar='CKPT', help="Checkpoint step number.") 35 | parser.add_argument("--input", type=str, default='', help="Directory and filename of the input data.", required=True) 36 | parser.add_argument("--output", type=str, default='', help="Directory and filename of the outout data.", required=True) 37 | parser.add_argument("--save-logits", action='store_true', help="Save logprobs matrices.") 38 | parser.add_argument("--save-ce", action='store_true',help="Save cross entropy matrices.") 39 | parser.add_argument("--alphabet-type", type=str, default='protein', metavar='T', help="Alphabet to use for the dataset.", required=False) 40 | parser.add_argument("--s3-path", type=str, default='', help="Base s3:// path (leave blank to disable syncing).") 41 | parser.add_argument("--s3-project", type=str, default=VERSION, help="Project name (subfolder of s3-path).") 42 | parser.add_argument("--num-data-workers", type=int, default=0, help="Number of workers to load data") 43 | parser.add_argument("--no-cuda", action='store_true', help="Disable GPU evaluation") 44 | parser.add_argument("--from-tf", action='store_true', help="Load model from tensorflow checkpoint") 45 | 46 | args = parser.parse_args() 47 | 48 | print(args) 49 | 50 | print('Call:', ' '.join(sys.argv)) 51 | print("OS: ", sys.platform) 52 | print("Python: ", sys.version) 53 | print("PyTorch: ", torch.__version__) 54 | print("TensorFlow: ", tf.__version__) 55 | print("Numpy: ", np.__version__) 56 | 57 | use_cuda = not args.no_cuda 58 | device = torch.device("cuda:0" if use_cuda and torch.cuda.is_available() else "cpu") 59 | print('Using device:', device) 60 | if device.type == 'cuda': 61 | print(torch.cuda.get_device_name(0)) 62 | print('Memory Usage:') 63 | print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB') 64 | print('Cached: ', round(torch.cuda.memory_cached(0)/1024**3, 1), 'GB') 65 | print(utils.get_cuda_version()) 66 | print("CuDNN Version ", utils.get_cudnn_version()) 67 | 68 | print("SeqDesign-PyTorch git hash:", str(utils.get_github_head_hash())) 69 | print() 70 | 71 | sess_name = args.sess 72 | input_filename = args.input 73 | output_filename = args.output 74 | 75 | aws_util = aws_utils.AWSUtility(s3_project=args.s3_project, s3_base_path=args.s3_path) if args.s3_path else None 76 | if not os.path.exists(input_filename) and input_filename.startswith('input/') and aws_util is not None: 77 | folder, filename = input_filename.rsplit('/', 1) 78 | if not aws_util.s3_get_file_grep( 79 | s3_folder=f'calc_logprobs/{folder}', 80 | dest_folder=f'{working_dir}/calc_logprobs/{folder}', 81 | search_pattern=f'{filename}', 82 | ): 83 | raise Exception("Could not download test data from S3.") 84 | 85 | dataset = data_loaders.FastaDataset( 86 | batch_size=args.minibatch_size, 87 | working_dir='.', 88 | dataset=args.input, 89 | matching=True, 90 | unlimited_epoch=False, 91 | alphabet_type=args.alphabet_type, 92 | ) 93 | loader = data_loaders.GeneratorDataLoader( 94 | dataset, 95 | num_workers=args.num_data_workers, 96 | pin_memory=True, 97 | ) 98 | print("Read in test data.") 99 | 100 | if args.checkpoint is None: # look for old-style flat session file structure 101 | glob_path = f"{working_dir}/sess/{sess_name}*" 102 | grep_path = f'{sess_name}.*' 103 | sess_namedir = f"{working_dir}/sess/{sess_name}" 104 | else: # look for new folder-based session file structure 105 | glob_path = f"{working_dir}/sess/{sess_name}/{sess_name}.ckpt-{args.checkpoint}*" 106 | grep_path = f'{sess_name}.ckpt-{args.checkpoint}.*' 107 | sess_namedir = f"{working_dir}/sess/{sess_name}/{sess_name}.ckpt-{args.checkpoint}" 108 | 109 | if not glob.glob(glob_path) and aws_util: 110 | if not aws_util.s3_get_file_grep( 111 | f'sess/{sess_name}', 112 | f'{working_dir}/sess/{sess_name}', 113 | grep_path, 114 | ): 115 | raise Exception("Could not download session files from S3.") 116 | 117 | print("Initializing and loading variables") 118 | if args.from_tf: 119 | if not TFReader: 120 | print("Trying to read tensorflow model but tensorflow could not be imported.") 121 | exit(1) 122 | reader = TFReader(sess_namedir) 123 | legacy_version = reader.get_checkpoint_legacy_version() 124 | last_dilation_size = 200 if legacy_version == 0 else 256 125 | hyperparams = {'encoder': {'dilation_schedule': [1, 2, 4, 8, 16, 32, 64, 128, last_dilation_size], 126 | "config": "original", 127 | 'dropout_type': 'independent'}} 128 | model = autoregressive_model.AutoregressiveFR( 129 | channels=args.channels, dropout_p=1-args.dropout_p, hyperparams=hyperparams) 130 | reader.load_autoregressive_fr(model) 131 | else: 132 | checkpoint = torch.load(sess_namedir+'.pth', map_location='cpu') 133 | dims = checkpoint['model_dims'] 134 | hyperparams = checkpoint['model_hyperparams'] 135 | model = autoregressive_model.AutoregressiveFR(dims=dims, hyperparams=hyperparams, dropout_p=1-args.dropout_p) 136 | model.load_state_dict(checkpoint['model_state_dict']) 137 | model.to(device) 138 | print("Num parameters:", model.parameter_count()) 139 | 140 | trainer = autoregressive_train.AutoregressiveTrainer( 141 | model=model, 142 | data_loader=None, 143 | device=device, 144 | ) 145 | output = trainer.test(loader, model_eval=False, num_samples=args.num_samples, return_logits=args.save_logits, return_ce=args.save_ce) 146 | if args.save_logits or args.save_ce: 147 | output, logits = output 148 | logits_path = os.path.splitext(args.output)[0] 149 | os.makedirs(logits_path, exist_ok=True) 150 | for key, value in logits.items(): 151 | np.save(f"{logits_path}/{key}.npy", value) 152 | 153 | os.makedirs(output_filename.rsplit('/', 1)[0], exist_ok=True) 154 | output = pd.DataFrame(output, columns=output.keys()) 155 | output.to_csv(args.output, index=False) 156 | print("Done!") 157 | 158 | if output_filename.startswith('output/') and aws_util: 159 | aws_util.s3_cp( 160 | local_file=output_filename, 161 | s3_file=f'calc_logprobs/output/{output_filename.replace("output/", "")}', 162 | destination='s3' 163 | ) 164 | 165 | 166 | if __name__ == "__main__": 167 | main() 168 | -------------------------------------------------------------------------------- /src/seqdesign_pt/scripts/generate_sample_seqs_fr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | try: 3 | import tensorflow as tf 4 | from seqdesign_pt.tf_reader import TFReader 5 | except ImportError: 6 | TFReader = None 7 | class tf: 8 | __version__ = None 9 | import numpy as np 10 | import time 11 | import sys 12 | import os 13 | import argparse 14 | import glob 15 | import torch 16 | import torch.distributions as dist 17 | from seqdesign_pt import autoregressive_model 18 | from seqdesign_pt import utils 19 | from seqdesign_pt import aws_utils 20 | from seqdesign_pt import data_loaders 21 | from seqdesign_pt.version import VERSION 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser(description="Generate novel sequences sampled from the model.") 26 | parser.add_argument("--sess", type=str, required=True, help="Session name for restoring a model.") 27 | parser.add_argument("--checkpoint", type=int, default=None, metavar='CKPT', help="Checkpoint step number.") 28 | parser.add_argument("--channels", type=int, default=48, metavar='C', help="Number of channels.") 29 | parser.add_argument("--r-seed", type=int, default=42, help="Random seed.") 30 | parser.add_argument("--temp", type=float, default=1.0, help="Generation temperature.") 31 | parser.add_argument("--batch-size", type=int, default=500, help="Number of sequences per generation batch.") 32 | parser.add_argument("--num-batches", type=int, default=1000000, help="Number of batches to generate.") 33 | parser.add_argument("--max-steps", type=int, default=50, help="Maximum number of decoding steps per batch.") 34 | parser.add_argument("--fast-generation", action='store_true', help="Use fast generation mode.") 35 | parser.add_argument("--input-seq", type=str, default='default', help="Path to file with starting sequence.") 36 | parser.add_argument("--output-prefix", type=str, default='nanobody', help="Prefix for output fasta file.") 37 | parser.add_argument("--alphabet-type", type=str, default='protein', metavar='T', help="Alphabet to use for the dataset.", required=False) 38 | parser.add_argument("--s3-path", type=str, default='', help="Base s3:// path (leave blank to disable syncing).") 39 | parser.add_argument("--s3-project", type=str, default=VERSION, help="Project name (subfolder of s3-path).") 40 | parser.add_argument("--no-cuda", action='store_true', help="Disable GPU evaluation") 41 | parser.add_argument("--from-tf", action='store_true', help="Load model from tensorflow checkpoint") 42 | 43 | args = parser.parse_args() 44 | 45 | print(args) 46 | 47 | print('Call:', ' '.join(sys.argv)) 48 | print("OS: ", sys.platform) 49 | print("Python: ", sys.version) 50 | print("PyTorch: ", torch.__version__) 51 | print("TensorFlow: ", tf.__version__) 52 | print("Numpy: ", np.__version__) 53 | 54 | use_cuda = not args.no_cuda 55 | device = torch.device("cuda:0" if use_cuda and torch.cuda.is_available() else "cpu") 56 | print('Using device:', device) 57 | if device.type == 'cuda': 58 | print(torch.cuda.get_device_name(0)) 59 | print('Memory Usage:') 60 | print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB') 61 | print('Cached: ', round(torch.cuda.memory_cached(0)/1024**3, 1), 'GB') 62 | print(utils.get_cuda_version()) 63 | print("CuDNN Version ", utils.get_cudnn_version()) 64 | 65 | print("SeqDesign-PyTorch git hash:", str(utils.get_github_head_hash())) 66 | print() 67 | 68 | aws_util = aws_utils.AWSUtility(s3_project=args.s3_project, s3_base_path=args.s3_path) if args.s3_path else None 69 | 70 | working_dir = "." 71 | 72 | # Variables for runtime modification 73 | sess_name = args.sess 74 | batch_size = args.batch_size 75 | num_batches = args.num_batches 76 | temp = args.temp 77 | r_seed = args.r_seed 78 | 79 | print(r_seed) 80 | 81 | np.random.seed(r_seed) 82 | torch.manual_seed(args.r_seed) 83 | torch.cuda.manual_seed_all(args.r_seed) 84 | 85 | dataset = data_loaders.SequenceDataset(alphabet_type=args.alphabet_type, output_types='decoder', matching=False) 86 | 87 | os.makedirs(os.path.join(working_dir, 'generate_sequences', 'generated'), exist_ok=True) 88 | output_filename = ( 89 | f"{working_dir}/generate_sequences/generated/" 90 | f"{args.output_prefix}_start-{args.input_seq.split('/')[-1].split('.')[0]}" 91 | f"_temp-{temp}_param-{sess_name}_ckpt-{args.checkpoint}_rseed-{r_seed}.fa" 92 | ) 93 | with open(output_filename, "w") as output: 94 | output.write('') 95 | 96 | # Provide the starting sequence to use for generation 97 | if args.input_seq == 'default': 98 | input_seq = "EVQLVESGGGLVQAGGSLRLSCAASGFTFSSYAMGWYRQAPGKEREFVAAISWSGGSTYYADSVKGRFTISRDNAKNTVYLQMNSLKPEDTAVYYC" 99 | elif args.input_seq == 'empty': 100 | input_seq = "" 101 | else: 102 | if not os.path.exists(args.input_seq) and aws_util: 103 | if '/' not in args.input_seq: 104 | args.input_seq = f'{working_dir}/generate_sequences/input/{args.input_seq}' 105 | aws_util.s3_get_file_grep( 106 | 'generate_sequences/input', 107 | f'{working_dir}/generate_sequences/input', 108 | f"{args.input_seq.rsplit('/', 1)[-1]}" 109 | ) 110 | with open(args.input_seq) as f: 111 | input_seq = f.read() 112 | input_seq = input_seq.strip() 113 | 114 | if args.checkpoint is None: # look for old-style flat session file structure 115 | glob_path = f"{working_dir}/sess/{sess_name}*" 116 | grep_path = f'{sess_name}.*' 117 | sess_namedir = f"{working_dir}/sess/{sess_name}" 118 | else: # look for new folder-based session file structure 119 | glob_path = f"{working_dir}/sess/{sess_name}/{sess_name}.ckpt-{args.checkpoint}*" 120 | grep_path = f'{sess_name}.ckpt-{args.checkpoint}.*' 121 | sess_namedir = f"{working_dir}/sess/{sess_name}/{sess_name}.ckpt-{args.checkpoint}" 122 | 123 | if not glob.glob(glob_path) and aws_util: 124 | if not aws_util.s3_get_file_grep( 125 | f'sess/{sess_name}', 126 | f'{working_dir}/sess/{sess_name}', 127 | grep_path, 128 | ): 129 | raise Exception("Could not download session files from S3.") 130 | 131 | print("Initializing and loading variables") 132 | if args.from_tf: 133 | if not TFReader: 134 | print("Trying to read tensorflow model but tensorflow could not be imported.") 135 | exit(1) 136 | reader = TFReader(sess_namedir) 137 | legacy_version = reader.get_checkpoint_legacy_version() 138 | last_dilation_size = 200 if legacy_version == 0 else 256 139 | hyperparams = {'encoder': {'dilation_schedule': [1, 2, 4, 8, 16, 32, 64, 128, last_dilation_size], 140 | "config": "original", 141 | 'dropout_type': 'independent'}} 142 | model: autoregressive_model.AutoregressiveFR = autoregressive_model.AutoregressiveFR( 143 | channels=args.channels, dropout_p=0., hyperparams=hyperparams) 144 | reader.load_autoregressive_fr(model) 145 | else: 146 | checkpoint = torch.load(sess_namedir+'.pth', map_location='cpu') 147 | dims = checkpoint['model_dims'] 148 | hyperparams = checkpoint['model_hyperparams'] 149 | model: autoregressive_model.AutoregressiveFR = autoregressive_model.AutoregressiveFR( 150 | dims=dims, hyperparams=hyperparams, dropout_p=0. 151 | ) 152 | model.load_state_dict(checkpoint['model_state_dict']) 153 | 154 | model: autoregressive_model.Autoregressive = model.model.model_f # use only forward model 155 | model.to(device) 156 | print("Num parameters:", model.parameter_count()) 157 | 158 | model.eval() 159 | for i in range(num_batches): 160 | start = time.time() 161 | 162 | input_seq_list = batch_size * [input_seq] 163 | batch = dataset.sequences_to_onehot(input_seq_list) 164 | seq_in = batch['decoder_input'].to(device) 165 | completion = torch.zeros(batch_size).to(device) 166 | 167 | for step in range(args.max_steps): 168 | with torch.no_grad(): 169 | model.generate(args.fast_generation) 170 | seq_logits = model.forward(seq_in, None) 171 | output_logits = seq_logits[:, :, 0, -1] * args.temp 172 | output = dist.OneHotCategorical(logits=output_logits).sample() 173 | completion += output[:, dataset.aa_dict['*']] 174 | seq_in = torch.cat([seq_in, output.unsqueeze(-1).unsqueeze(-1)], dim=3) 175 | 176 | if (completion > 0).all(): 177 | break 178 | 179 | batch_seqs = seq_in.argmax(1).squeeze().cpu().numpy() 180 | with open(output_filename, "a") as output: 181 | for idx_seq in range(batch_size): 182 | batch_seq = ''.join([dataset.idx_to_aa[idx] for idx in batch_seqs[idx_seq]]) 183 | out_seq = "" 184 | end_seq = False 185 | for idx_aa, aa in enumerate(batch_seq): 186 | if idx_aa != 0: 187 | if end_seq is False: 188 | out_seq += aa 189 | if aa == "*": 190 | end_seq = True 191 | output.write(f">{int(batch_size*i+idx_seq)}\n{out_seq}\n") 192 | 193 | print(f"Batch {i+1} done in {time.time()-start:0.4f} s") 194 | 195 | if aws_util: 196 | aws_util.s3_cp( 197 | local_file=output_filename, 198 | s3_file=f'generate_sequences/generated/{output_filename.rsplit("/", 1)[1]}', 199 | destination='s3' 200 | ) 201 | 202 | 203 | if __name__ == "__main__": 204 | main() 205 | -------------------------------------------------------------------------------- /src/seqdesign_pt/scripts/run_autoregressive_fr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import os 4 | import time 5 | from datetime import timedelta 6 | import sys 7 | import json 8 | import glob 9 | import numpy as np 10 | import torch 11 | from seqdesign_pt.version import VERSION 12 | from seqdesign_pt import data_loaders 13 | from seqdesign_pt import autoregressive_train 14 | from seqdesign_pt import autoregressive_model 15 | from seqdesign_pt import model_logging 16 | from seqdesign_pt import utils 17 | from seqdesign_pt import aws_utils 18 | 19 | 20 | def main(working_dir='.'): 21 | start_run_time = time.time() 22 | 23 | parser = argparse.ArgumentParser(description="Train an autoregressive model on a collection of sequences.") 24 | parser.add_argument("--s3-path", type=str, default='', 25 | help="Base s3:// path (leave blank to disable syncing).") 26 | parser.add_argument("--s3-project", type=str, default=VERSION, metavar='P', 27 | help="Project name (subfolder of s3-path).") 28 | parser.add_argument("--run-name-prefix", type=str, default=None, metavar='P', 29 | help="Prefix for run name.") 30 | parser.add_argument("--channels", type=int, default=48, metavar='C', 31 | help="Number of channels.") 32 | parser.add_argument("--batch-size", type=int, default=30, 33 | help="Batch size.") 34 | parser.add_argument("--num-iterations", type=int, default=250005, metavar='N', 35 | help="Number of iterations to run the model.") 36 | parser.add_argument("--snapshot-interval", type=int, default=None, metavar='N', 37 | help="Take a snapshot every N iterations.") 38 | parser.add_argument("--dataset", type=str, default=None, required=True, 39 | help="Dataset name for fitting model. Alignment weights must be computed beforehand.") 40 | parser.add_argument("--restore", type=str, default='', 41 | help="Session name for restoring a model to continue training.") 42 | parser.add_argument("--gpu", type=int, default=0, 43 | help="Which gpu to use. Usually 0, 1, 2, etc...") 44 | parser.add_argument("--r-seed", type=int, default=42, metavar='RSEED', 45 | help="Random seed for parameter initialization") 46 | parser.add_argument("--dropout-p", type=float, default=0.5, 47 | help="Dropout probability (drop rate, not keep rate)") 48 | parser.add_argument("--alphabet-type", type=str, default='protein', metavar='T', 49 | help="Type of data to model. Options = [protein, DNA, RNA]") 50 | parser.add_argument("--num-data-workers", type=int, default=0, 51 | help="Number of workers to load data") 52 | parser.add_argument("--no-cuda", action='store_true', 53 | help="Disable GPU evaluation") 54 | args = parser.parse_args() 55 | 56 | ######################## 57 | # MAKE RUN DESCRIPTORS # 58 | ######################## 59 | 60 | dataset_name = args.dataset.rsplit('/', 1)[-1] 61 | if dataset_name.endswith(".fa"): 62 | dataset_name = dataset_name[:-len(".fa")] 63 | elif dataset_name.endswith(".fasta"): 64 | dataset_name = dataset_name[:-len(".fasta")] 65 | if args.restore == '': 66 | folder_time = ( 67 | f"{dataset_name}_{args.s3_project}_channels-{args.channels}" 68 | f"_rseed-{args.r_seed}_{time.strftime('%y%b%d_%I%M%p', time.gmtime())}" 69 | ) 70 | if args.run_name_prefix is not None: 71 | folder_time = args.run_name_prefix + '_' + folder_time 72 | else: 73 | folder_time = args.restore.split('/')[-1] 74 | folder_time = folder_time.split('.ckpt')[0] 75 | 76 | folder = f"{working_dir}/sess/{folder_time}" 77 | os.makedirs(folder, exist_ok=True) 78 | log_f = utils.Tee(f'{folder}/log.txt', 'a') # log stdout to log.txt 79 | 80 | if not args.restore: 81 | restore_args = " \\\n ".join(sys.argv[1:]) 82 | restore_args += f" \\\n --restore {{restore}}" 83 | else: 84 | restore_args = sys.argv[1:] 85 | restore_args[restore_args.index('--restore') + 1] = '{{restore}}' 86 | 87 | sbatch_executable = f"""#!/bin/bash 88 | #SBATCH -c 4 # Request one core 89 | #SBATCH -N 1 # Request one node (if you request more than one core with -c, also using 90 | # -N 1 means all cores will be on the same node) 91 | #SBATCH -t 2-11:59 # Runtime in D-HH:MM format 92 | #SBATCH -p gpu # Partition to run in 93 | #SBATCH --gres=gpu:1 94 | #SBATCH --mem=30G # Memory total in MB (for all cores) 95 | #SBATCH -o slurm_files/slurm-%j.out # File to which STDOUT + STDERR will be written, including job ID in filename 96 | hostname 97 | pwd 98 | module load gcc/6.2.0 cuda/9.0 99 | srun stdbuf -oL -eL {sys.executable} \\ 100 | {sys.argv[0]} \\ 101 | {restore_args} 102 | """ 103 | 104 | #################### 105 | # SET RANDOM SEEDS # 106 | #################### 107 | 108 | if args.restore: 109 | # prevent from repeating batches/seed when restoring at intermediate point 110 | # script is repeatable as long as restored at same step 111 | # assumes restore arg of *.ckpt-(int step) 112 | restore_ckpt = args.restore.split('.ckpt-')[-1] 113 | r_seed = args.r_seed + int(restore_ckpt) 114 | r_seed = r_seed % (2 ** 32 - 1) # limit of np.random.seed 115 | else: 116 | r_seed = args.r_seed 117 | 118 | np.random.seed(args.r_seed) 119 | torch.manual_seed(args.r_seed) 120 | torch.cuda.manual_seed_all(args.r_seed) 121 | 122 | def _init_fn(worker_id): 123 | np.random.seed(args.r_seed + worker_id) 124 | 125 | ##################### 126 | # PRINT SYSTEM INFO # 127 | ##################### 128 | 129 | print(folder) 130 | print(args) 131 | 132 | print("OS: ", sys.platform) 133 | print("Python: ", sys.version) 134 | print("Numpy: ", np.__version__) 135 | 136 | use_cuda = not args.no_cuda 137 | device = torch.device(f"cuda:{args.gpu}" if use_cuda and torch.cuda.is_available() else "cpu") 138 | print('Using device:', device) 139 | if device.type == 'cuda': 140 | print(torch.cuda.get_device_name(device)) 141 | print('Memory Usage:') 142 | print('Allocated:', round(torch.cuda.memory_allocated(device)/1024**3, 1), 'GB') 143 | print('Cached: ', round(torch.cuda.memory_cached(device)/1024**3, 1), 'GB') 144 | print(utils.get_cuda_version()) 145 | print("CuDNN Version ", utils.get_cudnn_version()) 146 | 147 | print("SeqDesign-PyTorch git hash:", str(utils.get_github_head_hash())) 148 | print() 149 | 150 | print("Run:", folder_time) 151 | 152 | ############# 153 | # LOAD DATA # 154 | ############# 155 | 156 | aws_util = aws_utils.AWSUtility(s3_base_path=args.s3_path, s3_project=args.s3_project) if args.s3_path else None 157 | # for now, we will make all the sequences have the same length of 158 | # encoded matrices, though this is wasteful 159 | if os.path.exists(args.dataset): 160 | filenames = [args.dataset] 161 | else: 162 | if args.dataset.endswith(".fa"): 163 | args.dataset = args.dataset[:-len(".fa")] 164 | filenames = glob.glob(f'{working_dir}/datasets/sequences/{args.dataset}*.fa') 165 | if not filenames and aws_util is not None: 166 | if not aws_util.s3_get_file_grep( 167 | s3_folder='datasets/sequences', 168 | dest_folder=f'{working_dir}/datasets/sequences/', 169 | search_pattern=f'{args.dataset}.*\\.fa', 170 | ): 171 | raise Exception("Could not download dataset files from S3.") 172 | filenames = glob.glob(f'{working_dir}/datasets/sequences/{args.dataset}*.fa') 173 | assert len(filenames) == 1 174 | 175 | dataset = data_loaders.SingleFamilyDataset( 176 | batch_size=args.batch_size, 177 | working_dir=working_dir, 178 | dataset=args.dataset, 179 | matching=True, 180 | unlimited_epoch=True, 181 | output_shape='NCHW', 182 | output_types='decoder', 183 | ) 184 | loader = data_loaders.GeneratorDataLoader( 185 | dataset, 186 | num_workers=args.num_data_workers, 187 | pin_memory=True, 188 | worker_init_fn=_init_fn 189 | ) 190 | 191 | ############## 192 | # LOAD MODEL # 193 | ############## 194 | 195 | if args.restore: 196 | print("Restoring model from:", args.restore) 197 | checkpoint = torch.load(args.restore, map_location='cpu' if device.type == 'cpu' else None) 198 | dims = checkpoint['model_dims'] 199 | hyperparams = checkpoint['model_hyperparams'] 200 | trainer_params = checkpoint['train_params'] 201 | model = autoregressive_model.AutoregressiveFR(dims=dims, hyperparams=hyperparams, dropout_p=args.dropout_p) 202 | else: 203 | checkpoint = args.restore 204 | trainer_params = None 205 | dims = {"alphabet": len(dataset.alphabet)} 206 | model = autoregressive_model.AutoregressiveFR(channels=args.channels, dropout_p=args.dropout_p, dims=dims) 207 | model.to(device) 208 | 209 | ################ 210 | # RUN TRAINING # 211 | ################ 212 | 213 | trainer = autoregressive_train.AutoregressiveTrainer( 214 | model=model, 215 | data_loader=loader, 216 | params=trainer_params, 217 | snapshot_path=working_dir + '/sess', 218 | snapshot_name=folder_time, 219 | snapshot_interval=args.num_iterations // 10 if args.snapshot_interval is None else args.snapshot_interval, 220 | snapshot_exec_template=sbatch_executable, 221 | device=device, 222 | # logger=model_logging.Logger(validation_interval=None), 223 | logger=model_logging.TensorboardLogger( 224 | log_interval=500, 225 | validation_interval=None, 226 | generate_interval=None, 227 | log_dir=working_dir + '/log/' + folder_time, 228 | print_output=True, 229 | ) 230 | ) 231 | if args.restore: 232 | trainer.load_state(checkpoint) 233 | 234 | print() 235 | print("Model:", model.__class__.__name__) 236 | print("Hyperparameters:", json.dumps(model.hyperparams, indent=4)) 237 | print("Trainer:", trainer.__class__.__name__) 238 | print("Training parameters:", json.dumps( 239 | {key: value for key, value in trainer.params.items() if key != 'snapshot_exec_template'}, indent=4)) 240 | print("Dataset:", dataset.__class__.__name__) 241 | print("Dataset parameters:", json.dumps(dataset.params, indent=4)) 242 | print("Num trainable parameters:", model.parameter_count()) 243 | print(f"Training for {args.num_iterations - model.step} iterations") 244 | 245 | trainer.save_state() 246 | trainer.train(steps=args.num_iterations) 247 | 248 | print(f"Done! Total run time: {timedelta(seconds=time.time()-start_run_time)}") 249 | log_f.flush() 250 | if aws_util: 251 | aws_util.s3_sync(local_folder=folder, s3_folder=f'sess/{folder_time}/', destination='s3') 252 | 253 | if working_dir != '.': 254 | os.makedirs(f'{working_dir}/complete/', exist_ok=True) 255 | OUTPUT = open(f'{working_dir}/complete/{folder_time}.txt', 'w') 256 | OUTPUT.close() 257 | 258 | 259 | if __name__ == "__main__": 260 | main() 261 | -------------------------------------------------------------------------------- /src/seqdesign_pt/scripts/run_autoregressive_vae_fr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import os 4 | import argparse 5 | import time 6 | import json 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from seqdesign_pt import data_loaders 12 | from seqdesign_pt import autoregressive_model 13 | from seqdesign_pt import autoregressive_train 14 | from seqdesign_pt import model_logging 15 | from seqdesign_pt.utils import get_cuda_version, get_cudnn_version, get_github_head_hash, Tee 16 | 17 | working_dir = '/n/groups/marks/users/aaron/autoregressive' 18 | data_dir = '/n/groups/marks/projects/autoregressive' 19 | 20 | parser = argparse.ArgumentParser(description="Train an autoregressive model on a collection of sequences.") 21 | parser.add_argument("--channels", type=int, default=48, 22 | help="Number of channels.") 23 | parser.add_argument("--num-iterations", type=int, default=250005, 24 | help="Number of iterations to run the model.") 25 | parser.add_argument("--dataset", type=str, default=None, required=True, 26 | help="Dataset name for fitting model. Alignment weights must be computed beforehand.") 27 | parser.add_argument("--num-data-workers", type=int, default=4, 28 | help="Number of workers to load data") 29 | parser.add_argument("--restore", type=str, default=None, 30 | help="Snapshot path for restoring a model to continue training.") 31 | parser.add_argument("--r-seed", type=int, default=42, 32 | help="Random seed") 33 | parser.add_argument("--no-lag-inf", action='store_true', 34 | help="Disable lagging inference") 35 | parser.add_argument("--lag-inf-max-steps", type=int, default=None, 36 | help="Disable lagging inference") 37 | parser.add_argument("--dropout-p", type=float, default=0.5, 38 | help="Decoder dropout probability (drop rate, not keep rate)") 39 | parser.add_argument("--no-cuda", action='store_true', 40 | help="Disable GPU training") 41 | args = parser.parse_args() 42 | 43 | run_name = f"{args.dataset}_VAE_elu_channels-{args.channels}_dropout-{args.dropout_p}_rseed-{args.r_seed}" \ 44 | f"_start-{time.strftime('%y%b%d_%H%M', time.localtime())}" 45 | 46 | sbatch_executable = f"""#!/bin/bash 47 | #SBATCH -c 4 # Request one core 48 | #SBATCH -N 1 # Request one node (if you request more than one core with -c, also using 49 | # -N 1 means all cores will be on the same node) 50 | #SBATCH -t 2-11:59 # Runtime in D-HH:MM format 51 | #SBATCH -p gpu # Partition to run in 52 | #SBATCH --gres=gpu:1 53 | #SBATCH --mem=30G # Memory total in MB (for all cores) 54 | #SBATCH -o slurm_files/slurm-%j.out # File to which STDOUT + STDERR will be written, including job ID in filename 55 | hostname 56 | pwd 57 | module load gcc/6.2.0 cuda/9.0 58 | srun stdbuf -oL -eL {sys.executable} \\ 59 | {sys.argv[0]} \\ 60 | --dataset {args.dataset} --num-iterations {args.num_iterations} \\ 61 | --channels {args.channels} --dropout-p {args.dropout_p} --r-seed {args.r_seed} \\ 62 | --restore {{restore}} 63 | """ 64 | 65 | torch.manual_seed(args.r_seed) 66 | torch.cuda.manual_seed_all(args.r_seed) 67 | 68 | 69 | def _init_fn(worker_id): 70 | np.random.seed(args.r_seed + worker_id) 71 | 72 | 73 | os.makedirs(f'logs/{args.run_name}', exist_ok=True) 74 | log_f = Tee(f'logs/{args.run_name}/log.txt', 'a') 75 | 76 | print("OS: ", sys.platform) 77 | print("Python: ", sys.version) 78 | print("PyTorch: ", torch.__version__) 79 | print("Numpy: ", np.__version__) 80 | 81 | USE_CUDA = not args.no_cuda 82 | device = torch.device("cuda:0" if USE_CUDA and torch.cuda.is_available() else "cpu") 83 | print('Using device:', device) 84 | if device.type == 'cuda': 85 | print(torch.cuda.get_device_name(0)) 86 | print('Memory Usage:') 87 | print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB') 88 | print('Cached: ', round(torch.cuda.memory_cached(0)/1024**3, 1), 'GB') 89 | print(get_cuda_version()) 90 | print("CuDNN Version ", get_cudnn_version()) 91 | 92 | print("git hash:", str(get_github_head_hash())) 93 | print() 94 | 95 | print("Run:", run_name) 96 | 97 | dataset = data_loaders.SingleFamilyDataset( 98 | batch_size=args.batch_size, 99 | working_dir=data_dir, 100 | dataset=args.dataset, 101 | matching=True, 102 | unlimited_epoch=True, 103 | output_shape='NCHW', 104 | output_types='decoder', 105 | ) 106 | loader = data_loaders.GeneratorDataLoader( 107 | dataset, 108 | num_workers=args.num_data_workers, 109 | pin_memory=True, 110 | worker_init_fn=_init_fn 111 | ) 112 | 113 | if args.restore is not None: 114 | print("Restoring model from:", args.restore) 115 | checkpoint = torch.load(args.restore, map_location='cpu' if device.type == 'cpu' else None) 116 | dims = checkpoint['model_dims'] 117 | hyperparams = checkpoint['model_hyperparams'] 118 | trainer_params = checkpoint['train_params'] 119 | model = autoregressive_model.AutoregressiveVAEFR(dims=dims, hyperparams=hyperparams, dropout_p=args.dropout_p) 120 | else: 121 | checkpoint = args.restore 122 | trainer_params = None 123 | model = autoregressive_model.AutoregressiveVAEFR(channels=args.channels, dropout_p=args.dropout_p) 124 | model.to(device) 125 | 126 | trainer = autoregressive_train.AutoregressiveVAETrainer( 127 | model=model, 128 | data_loader=loader, 129 | params=trainer_params, 130 | snapshot_path=working_dir + '/sess', 131 | snapshot_name=run_name, 132 | snapshot_interval=args.num_iterations // 10, 133 | snapshot_exec_template=sbatch_executable, 134 | device=device, 135 | # logger=model_logging.Logger(validation_interval=None), 136 | logger=model_logging.TensorboardLogger( 137 | log_interval=500, 138 | validation_interval=1000, 139 | generate_interval=5000, 140 | log_dir=working_dir + '/log/' + run_name 141 | ) 142 | ) 143 | if args.restore is not None: 144 | trainer.load_state(checkpoint) 145 | if args.no_lag_inf: 146 | trainer.params['lagging_inference'] = False 147 | if args.lag_inf_max_steps is not None: 148 | trainer.params['lag_inf_inner_loop_max_steps'] = args.lag_inf_max_steps 149 | 150 | print("Hyperparameters:", json.dumps(model.hyperparams, indent=4)) 151 | print("Training parameters:", json.dumps(trainer.params, indent=4)) 152 | print("Num trainable parameters:", model.parameter_count()) 153 | 154 | trainer.train(steps=args.num_iterations) 155 | -------------------------------------------------------------------------------- /src/seqdesign_pt/tf_reader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tensorflow.python import pywrap_tensorflow 3 | 4 | 5 | class ConversionError(Exception): 6 | pass 7 | 8 | 9 | class TFReader: 10 | def __init__(self, filepath): 11 | self.filepath = filepath 12 | self.reader = pywrap_tensorflow.NewCheckpointReader(filepath) 13 | self.unused_keys = [] 14 | self.reset_keys() 15 | 16 | def reset_keys(self): 17 | self.unused_keys = list(self.reader.get_variable_to_shape_map().keys()) 18 | self.unused_keys = [key for key in self.unused_keys if not key.startswith("Backprop")] 19 | 20 | def get_checkpoint_legacy_version(self): 21 | if 'Forward/Encoder/DilationBlock1/ConvNet1D/Conv8_3x200/DilatedConvGen8/W' in self.unused_keys: 22 | return 0 23 | else: 24 | return 1 25 | 26 | def load_autoregressive_fr(self, model): 27 | for m, m_name in zip([model.model['model_f'], model.model['model_r']], 28 | ['Forward/', 'Reverse/']): 29 | self.load_autoregressive(m, m_name) 30 | 31 | def load_autoregressive(self, model, model_name): 32 | self.load_conv2d(model.start_conv, model_name + 'EncoderPrepareInput/Features1D/') 33 | 34 | for i_block, block in enumerate(model.dilation_blocks): 35 | block_name = model_name + f'Encoder/DilationBlock{i_block + 1}/ConvNet1D/' 36 | dilation_schedule = block.dilations 37 | 38 | for i_layer, layer in enumerate(block.dilation_layers): 39 | layer_name = block_name + f'Conv{i_layer}_3x{dilation_schedule[i_layer]}/' 40 | 41 | self.load_layer_norm(layer.layernorm_1, layer_name + 'ScaleAndShift/') 42 | self.load_layer_norm(layer.layernorm_2, layer_name + 'ScaleShiftDeltaLayer/ScaleAndShift/') 43 | 44 | self.load_conv2d(layer.mix_conv_1, layer_name + f'Mix1{i_layer}/') 45 | self.load_conv2d(layer.dilated_conv, layer_name + f'DilatedConvGen{i_layer}/') 46 | self.load_conv2d(layer.mix_conv_3, layer_name + f'Mix3{i_layer}/') 47 | 48 | self.load_conv2d(model.end_conv, model_name + 'WriteSequence/conv2D/') 49 | 50 | model.step = self.reader.get_tensor('global_step') 51 | 52 | def load_conv2d(self, layer, layer_name): 53 | self.set_parameter(layer.bias, layer_name + 'b') 54 | self.set_parameter(layer.weight_g, layer_name + 'g') 55 | self.set_parameter(layer.weight_v, layer_name + 'W', permute=(3, 2, 0, 1)) # HWIO to OIHW 56 | 57 | def load_layer_norm(self, layer, layer_name): 58 | self.set_parameter(layer.bias, layer_name + 'b') 59 | self.set_parameter(layer.weight, layer_name + 'g') 60 | 61 | def set_parameter(self, parameter, name='', permute=()): 62 | new_data = self.reader.get_tensor(name) 63 | new_data = torch.as_tensor(new_data, dtype=parameter.dtype, device=parameter.device) 64 | new_data.requires_grad_(True) 65 | 66 | if permute: 67 | new_data = new_data.permute(*permute).contiguous() 68 | 69 | if new_data.shape != parameter.shape: 70 | raise ConversionError('mismatched shapes: {} to {} at {}'.format( 71 | tuple(new_data.shape), tuple(parameter.shape), name)) 72 | 73 | parameter.data = new_data 74 | self.unused_keys.remove(name) 75 | 76 | 77 | if __name__ == '__main__': 78 | import sys 79 | import autoregressive_model 80 | model_test = autoregressive_model.AutoregressiveFR() 81 | reader = TFReader(sys.argv[1]) 82 | reader.load_autoregressive_fr(model_test) 83 | print([key for key in reader.unused_keys if not key.startswith("Backprop")]) 84 | -------------------------------------------------------------------------------- /src/seqdesign_pt/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | import glob 5 | from collections.abc import Mapping 6 | import shutil 7 | import contextlib 8 | 9 | import numpy as np 10 | import git 11 | 12 | 13 | def recursive_update(orig_dict, update_dict): 14 | """Update the contents of orig_dict with update_dict""" 15 | for key, val in update_dict.items(): 16 | if isinstance(val, Mapping): 17 | orig_dict[key] = recursive_update(orig_dict.get(key, {}), val) 18 | else: 19 | orig_dict[key] = val 20 | return orig_dict 21 | 22 | 23 | # https://stackoverflow.com/questions/49555991/can-i-create-a-local-numpy-random-seed 24 | @contextlib.contextmanager 25 | def temp_seed(seed): 26 | state = np.random.get_state() 27 | np.random.seed(seed) 28 | try: 29 | yield 30 | finally: 31 | np.random.set_state(state) 32 | 33 | 34 | # https://github.com/ilkarman/DeepLearningFrameworks/blob/master/notebooks/common/utils.py 35 | def get_gpu_name(): 36 | try: 37 | out_str = subprocess.run(["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"], stdout=subprocess.PIPE).stdout 38 | out_list = out_str.decode("utf-8").split('\n') 39 | out_list = out_list[1:-1] 40 | return out_list 41 | except Exception as e: 42 | print(e) 43 | 44 | 45 | def get_cuda_path(): 46 | nvcc_path = shutil.which('nvcc') 47 | if nvcc_path is not None: 48 | return nvcc_path.replace('bin/nvcc', '') 49 | else: 50 | return None 51 | 52 | 53 | # https://github.com/ilkarman/DeepLearningFrameworks/blob/master/notebooks/common/utils.py 54 | def get_cuda_version(): 55 | """Get CUDA version""" 56 | path = get_cuda_path() 57 | if path is not None and os.path.isfile(path + 'version.txt'): 58 | with open(path + 'version.txt', 'r') as f: 59 | data = f.read().replace('\n', '') 60 | return data 61 | else: 62 | return "No CUDA in this machine" 63 | 64 | 65 | # https://github.com/ilkarman/DeepLearningFrameworks/blob/master/notebooks/common/utils.py 66 | def get_cudnn_version(): 67 | """Get CUDNN version""" 68 | if sys.platform == 'win32': 69 | raise NotImplementedError("Implement this!") 70 | # This breaks on linux: 71 | # cuda=!ls "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA" 72 | # candidates = ["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\" + str(cuda[0]) +"\\include\\cudnn.h"] 73 | elif sys.platform == 'linux': 74 | candidates = ['/usr/include/x86_64-linux-gnu/cudnn_v[0-99].h', 75 | '/usr/local/cuda/include/cudnn.h', 76 | '/usr/include/cudnn.h'] 77 | elif sys.platform == 'darwin': 78 | candidates = ['/usr/local/cuda/include/cudnn.h', 79 | '/usr/include/cudnn.h'] 80 | else: 81 | raise ValueError("Not in Windows, Linux or Mac") 82 | cuda_path = get_cuda_path() 83 | if cuda_path is not None: 84 | candidates.append(cuda_path + 'include/cudnn.h') 85 | file = None 86 | for c in candidates: 87 | file = glob.glob(c) 88 | if file: 89 | break 90 | if file: 91 | with open(file[0], 'r') as f: 92 | version = '' 93 | for line in f: 94 | if "#define CUDNN_MAJOR" in line: 95 | version = line.split()[-1] 96 | if "#define CUDNN_MINOR" in line: 97 | version += '.' + line.split()[-1] 98 | if "#define CUDNN_PATCHLEVEL" in line: 99 | version += '.' + line.split()[-1] 100 | if version: 101 | return version 102 | else: 103 | return "Cannot find CUDNN version" 104 | else: 105 | return "No CUDNN in this machine" 106 | 107 | 108 | def get_github_head_hash(): 109 | """ 110 | https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script 111 | :return: 112 | """ 113 | try: 114 | repo = git.Repo(search_parent_directories=True) 115 | return repo.head.object.hexsha 116 | except git.GitError: 117 | return None 118 | 119 | 120 | class Tee(object): 121 | """Duplicate output to file and sys.stdout 122 | From https://stackoverflow.com/questions/616645/how-to-duplicate-sys-stdout-to-a-log-file""" 123 | 124 | def __init__(self, name, mode): 125 | self.file = open(name, mode) 126 | self.stdout = sys.stdout 127 | sys.stdout = self 128 | 129 | def __del__(self): 130 | self.close() 131 | 132 | def close(self): 133 | sys.stdout = self.stdout 134 | self.file.close() 135 | 136 | def write(self, data): 137 | self.file.write(data) 138 | self.stdout.write(data) 139 | 140 | def flush(self): 141 | self.file.flush() 142 | self.stdout.flush() 143 | 144 | def __enter__(self): 145 | pass 146 | 147 | def __exit__(self, _type, _value, _traceback): 148 | self.close() 149 | -------------------------------------------------------------------------------- /src/seqdesign_pt/version.py: -------------------------------------------------------------------------------- 1 | VERSION = 'v3-pt' 2 | --------------------------------------------------------------------------------