├── contextual_lenses ├── __init__.py ├── cloud_utils │ ├── __init__.py │ ├── cpu_init.sh │ ├── gpu_init.sh │ └── tpu_init.py ├── resources │ ├── __init__.py │ ├── reduce_fn_kwargs_resources │ │ ├── pool.json │ │ ├── linear_pool_256.json │ │ ├── linear_pool_512.json │ │ └── linear_pool_1024.json │ └── encoder_fn_kwargs_resources │ │ ├── 1-layer_cnn_kwargs.json │ │ ├── large_transformer_kwargs.json │ │ ├── medium_transformer_kwargs.json │ │ ├── small_transformer_kwargs.json │ │ └── 2-layer_cnn_kwargs.json ├── loss_fns.py ├── load_transformer.py ├── test │ ├── test_loading.py │ ├── test_pooling.py │ └── test_learning.py ├── encoders.py ├── blast_baseline.py ├── contextual_lenses.py ├── pfam_utils.py └── train_utils.py ├── .gitmodules ├── figures ├── 1-sample_test_knn_accuracy.png ├── 5-sample_test_knn_accuracy.png ├── 10-sample_test_knn_accuracy.png └── 50-sample_test_knn_accuracy.png ├── demo.json ├── docs ├── contributing.md └── code-of-conduct.md ├── setup.py ├── LICENSE ├── README.md ├── pfam_experiment.py └── generate_params.py /contextual_lenses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /contextual_lenses/cloud_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /contextual_lenses/resources/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /contextual_lenses/resources/reduce_fn_kwargs_resources/pool.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /contextual_lenses/resources/reduce_fn_kwargs_resources/linear_pool_256.json: -------------------------------------------------------------------------------- 1 | {"rep_size": 256} -------------------------------------------------------------------------------- /contextual_lenses/resources/reduce_fn_kwargs_resources/linear_pool_512.json: -------------------------------------------------------------------------------- 1 | {"rep_size": 512} -------------------------------------------------------------------------------- /contextual_lenses/resources/reduce_fn_kwargs_resources/linear_pool_1024.json: -------------------------------------------------------------------------------- 1 | {"rep_size": 1024} -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "google-research"] 2 | path = google_research 3 | url = https://github.com/google-research/google-research.git -------------------------------------------------------------------------------- /figures/1-sample_test_knn_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/googleinterns/protein-embedding-retrieval/HEAD/figures/1-sample_test_knn_accuracy.png -------------------------------------------------------------------------------- /figures/5-sample_test_knn_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/googleinterns/protein-embedding-retrieval/HEAD/figures/5-sample_test_knn_accuracy.png -------------------------------------------------------------------------------- /figures/10-sample_test_knn_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/googleinterns/protein-embedding-retrieval/HEAD/figures/10-sample_test_knn_accuracy.png -------------------------------------------------------------------------------- /figures/50-sample_test_knn_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/googleinterns/protein-embedding-retrieval/HEAD/figures/50-sample_test_knn_accuracy.png -------------------------------------------------------------------------------- /contextual_lenses/resources/encoder_fn_kwargs_resources/1-layer_cnn_kwargs.json: -------------------------------------------------------------------------------- 1 | {"n_layers": 1, "n_features": [1024], "n_kernel_sizes": [12], "n_kernel_dilations": null} -------------------------------------------------------------------------------- /contextual_lenses/resources/encoder_fn_kwargs_resources/large_transformer_kwargs.json: -------------------------------------------------------------------------------- 1 | {"emb_dim": 512, "num_heads": 8, "qkv_dim": 512, "mlp_dim": 1024, "num_layers": 36} -------------------------------------------------------------------------------- /contextual_lenses/resources/encoder_fn_kwargs_resources/medium_transformer_kwargs.json: -------------------------------------------------------------------------------- 1 | {"emb_dim": 512, "num_heads": 8, "qkv_dim": 512, "mlp_dim": 2048, "num_layers": 6} -------------------------------------------------------------------------------- /contextual_lenses/resources/encoder_fn_kwargs_resources/small_transformer_kwargs.json: -------------------------------------------------------------------------------- 1 | {"emb_dim": 256, "num_heads": 4, "qkv_dim": 256, "mlp_dim": 1024, "num_layers": 2} -------------------------------------------------------------------------------- /contextual_lenses/cloud_utils/cpu_init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install --upgrade pip 4 | pip install --upgrade jax jaxlib flax 5 | export TF_FORCE_GPU_ALLOW_GROWTH=true -------------------------------------------------------------------------------- /contextual_lenses/resources/encoder_fn_kwargs_resources/2-layer_cnn_kwargs.json: -------------------------------------------------------------------------------- 1 | {"n_layers": 2, "n_features": [1024, 1024], "n_kernel_sizes": [5, 5], "n_kernel_dilations": null} -------------------------------------------------------------------------------- /contextual_lenses/cloud_utils/gpu_init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pip install --upgrade pip 4 | 5 | PYTHON_VERSION=cp37 6 | CUDA_VERSION=cuda101 7 | PLATFORM=linux_x86_64 8 | BASE_URL='https://storage.googleapis.com/jax-releases' 9 | 10 | pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.48-$PYTHON_VERSION-none-$PLATFORM.whl 11 | pip install --upgrade jax flax 12 | 13 | export TF_FORCE_GPU_ALLOW_GROWTH=true -------------------------------------------------------------------------------- /contextual_lenses/loss_fns.py: -------------------------------------------------------------------------------- 1 | """Loss functions 2 | 3 | Jax loss functions for computing gradient updates. 4 | """ 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from jax.config import config 9 | config.enable_omnistaging() 10 | 11 | 12 | def mse_loss(Y, Y_hat): 13 | """Squeezes predictions and returns MSE loss.""" 14 | 15 | if len(Y_hat.shape) > 1: 16 | Y_hat = jnp.squeeze(Y_hat, axis=1) 17 | 18 | loss = jnp.mean(jnp.square(Y - Y_hat)) 19 | 20 | return loss 21 | 22 | 23 | def cross_entropy_loss(Y, Y_hat, num_classes): 24 | """Applies log-softmax to predictions and one-hot encodes true values 25 | to compute and return cross-entropy loss. 26 | """ 27 | 28 | Y_hat = jax.nn.log_softmax(Y_hat) 29 | 30 | Y = jax.nn.one_hot(Y, num_classes=num_classes) 31 | 32 | loss = -jnp.sum(Y * Y_hat) 33 | 34 | return loss 35 | -------------------------------------------------------------------------------- /contextual_lenses/cloud_utils/tpu_init.py: -------------------------------------------------------------------------------- 1 | """Function to connect VM instance to Cloud TPU.""" 2 | 3 | 4 | import requests 5 | import yaml 6 | import subprocess as sp 7 | from jax.config import config 8 | 9 | 10 | def connect_tpu(tpu_name=None): 11 | """Runs necessary commands to connect VM to Cloud TPU.""" 12 | 13 | if tpu_name is not None: 14 | command = 'gcloud compute tpus describe ' + tpu_name 15 | 16 | output = sp.getoutput(command) 17 | output = yaml.load(stream=output, Loader=yaml.SafeLoader) 18 | 19 | ip_address = output['ipAddress'] 20 | port = output['port'] 21 | 22 | url = 'http://' + ip_address + ':8475/requestversion/tpu_driver_nightly' 23 | requests.post(url) 24 | 25 | config.FLAGS.jax_xla_backend = 'tpu_driver' 26 | config.FLAGS.jax_backend_target = "grpc://" + ip_address + ':' + port 27 | 28 | print('Successfully connected to TPU named \"' + tpu_name + '\"!') 29 | -------------------------------------------------------------------------------- /demo.json: -------------------------------------------------------------------------------- 1 | {"encoder_fn_kwargs_path": "small_transformer_kwargs", "restore_transformer_dir": "gs://sequin-public/transformer_models/small_trembl_bert/", "lens_shuffle_seed": 0, "label": "demo", "train_families": 10000, "data_partitions_dirpath": "random_split/", "last_test_family": 16000, "load_model_step": 0, "knn_sample_random_state": 1, "predictor_wd": 0.2, "lens_wd": 0.2, "save_model": true, "model_random_key": 0, "encoder_wd": 0.0, "use_bert": true, "first_test_family": 15000, "load_model_dir": "", "save_gcs_bucket": "neuralblast_public", "predictor_lr": 0.001, "knn_shuffle_seed": 1, "load_model": false, "load_gcs_bucket": "neuralblast_public", "knn_batch_size": 64, "reduce_fn_name": "linear_max_pool", "epochs": 10, "results_save_dir": "demo/", "lens_train_samples": 50, "measurements": 1, "reduce_fn_kwargs_path": "linear_pool_1024", "save_model_dir": "demo/model/", "lens_sample_random_state": 0, "use_transformer": true, "lens_lr": 0.0001, "encoder_lr": 0.0, "encoder_fn_name": "transformer", "lens_batch_size": 64} -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /contextual_lenses/load_transformer.py: -------------------------------------------------------------------------------- 1 | """Utils for loading FlaxLM or FlaxBERT models or parameters.""" 2 | 3 | import os 4 | 5 | import tensorflow as tf 6 | 7 | import gin 8 | 9 | from protein_lm import data, models 10 | 11 | 12 | def load_transformer_model(ckpt_dir, model_cls, domain=None): 13 | """Loads a model from directory.""" 14 | 15 | if domain is None: 16 | domain = data.protein_domain 17 | 18 | config_path = os.path.join(ckpt_dir, 'config.gin') 19 | with gin.config_scope('load_model'): 20 | with tf.io.gfile.GFile(config_path) as f: 21 | gin.parse_config(f, skip_unknown=True) 22 | model = model_cls(domain=domain) 23 | model.load_checkpoint(ckpt_dir) 24 | 25 | return model 26 | 27 | 28 | def load_transformer_params(ckpt_dir, model_cls, domain=None): 29 | """Returns parameters of a loaded model.""" 30 | 31 | model = load_transformer_model(ckpt_dir, model_cls, domain=domain) 32 | params = models.jax_utils.unreplicate(model._optimizer.target).params 33 | 34 | return params 35 | 36 | 37 | def load_transformer_encoder(ckpt_dir, model_cls, domain=None): 38 | """Returns parameters of a loaded model.""" 39 | 40 | model = load_transformer_model(ckpt_dir, model_cls, domain=domain) 41 | encoder = models.jax_utils.unreplicate(model._optimizer.target) 42 | 43 | return encoder 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # This follows the style of Jaxlib installation here: 4 | # https://github.com/google/jax#pip-installation 5 | PYTHON_VERSION = "cp37" 6 | CUDA_VERSION = "cuda101" # alternatives: cuda90, cuda92, cuda100 7 | PLATFORM = "manylinux2010_x86_64" # alternatives: linux_x86_64 8 | BASE_URL = "https://storage.googleapis.com/jax-releases" 9 | 10 | 11 | def jax_artifact(version, gpu=False): 12 | if gpu: 13 | prefix = f"{BASE_URL}/{CUDA_VERSION}/jaxlib" 14 | wheel_suffix = f"{PYTHON_VERSION}-none-{PLATFORM}.whl" 15 | location = f"{prefix}-{version}-{wheel_suffix}" 16 | return f"jaxlib @ {location}" 17 | return f"jaxlib=={version}" 18 | 19 | def readme(): 20 | try: 21 | with open('README.md') as rf: 22 | return rf.read() 23 | except FileNotFoundError: 24 | return None 25 | 26 | JAXLIB_VERSION = "0.1.52" 27 | JAX_VERSION = "0.1.75" 28 | FLAX_VERSION = "0.2.0" 29 | 30 | REQUIRED_PACKAGES = [ 31 | "tensorflow", 32 | "pandas", 33 | "numpy", 34 | "gin-config", 35 | "sklearn", 36 | "dm-tree", 37 | "fs", 38 | "fs-gcsfs", 39 | f"jax=={JAX_VERSION}", 40 | f"flax=={FLAX_VERSION}" 41 | ] 42 | 43 | setup(name='contextual_lenses', 44 | version='1.0', 45 | description='Protein contextual lenses.', 46 | long_description=readme(), 47 | author='Amir Shanehsazzadeh', 48 | author_email='amirshanehsaz@google.com', 49 | packages=find_packages(exclude=('docs')), 50 | install_requires=REQUIRED_PACKAGES, 51 | extras_require={ 52 | "cpu": [jax_artifact(JAXLIB_VERSION, gpu=False)], 53 | "gpu": [jax_artifact(JAXLIB_VERSION, gpu=True)], 54 | }) 55 | -------------------------------------------------------------------------------- /contextual_lenses/test/test_loading.py: -------------------------------------------------------------------------------- 1 | """Tests for saving and restoring checkpoints.""" 2 | 3 | 4 | import os 5 | 6 | from absl.testing import parameterized 7 | from absl.testing import absltest 8 | 9 | from flax.training import checkpoints 10 | 11 | from contextual_lenses import mean_pool, max_pool 12 | 13 | from train_utils import create_optimizer, create_representation_model 14 | 15 | from encoders import one_hot_encoder 16 | 17 | 18 | # Test cases: 19 | test1 = { 20 | 'encoder_fn': one_hot_encoder, 21 | 'encoder_fn_kwargs' : { 22 | 23 | }, 24 | 'reduce_fn': mean_pool, 25 | 'reduce_fn_kwargs': { 26 | 27 | } 28 | } 29 | 30 | test2 = { 31 | 'encoder_fn': one_hot_encoder, 32 | 'encoder_fn_kwargs' : { 33 | 34 | }, 35 | 'reduce_fn': max_pool, 36 | 'reduce_fn_kwargs': { 37 | 38 | } 39 | } 40 | 41 | tests = (test1, test2) 42 | 43 | 44 | class TestLoading(parameterized.TestCase): 45 | """Abstract method for testing saving and restoring optimizer checkpoints.""" 46 | 47 | @parameterized.parameters( 48 | *tests 49 | ) 50 | def test_loading(self, encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs): 51 | 52 | model = create_representation_model(encoder_fn=encoder_fn, 53 | encoder_fn_kwargs=encoder_fn_kwargs, 54 | reduce_fn=reduce_fn, 55 | reduce_fn_kwargs=reduce_fn_kwargs, 56 | num_categories=21, 57 | output_features=1) 58 | 59 | optimizer = create_optimizer(model, learning_rate=1e-3, weight_decay=0.) 60 | 61 | params = optimizer.target.params 62 | step = optimizer.state.step 63 | 64 | os.mkdir('tmp/') 65 | 66 | checkpoints.save_checkpoint(ckpt_dir='tmp/', target=optimizer, step=optimizer.state.step) 67 | 68 | loaded_optimizer = checkpoints.restore_checkpoint(ckpt_dir='tmp/', target=optimizer) 69 | 70 | loaded_params = loaded_optimizer.target.params 71 | loaded_step = loaded_optimizer.state.step 72 | 73 | os.system('rm -rf tmp/') 74 | 75 | self.assertTrue(step==loaded_step) 76 | 77 | self.assertTrue(list(params.keys())==list(loaded_params.keys())) 78 | for key in params.keys(): 79 | sub_params = params[key] 80 | loaded_sub_params = loaded_params[key] 81 | self.assertTrue(list(sub_params.keys())==list(loaded_sub_params.keys())) 82 | for sub_key in sub_params.keys(): 83 | self.assertTrue((sub_params[sub_key]==loaded_sub_params[sub_key]).all()) 84 | 85 | 86 | if __name__ == '__main__': 87 | absltest.main() 88 | 89 | -------------------------------------------------------------------------------- /contextual_lenses/encoders.py: -------------------------------------------------------------------------------- 1 | """Encoder functions 2 | 3 | Fixed and learnable transformations for embedding sequences. 4 | """ 5 | 6 | import flax 7 | from flax import nn 8 | 9 | import jax 10 | from jax import lax 11 | import jax.nn 12 | import jax.numpy as jnp 13 | from jax.config import config 14 | config.enable_omnistaging() 15 | 16 | import numpy as np 17 | 18 | from operator import itemgetter 19 | 20 | 21 | def one_hot_encoder(batch_inds, num_categories): 22 | """Applies one-hot encoding from jax.nn.""" 23 | 24 | one_hots = jax.nn.one_hot(batch_inds, num_classes=num_categories) 25 | 26 | return one_hots 27 | 28 | 29 | class CNN(nn.Module): 30 | """A simple 1D CNN model.""" 31 | def apply(self, x, n_layers, n_features, n_kernel_sizes, 32 | n_kernel_dilations): 33 | 34 | if n_kernel_dilations is None: 35 | n_kernel_dilations = [1] * n_layers 36 | 37 | x = jnp.expand_dims(x, axis=2) 38 | 39 | for layer in range(n_layers): 40 | features = n_features[layer] 41 | kernel_size = (n_kernel_sizes[layer], 1) 42 | kernel_dilation = (n_kernel_dilations[layer], 1) 43 | x = nn.Conv(x, 44 | features=features, 45 | kernel_size=kernel_size, 46 | kernel_dilation=kernel_dilation) 47 | x = nn.relu(x) 48 | 49 | x = jnp.squeeze(x, axis=2) 50 | 51 | return x 52 | 53 | 54 | def cnn_one_hot_encoder(batch_inds, 55 | num_categories, 56 | n_layers, 57 | n_features, 58 | n_kernel_sizes, 59 | n_kernel_dilations=None): 60 | """Applies one-hot encoding followed by 1D CNN.""" 61 | 62 | one_hots = one_hot_encoder(batch_inds, num_categories) 63 | cnn_one_hots = CNN(one_hots, n_layers, n_features, n_kernel_sizes, 64 | n_kernel_dilations) 65 | 66 | return cnn_one_hots 67 | 68 | 69 | def encoder_fn_name_to_fn(encoder_fn_name): 70 | """Returns encoder_fn corresponding to encoder_fn_name.""" 71 | 72 | if encoder_fn_name is None or encoder_fn_name == 'transformer': 73 | encoder_fn = None 74 | elif encoder_fn_name == 'one_hot': 75 | encoder_fn = one_hot_encoder 76 | elif encoder_fn_name == 'cnn_one_hot': 77 | encoder_fn = cnn_one_hot_encoder 78 | elif encoder_fn_name == 'one_hot_pos_emb': 79 | encoder_fn = one_hot_pos_emb_encoder 80 | elif encoder_fn_name == 'cnn_one_hot_pos_emb': 81 | encoder_fn = cnn_one_hot_pos_emb_encoder 82 | else: 83 | raise ValueError('Incorrect encoder name specified.') 84 | 85 | return encoder_fn 86 | -------------------------------------------------------------------------------- /docs/code-of-conduct.md: -------------------------------------------------------------------------------- 1 | # Google Open Source Community Guidelines 2 | 3 | At Google, we recognize and celebrate the creativity and collaboration of open 4 | source contributors and the diversity of skills, experiences, cultures, and 5 | opinions they bring to the projects and communities they participate in. 6 | 7 | Every one of Google's open source projects and communities are inclusive 8 | environments, based on treating all individuals respectfully, regardless of 9 | gender identity and expression, sexual orientation, disabilities, 10 | neurodiversity, physical appearance, body size, ethnicity, nationality, race, 11 | age, religion, or similar personal characteristic. 12 | 13 | We value diverse opinions, but we value respectful behavior more. 14 | 15 | Respectful behavior includes: 16 | 17 | * Being considerate, kind, constructive, and helpful. 18 | * Not engaging in demeaning, discriminatory, harassing, hateful, sexualized, or 19 | physically threatening behavior, speech, and imagery. 20 | * Not engaging in unwanted physical contact. 21 | 22 | Some Google open source projects [may adopt][] an explicit project code of 23 | conduct, which may have additional detailed expectations for participants. Most 24 | of those projects will use our [modified Contributor Covenant][]. 25 | 26 | [may adopt]: https://opensource.google/docs/releasing/preparing/#conduct 27 | [modified Contributor Covenant]: https://opensource.google/docs/releasing/template/CODE_OF_CONDUCT/ 28 | 29 | ## Resolve peacefully 30 | 31 | We do not believe that all conflict is necessarily bad; healthy debate and 32 | disagreement often yields positive results. However, it is never okay to be 33 | disrespectful. 34 | 35 | If you see someone behaving disrespectfully, you are encouraged to address the 36 | behavior directly with those involved. Many issues can be resolved quickly and 37 | easily, and this gives people more control over the outcome of their dispute. 38 | If you are unable to resolve the matter for any reason, or if the behavior is 39 | threatening or harassing, report it. We are dedicated to providing an 40 | environment where participants feel welcome and safe. 41 | 42 | ## Reporting problems 43 | 44 | Some Google open source projects may adopt a project-specific code of conduct. 45 | In those cases, a Google employee will be identified as the Project Steward, 46 | who will receive and handle reports of code of conduct violations. In the event 47 | that a project hasn’t identified a Project Steward, you can report problems by 48 | emailing opensource@google.com. 49 | 50 | We will investigate every complaint, but you may not receive a direct response. 51 | We will use our discretion in determining when and how to follow up on reported 52 | incidents, which may range from not taking action to permanent expulsion from 53 | the project and project-sponsored spaces. We will notify the accused of the 54 | report and provide them an opportunity to discuss it before any action is 55 | taken. The identity of the reporter will be omitted from the details of the 56 | report supplied to the accused. In potentially harmful situations, such as 57 | ongoing harassment or threats to anyone's safety, we may take action without 58 | notice. 59 | 60 | *This document was adapted from the [IndieWeb Code of Conduct][] and can also 61 | be found at .* 62 | 63 | [IndieWeb Code of Conduct]: https://indieweb.org/code-of-conduct 64 | -------------------------------------------------------------------------------- /contextual_lenses/blast_baseline.py: -------------------------------------------------------------------------------- 1 | """Computes accuracy of 1 nearest neighbor classification using BLAST. 2 | 3 | Example usage: 4 | blast_baseline.py \ 5 | --train_file=./resources/knn_data/5-samples_train_knn_data_families_15001-16000.csv \ 6 | --test_file=./resources/knn_data/test_knn_data_families_15001-16000.csv 7 | """ 8 | 9 | import os 10 | import subprocess 11 | import tempfile 12 | 13 | from absl import app 14 | from absl import flags 15 | import numpy as np 16 | import pandas as pd 17 | 18 | # Need to install locally from https://github.com/google-research/proteinfer 19 | from proteinfer import baseline_utils 20 | 21 | # We also assume that blast is installed locally. 22 | # sudo apt-get install ncbi-blast+ 23 | 24 | 25 | flags.DEFINE_string('train_file', '', 'Input train csv file.') 26 | flags.DEFINE_string('test_file', '', 'Input test csv file.') 27 | 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | 32 | _BLAST_FLAGS = '-outfmt 6 -max_hsps 1 -num_threads 10 -num_alignments 1' 33 | 34 | 35 | def _get_header(row): 36 | accession = row.accession.replace('/', '_').replace('-', '_') 37 | return '>accession="%s"\tlabels="%s"' % (accession, row.label) 38 | 39 | 40 | def _get_fasta_entry(row): 41 | header = _get_header(row) 42 | return '\n'.join([header, row.sequence]) 43 | 44 | 45 | def _write_fasta(df, output_file): 46 | entries = df.apply(_get_fasta_entry, axis=1) 47 | with open(output_file, 'w') as file: 48 | file.write('\n'.join(entries)) 49 | 50 | 51 | def _run_cmd(cmd_string): 52 | subprocess.run(cmd_string.split(' '), check=True) 53 | 54 | 55 | class BlastClassifier(object): 56 | """Stateful wrapper for BLAST system calls.""" 57 | 58 | def __init__(self, df): 59 | _, self._train_fasta = tempfile.mkstemp() 60 | _, self._blast_db = tempfile.mkstemp() 61 | _write_fasta(df, self._train_fasta) 62 | print(self._train_fasta) 63 | self._train_df = baseline_utils.load_ground_truth(self._train_fasta) 64 | self._label_vocab = df.label.unique() 65 | cmd = 'makeblastdb -in %s -dbtype prot -out %s' % (self._train_fasta, 66 | self._blast_db) 67 | _run_cmd(cmd) 68 | 69 | def __del__(self): 70 | os.remove(self._train_fasta) 71 | os.remove(self._blast_db) 72 | 73 | def predict(self, df): 74 | """Predicts labels by propagating labels from BLAST top hit.""" 75 | 76 | _, query_fasta = tempfile.mkstemp() 77 | _, blast_output = tempfile.mkstemp() 78 | _write_fasta(df, query_fasta) 79 | cmd = 'blastp -query %s -db %s %s -out %s' % (query_fasta, self._blast_db, 80 | _BLAST_FLAGS, blast_output) 81 | _run_cmd(cmd) 82 | 83 | assert df.label.isin(self._label_vocab).all() 84 | query_df = baseline_utils.load_ground_truth(query_fasta) 85 | results_df = baseline_utils.load_blast_output(blast_output, 86 | self._label_vocab, 87 | self._train_df, 88 | query_df) 89 | 90 | os.remove(query_fasta) 91 | os.remove(blast_output) 92 | return results_df 93 | 94 | 95 | def _get_label(label_set): 96 | if label_set: 97 | assert len(label_set) == 1 98 | return next(iter(label_set)) 99 | else: 100 | return None 101 | 102 | 103 | def _compute_accuracy(df): 104 | prediction = df.predicted_label.apply(_get_label) 105 | true_label = df.true_label.apply(_get_label) 106 | return np.mean(prediction == true_label) 107 | 108 | 109 | def _load(filename): 110 | df = pd.read_csv(filename) 111 | df.rename(columns=dict(sequence_name='accession'), inplace=True) 112 | return df 113 | 114 | 115 | def main(argv): 116 | if len(argv) > 1: 117 | raise app.UsageError('Too many command-line arguments.') 118 | 119 | train_df = _load(FLAGS.train_file) 120 | test_df = _load(FLAGS.test_file) 121 | 122 | blast_classifier = BlastClassifier(df=train_df) 123 | output_df = blast_classifier.predict(test_df) 124 | 125 | accuracy = _compute_accuracy(output_df) 126 | print('Accuracy = %f' % accuracy) 127 | 128 | 129 | if __name__ == '__main__': 130 | app.run(main) 131 | -------------------------------------------------------------------------------- /contextual_lenses/contextual_lenses.py: -------------------------------------------------------------------------------- 1 | """Contextual lenses 2 | 3 | Creates sequence length independent representation of embedded sequences 4 | Original paper: https://arxiv.org/pdf/2002.08866.pdf. 5 | """ 6 | 7 | import flax 8 | from flax import nn 9 | 10 | import jax 11 | import jax.nn 12 | import jax.numpy as jnp 13 | from jax.config import config 14 | config.enable_omnistaging() 15 | 16 | import numpy as np 17 | 18 | from operator import itemgetter 19 | 20 | 21 | def max_pool(x, padding_mask=None, pad_constant=1e8): 22 | """Apply padding, take maximum over sequence length axis.""" 23 | 24 | if padding_mask is not None: 25 | x = x * padding_mask 26 | neg_mask = -pad_constant * (1 - padding_mask) 27 | x = x + neg_mask 28 | 29 | rep = jnp.max(x, axis=-2) 30 | 31 | return rep 32 | 33 | 34 | def mean_pool(x, padding_mask=None): 35 | """Apply padding, take mean over sequence length axis.""" 36 | 37 | if padding_mask is not None: 38 | x = x * padding_mask 39 | rep = jnp.sum(x, axis=-2) / jnp.sum(padding_mask, axis=-2) 40 | else: 41 | rep = jnp.mean(x, axis=-2) 42 | 43 | return rep 44 | 45 | 46 | def linear_max_pool(x, rep_size, padding_mask=None): 47 | """Apply linear transformation + ReLU, apply padding, 48 | take maximum over sequence length. 49 | """ 50 | 51 | x = nn.Dense(x, 52 | rep_size, 53 | kernel_init=nn.initializers.xavier_uniform(), 54 | bias_init=nn.initializers.normal(stddev=1e-6)) 55 | 56 | x = nn.relu(x) 57 | 58 | rep = max_pool(x, padding_mask=padding_mask) 59 | 60 | return rep 61 | 62 | 63 | def linear_mean_pool(x, rep_size, padding_mask=None): 64 | """Apply linear transformation + ReLU, apply padding, 65 | take mean over sequence length. 66 | """ 67 | 68 | x = nn.Dense(x, 69 | rep_size, 70 | kernel_init=nn.initializers.xavier_uniform(), 71 | bias_init=nn.initializers.normal(stddev=1e-6)) 72 | 73 | x = nn.relu(x) 74 | 75 | rep = mean_pool(x, padding_mask=padding_mask) 76 | 77 | return rep 78 | 79 | 80 | class GatedConv(nn.Module): 81 | """Gated Convolutional lens followed by max pooling, 82 | see original paper for details. 83 | """ 84 | def apply(self, 85 | x, 86 | rep_size, 87 | m_layers, 88 | m_features, 89 | m_kernel_sizes, 90 | conv_rep_size, 91 | padding_mask=None): 92 | 93 | H_0 = nn.relu(nn.Dense(x, conv_rep_size)) 94 | G_0 = nn.relu(nn.Dense(x, conv_rep_size)) 95 | H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2) 96 | 97 | for layer in range(1, m_layers + 1): 98 | 99 | if layer < m_layers: 100 | H_features, G_features = m_features[layer - 1] 101 | else: 102 | H_features, G_features = conv_rep_size, conv_rep_size 103 | 104 | H_kernel_size, G_kernel_size = m_kernel_sizes[layer - 1] 105 | 106 | H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1)) 107 | G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1)) 108 | 109 | if layer < m_layers: 110 | H = nn.relu(H) 111 | G = nn.relu(G) 112 | else: 113 | H = nn.tanh(H) 114 | G = nn.sigmoid(G) 115 | 116 | H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2) 117 | 118 | F = H * G + G_0 119 | 120 | rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size) 121 | 122 | return rep 123 | 124 | 125 | def gated_conv(x, 126 | rep_size, 127 | m_layers, 128 | m_features, 129 | m_kernel_sizes, 130 | conv_rep_size, 131 | padding_mask=None): 132 | """Calls GatedConv method for use as a lens.""" 133 | 134 | rep = GatedConv(x, 135 | rep_size=rep_size, 136 | m_features=m_features, 137 | m_layers=m_layers, 138 | m_kernel_sizes=m_kernel_sizes, 139 | conv_rep_size=conv_rep_size, 140 | padding_mask=padding_mask) 141 | 142 | return rep 143 | 144 | 145 | def reduce_fn_name_to_fn(reduce_fn_name): 146 | """Returns reduce_fn corresponding to reduce_fn_name.""" 147 | 148 | if reduce_fn_name == 'mean_pool': 149 | reduce_fn = mean_pool 150 | elif reduce_fn_name == 'max_pool': 151 | reduce_fn = max_pool 152 | elif reduce_fn_name == 'linear_mean_pool': 153 | reduce_fn = linear_mean_pool 154 | elif reduce_fn_name == 'linear_max_pool': 155 | reduce_fn = linear_max_pool 156 | elif reduce_fn_name == 'gated_conv': 157 | reduce_fn = gated_conv 158 | else: 159 | raise ValueError('Incorrect lens name specified.') 160 | 161 | return reduce_fn 162 | -------------------------------------------------------------------------------- /contextual_lenses/test/test_pooling.py: -------------------------------------------------------------------------------- 1 | """Tests for mean and max pool reduce functions.""" 2 | 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | import numpy as np 8 | 9 | from absl.testing import parameterized 10 | from absl.testing import absltest 11 | 12 | from contextual_lenses import mean_pool, max_pool 13 | 14 | 15 | # Test cases: 16 | test1 = { 17 | 'x': jnp.array([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], 18 | [[0, -1, 2, -3], [4, -5, 6, -7], [8, -9, 10, -11]]]), 19 | 'padding_mask' : None, 20 | 'mean_pool_rep': jnp.array([[4, 5, 6, 7], 21 | [4, -5, 6, -7]]), 22 | 'max_pool_rep': jnp.array([[8, 9, 10, 11], 23 | [8, -1, 10, -3]]) 24 | } 25 | 26 | test2 = { 27 | 'x': jnp.array([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], 28 | [[0, -1, 2, -3], [4, -5, 6, -7], [8, -9, 10, -11]]]), 29 | 'padding_mask' : jnp.array([[[1], [1], [1]] 30 | [1], [1], [1]]), 31 | 'mean_pool_rep': jnp.array([[4, 5, 6, 7], 32 | [4, -5, 6, -7]]), 33 | 'max_pool_rep': jnp.array([[8, 9, 10, 11], 34 | [8, -1, 10, -3]]) 35 | } 36 | 37 | test3 = { 38 | 'x': jnp.array([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], 39 | [[0, -1, 2, -3], [4, -5, 6, -7], [8, -9, 10, -11]]]), 40 | 'padding_mask' : jnp.array([[[1], [1], [1]], 41 | [[1], [1], [0]]]), 42 | 'mean_pool_rep': jnp.array([[4, 5, 6, 7], 43 | [2, -3, 4, -5]]), 44 | 'max_pool_rep': jnp.array([[8, 9, 10, 11], 45 | [4, -1, 6, -3]]) 46 | } 47 | 48 | test4 = { 49 | 'x': jnp.array([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], 50 | [[0, -1, 2, -3], [4, -5, 6, -7], [8, -9, 10, -11]]]), 51 | 'padding_mask' : jnp.array([[[1], [0], [1]], 52 | [[0], [1], [0]]]), 53 | 'mean_pool_rep': jnp.array([[4, 5, 6, 7], 54 | [4, -5, 6, -7]]), 55 | 'max_pool_rep': jnp.array([[8, 9, 10, 11], 56 | [4, -5, 6, -7]]) 57 | } 58 | 59 | test5 = { 60 | 'x': jnp.array([[[5, 2, -5], [1, -3, 4], [-3, -8, 1]], 61 | [[-2, -7, 4], [9, -4, 5], [-1, 2, -3]]]), 62 | 'padding_mask' : None, 63 | 'mean_pool_rep': jnp.array([[1, -3, 0], 64 | [2, -3, 2]]), 65 | 'max_pool_rep': jnp.array([[5, 2, 4], 66 | [9, 2, 5]]) 67 | } 68 | 69 | test6 = { 70 | 'x': jnp.array([[[5, 2, -5], [1, -3, 4], [-3, -8, 1]], 71 | [[-2, -7, 4], [9, -4, 5], [-1, 2, -3]]]), 72 | 'padding_mask' : jnp.array([[[1], [1], [1]], 73 | [[1], [1], [1]]]), 74 | 'mean_pool_rep': jnp.array([[1, -3, 0], 75 | [2, -3, 2]]), 76 | 'max_pool_rep': jnp.array([[5, 2, 4], 77 | [9, 2, 5]]) 78 | } 79 | 80 | test7 = { 81 | 'x': jnp.array([[[6, 2, -5], [1, -2, 4], [-3, -8, 0]], 82 | [[-2, -7, 4], [10, -4, 5], [-1, 3, -3]]]), 83 | 'padding_mask' : jnp.array([[[0], [1], [1]], 84 | [[1], [0], [0]]]), 85 | 'mean_pool_rep': jnp.array([[-1, -5, 2], 86 | [-2, -7, 4]]), 87 | 'max_pool_rep': jnp.array([[1, -2, 4], 88 | [-2, -7, 4]]) 89 | } 90 | 91 | test8 = { 92 | 'x': jnp.array([[[6, 2, -5], [1, -2, 4], [-3, -8, 0]], 93 | [[-2, -7, 4], [10, -4, 5], [-1, 3, -3]]]), 94 | 'padding_mask' : jnp.array([[[1], [0], [1]], 95 | [[1], [1], [0]]]), 96 | 'mean_pool_rep': jnp.array([[3/2, -3, -5/2], 97 | [4, -11/2, 9/2]]), 98 | 'max_pool_rep': jnp.array([[6, 2, 0], 99 | [10, -4, 5]]) 100 | } 101 | 102 | tests = (test1, 103 | test2, 104 | test3, 105 | test4, 106 | test5, 107 | test6, 108 | test7, 109 | test8) 110 | 111 | 112 | class TestPooling(parameterized.TestCase): 113 | """Abstract method for testing mean and max pool reduce functions.""" 114 | 115 | @parameterized.parameters( 116 | *tests 117 | ) 118 | def test_mean_pool(self, x, padding_mask, mean_pool_rep, **unused_kwargs): 119 | self.assertTrue(jnp.array_equal(mean_pool(x, padding_mask), mean_pool_rep)) 120 | 121 | @parameterized.parameters( 122 | *tests 123 | ) 124 | def test_max_pool(self, x, padding_mask, max_pool_rep, **unused_kwargs): 125 | self.assertTrue(jnp.array_equal(max_pool(x, padding_mask), max_pool_rep)) 126 | 127 | @parameterized.parameters( 128 | *tests 129 | ) 130 | def test_max_pool_greater_than_mean_pool(self, x, padding_mask, **unused_kwargs): 131 | self.assertTrue((max_pool(x, padding_mask=None) >= mean_pool(x, padding_mask=None)).all()) 132 | 133 | 134 | if __name__ == '__main__': 135 | absltest.main() 136 | -------------------------------------------------------------------------------- /contextual_lenses/pfam_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for Pfam family classification experiments.""" 2 | 3 | import os 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | import numpy as np 9 | 10 | import tensorflow as tf 11 | 12 | import pandas as pd 13 | 14 | import matplotlib.pyplot as plt 15 | 16 | import scipy.stats 17 | 18 | import sklearn.metrics as metrics 19 | from sklearn.neighbors import KNeighborsClassifier as knn 20 | 21 | from pkg_resources import resource_filename 22 | 23 | from fs_gcsfs import GCSFS 24 | 25 | from google_research.protein_lm import domains 26 | 27 | from contextual_lenses.train_utils import create_data_iterator 28 | 29 | from contextual_lenses.loss_fns import cross_entropy_loss 30 | 31 | 32 | # Data preprocessing. 33 | # Original code source: https://www.kaggle.com/drewbryant/starter-pfam-seed-random-split. 34 | def read_all_shards(partition, data_dir, bucket_name): 35 | """Combines different CSVs into a single dataframe.""" 36 | 37 | shards = [] 38 | gcsfs = GCSFS(bucket_name) 39 | for fn in gcsfs.listdir(os.path.join(data_dir, partition)): 40 | with gcsfs.open(os.path.join(data_dir, partition, fn)) as f: 41 | shards.append(pd.read_csv(f, index_col=None)) 42 | 43 | return pd.concat(shards) 44 | 45 | 46 | def mod_family_accession(family_accession): 47 | """Reduces family accession to everything prior to '.'.""" 48 | 49 | return family_accession[:family_accession.index('.')] 50 | 51 | 52 | # Pfam protein_lm domain. 53 | PFAM_PROTEIN_DOMAIN = domains.VariableLengthDiscreteDomain( 54 | vocab=domains.ProteinVocab(include_anomalous_amino_acids=True, 55 | include_bos=True, 56 | include_eos=True, 57 | include_pad=True, 58 | include_mask=True), 59 | length=512) 60 | 61 | 62 | # Number of categories for one-hot encoding. 63 | PFAM_NUM_CATEGORIES = 27 64 | 65 | 66 | def residues_to_one_hot_inds(seq): 67 | """Converts amino acid residues to one hot indices.""" 68 | 69 | one_hot_inds = PFAM_PROTEIN_DOMAIN.encode([seq])[0] 70 | 71 | return one_hot_inds 72 | 73 | 74 | def get_family_ids(): 75 | """Pfam family ids.""" 76 | 77 | family_ids = open( 78 | resource_filename('contextual_lenses.resources', 'pfam_family_ids.txt'), 79 | 'r').readlines() 80 | 81 | return family_ids 82 | 83 | def get_family_id_to_index(): 84 | """Dictionary mapping family id to index.""" 85 | 86 | family_ids = open( 87 | resource_filename('contextual_lenses.resources', 'pfam_family_ids.txt'), 88 | 'r').readlines() 89 | family_id_to_index = {} 90 | for i, family_id in enumerate(family_ids): 91 | family_id_to_index[family_id.replace('\n', '')] = i 92 | 93 | return family_id_to_index 94 | 95 | 96 | def create_pfam_df(family_accessions, 97 | test=False, 98 | samples=None, 99 | random_state=0, 100 | data_partitions_dirpath='random_split/', 101 | gcs_bucket='neuralblast_public'): 102 | """Processes Pfam data into a featurized dataframe with samples many entries per family.""" 103 | 104 | family_id_to_index = get_family_id_to_index() 105 | 106 | if test: 107 | pfam_df = read_all_shards(partition='test', 108 | data_dir=data_partitions_dirpath, 109 | bucket_name=gcs_bucket) 110 | else: 111 | pfam_df = read_all_shards(partition='train', 112 | data_dir=data_partitions_dirpath, 113 | bucket_name=gcs_bucket) 114 | 115 | pfam_df['mod_family_accession'] = pfam_df.family_accession.apply( 116 | lambda x: mod_family_accession(x)) 117 | pfam_df = pfam_df[pfam_df.mod_family_accession.isin(family_accessions)] 118 | pfam_df['index'] = pfam_df.family_id.apply(lambda x: family_id_to_index[x]) 119 | 120 | pfam_df['one_hot_inds'] = pfam_df.sequence.apply( 121 | lambda x: residues_to_one_hot_inds(x[:512])) 122 | 123 | if samples is not None: 124 | pfam_df = pfam_df.sample(frac=1, 125 | replace=False, 126 | random_state=random_state) 127 | pfam_df = pfam_df.groupby('mod_family_accession').head( 128 | samples).reset_index() 129 | 130 | return pfam_df 131 | 132 | 133 | def create_pfam_seq_batches(family_accessions, 134 | batch_size, 135 | test=False, 136 | samples=None, 137 | epochs=1, 138 | drop_remainder=False, 139 | buffer_size=None, 140 | shuffle_seed=0, 141 | sample_random_state=0, 142 | data_partitions_dirpath='random_split/', 143 | gcs_bucket='neuralblast_public', 144 | as_numpy=False): 145 | """Creates iterable object of Pfam sequences.""" 146 | 147 | pfam_df = create_pfam_df(family_accessions, 148 | test=test, 149 | samples=samples, 150 | random_state=sample_random_state, 151 | data_partitions_dirpath=data_partitions_dirpath, 152 | gcs_bucket=gcs_bucket) 153 | 154 | pfam_batches = create_data_iterator(df=pfam_df, 155 | input_col='one_hot_inds', 156 | output_col='index', 157 | batch_size=batch_size, 158 | epochs=epochs, 159 | buffer_size=buffer_size, 160 | seed=shuffle_seed, 161 | drop_remainder=drop_remainder, 162 | add_outputs=False, 163 | as_numpy=as_numpy) 164 | 165 | return pfam_batches 166 | 167 | 168 | def create_pfam_batches(family_accessions, 169 | batch_size, 170 | test=False, 171 | samples=None, 172 | epochs=1, 173 | drop_remainder=False, 174 | buffer_size=None, 175 | shuffle_seed=0, 176 | sample_random_state=0, 177 | data_partitions_dirpath='random_split/', 178 | gcs_bucket='neuralblast_public', 179 | as_numpy=True): 180 | """Creates iterable object of Pfam data batches.""" 181 | 182 | pfam_df = create_pfam_df(family_accessions, 183 | test=test, 184 | samples=samples, 185 | random_state=sample_random_state, 186 | data_partitions_dirpath=data_partitions_dirpath, 187 | gcs_bucket=gcs_bucket) 188 | 189 | pfam_indexes = pfam_df['index'].values 190 | 191 | pfam_batches = create_data_iterator(df=pfam_df, 192 | input_col='one_hot_inds', 193 | output_col='index', 194 | batch_size=batch_size, 195 | epochs=epochs, 196 | buffer_size=buffer_size, 197 | seed=shuffle_seed, 198 | drop_remainder=drop_remainder, 199 | as_numpy=as_numpy) 200 | 201 | return pfam_batches, pfam_indexes 202 | 203 | 204 | # Model evaluation. 205 | def pfam_evaluate(predict_fn, 206 | test_family_accessions, 207 | title, 208 | loss_fn_kwargs, 209 | batch_size=512, 210 | data_partitions_dirpath='random_split/', 211 | gcs_bucket='neuralblast_public'): 212 | """Computes predicted family ids and measures performance in cross entropy and accuracy.""" 213 | 214 | test_batches, test_indexes = create_pfam_batches( 215 | family_accessions=test_family_accessions, 216 | batch_size=batch_size, 217 | test=True, 218 | buffer_size=1, 219 | gcs_bucket=gcs_bucket, 220 | data_partitions_dirpath=data_partitions_dirpath) 221 | 222 | pred_indexes = [] 223 | cross_entropy = 0. 224 | 225 | for batch in iter(test_batches): 226 | 227 | X, Y = batch 228 | 229 | Y_hat = predict_fn(X) 230 | 231 | cross_entropy += cross_entropy_loss(Y, Y_hat, **loss_fn_kwargs) 232 | 233 | preds = jnp.argmax(Y_hat, axis=1) 234 | for pred in preds: 235 | pred_indexes.append(pred) 236 | 237 | pred_indexes = np.array(pred_indexes) 238 | 239 | acc = metrics.accuracy_score(test_indexes, pred_indexes) 240 | 241 | results = { 242 | 'title': title, 243 | 'cross_entropy': cross_entropy, 244 | 'accuracy': acc, 245 | } 246 | 247 | return results, pred_indexes 248 | 249 | 250 | def compute_embeddings(encoder, data_batches): 251 | """Computes sequence embeddings according to a specified encoder.""" 252 | 253 | vectors = [] 254 | for batch in iter(data_batches): 255 | X, Y = batch 256 | X_embedded = encoder(X) 257 | for vec in np.array(X_embedded): 258 | vectors.append(vec) 259 | vectors = np.array(vectors) 260 | 261 | return vectors 262 | 263 | 264 | def pfam_nearest_neighbors_classification( 265 | encoder, 266 | family_accessions, 267 | batch_size=512, 268 | n_neighbors=1, 269 | train_samples=None, 270 | test_samples=None, 271 | shuffle_seed=0, 272 | sample_random_state=0, 273 | data_partitions_dirpath='random_split/', 274 | gcs_bucket='neuralblast_public'): 275 | """Nearest neighbors classification on Pfam families using specified encoder.""" 276 | 277 | train_batches, train_indexes = create_pfam_batches( 278 | family_accessions=family_accessions, 279 | batch_size=batch_size, 280 | samples=train_samples, 281 | buffer_size=1, 282 | shuffle_seed=shuffle_seed, 283 | sample_random_state=sample_random_state, 284 | data_partitions_dirpath=data_partitions_dirpath, 285 | gcs_bucket=gcs_bucket) 286 | test_batches, test_indexes = create_pfam_batches( 287 | family_accessions=family_accessions, 288 | batch_size=batch_size, 289 | test=True, 290 | samples=test_samples, 291 | buffer_size=1, 292 | shuffle_seed=shuffle_seed, 293 | sample_random_state=sample_random_state, 294 | data_partitions_dirpath=data_partitions_dirpath, 295 | gcs_bucket=gcs_bucket) 296 | 297 | train_vectors = compute_embeddings(encoder, train_batches) 298 | test_vectors = compute_embeddings(encoder, test_batches) 299 | 300 | knn_classifier = knn(n_neighbors=n_neighbors) 301 | knn_classifier.fit(train_vectors, train_indexes) 302 | knn_predictions = knn_classifier.predict(test_vectors) 303 | 304 | knn_accuracy = metrics.accuracy_score(test_indexes, knn_predictions) 305 | 306 | results = { 307 | str(n_neighbors) + "-nn accuracy": knn_accuracy, 308 | 'train_samples': train_samples, 309 | 'test_samples': test_samples 310 | } 311 | 312 | return results, knn_predictions, knn_classifier 313 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Protein Embedding Search 2 | 3 | **This is not an officially supported Google product.** 4 | 5 | Protein database search tools such as [BLAST](https://blast.ncbi.nlm.nih.gov/Blast.cgi) are instrumental for research in the life sciences. However, they are slow and are based on surface-level sequence similarity. We are exploring using neural networks to improve the speed and accuracy of finding relevant sequences from these databases. 6 | 7 | More specifically, we are aiming to learn fixed-length protein embeddings using [contextual lenses](https://arxiv.org/pdf/2002.08866.pdf). Generally speaking, a sequence level protein representation, such as a one-hot encoding, is an array of the the form (sequence_length, n) where n is the amino acid embedding dimension. A contextual lens is a (learnable) map from the (sequence_length, n)-array to an (m,)-vector where m is independent of sequence_length. Embeddings are constructed using an encoder function followed by a contextual lens. To learn these embeddings a downstream prediction task is performed using a single dense layer. Gradients are backpropagated through all 3 components of the architecture (encoder, lens, predictor) using the Adam optimizer with variable (potentially zero) learning rates and weight decays per component. 8 | 9 | ### Encoders 10 | - [One-hot](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/contextual_lenses/encoders.py#L21): non-learnable 11 | - [CNN](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/contextual_lenses/encoders.py#L46): learnable 12 | - [Transformer](https://github.com/google-research/google-research/blob/master/protein_lm/models.py#L870): learnable and pretrainable 13 | 14 | ### Lenses 15 | - [Mean/Max-Pool](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/contextual_lenses/contextual_lenses.py#L21): non-learnable 16 | - [Linear-Mean/Max-Pool](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/contextual_lenses/contextual_lenses.py#L46): learnable 17 | - [GatedConvolution](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/contextual_lenses/contextual_lenses.py#L125): learnable and self-attentive 18 | 19 | ## TAPE Protein Engineering Tasks 20 | [TAPE](https://arxiv.org/pdf/1906.08230.pdf) proposes two protein engineering tasks: fluorescence prediction and stability prediction. We implement our lens architectures on these tasks in a [Google Colab notebook](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/cnn_protein_landscapes.ipynb). We find that for the fluorescence task both linear regression on the one-hot encodings and 1-layer convolution compete with and outperform the best pretrained language models in TAPE. Likewise, we find that for the stability task 3-layer convolution competes with and outperforms TAPE's models. See below for our results compared to TAPE's results (bold represents best performance). 21 | 22 | ### Fluorescence 23 | MSE is mean squared error and rho represents Spearman's rank correlation coefficient. 24 | | Model Type | Model | Full Test Set (MSE, rho) | Bright Mode (MSE, rho) | Dark Mode (MSE, rho) | 25 | | ---------- | ----- | :----------------------: | :--------------------: | :------------------: | 26 | | Baseline | Linear Regression | (0.35, **0.69**) | (0.09, **0.68**) | (0.33, **0.05**) | 27 | | Lens Architecture | 1-Layer CNN + MaxPool | (0.26, **0.69**) | (0.09, 0.65) | (0.29, **0.05**) | 28 | | Lens Architecture | 1-Layer CNN + LinearMaxPool | (0.23, **0.69**) | (0.12, 0.66) | (0.28, **0.05**) | 29 | | TAPE | Best of all models | (**0.19**, 0.68) | (**0.07**, 0.63) | (**0.22**, **0.05**)| 30 | 31 | 32 | ### Stability 33 | Accuracy (Acc) is measured using the parent protein as a decision boundary and labeling mutations as beneficial if predicted stability is greater than predicted parent stability and deleterious if the opposite is true. rho represents Spearman's rank correlation coefficient. The letters A and B represent the alpha and beta topologies, respectively. 34 | | Model Type | Model | Full Test Set (rho, Acc) | AAA (rho, Acc) | ABBA (rho, Acc) | BABB (rho, Acc) | BBABB (rho, Acc) | 35 | | ---------- | ----- | :---------------------------: | :-----------------: | :------------------: | :------------------: | :-------------------: | 36 | | Baseline | Linear Regression | (0.49, 0.60) | (0.21, 0.66) | (-0.03, 0.6) | (0.51, 0.64) | (0.38, 0.61) | 37 | | Lens Architecture | 3-Layer CNN + MaxPool | (0.76, 0.75) | (0.69, **0.71**) | (0.37, 0.70) | (0.50, 0.72) | (0.60, 0.68) | 38 | | Lens Architecture | Dilated 3-Layer CNN + MaxPool | (0.75, 0.73) | (0.67, 0.69) | (0.49, 0.69) | (0.61, 0.70) | (0.53, 0.64) | 39 | | Lens Architecture | 3-Layer CNN + LinearMaxPool | (0.71, **0.77**) | (0.59, 0.69) | (**0.52**, 0.77) | (0.55, **0.73**) | (0.60, **0.70**) | 40 | | Lens Architecture | Ensemble (Average) of above CNN models | (**0.79**, **0.77**) | (0.67, **0.71**) | (**0.53**, 0.75) | (0.65, **0.74**) | (0.60, **0.70**) | 41 | | TAPE | Best of all models | (0.73, 0.70) | (**0.72**, 0.70) | (0.48, **0.79**) | (**0.68**, 0.71) | (**0.67**, **0.70**) | 42 | 43 | 44 | ## Downstream Task 45 | The downstream task we use to train embeddings is Pfam family classification. We pick an encoder and a lens and train the architecture to predict a protein's family using only its primary sequence. We train on 10000 families in the data set and measure **Lens Accuracy**: the accuracy achieved on the *test set of train families* by the architecture trained for family prediction on the *train set of train families*. We then take the embeddings from this trained model and use them to do family prediction on 1000 holdout families with KNN (using 1 neighbor). This test allows us to assess the extent of transfer learning by seeing how much the embeddings have learned about the holdout families from the train families. In theory, a perfect model would map all proteins that are members of the same family to a single vector. To test for this we measure **n-Sample Test KNN Accuracy**: The accuracy achieved on the *test set of test families* by a KNN classifier trained on the embeddings (from our architecture) of the *train set of test families* using *at most n samples per family*. 46 | 47 | ### Pretraining 48 | We also measure the effect that pretraining has on the performance of a language model encoder. There has been a great deal of interest in measuring the degree to which pretraining protein language models improves their performance on downstream tasks. TAPE investigates this and proposes baselines. Our results indicate that pretraining offers a substantial boost in performance on the family classification task. We use transformer language models, specifically BERT models similar to to the the [ProGen model](https://www.biorxiv.org/content/10.1101/2020.03.07.982272v2.full.pdf) and the [models used by FAIR](https://www.biorxiv.org/content/10.1101/622803v2.full.pdf). Our [models](https://github.com/google-research/google-research/tree/master/protein_lm) are implemented in jax/flax and pretrained on the [TrEMBL protein corpus](https://www.uniprot.org/statistics/TrEMBL). 49 | 50 | ## Results 51 | In the table below we show the accuracies achieved using KNN on the model embeddings as well as KNN using BLAST's weighted edit distance. We show a simple 2-layer CNN, 3 different size language models both with and without pretraining, and the [Blundell CNN model](https://www.biorxiv.org/content/10.1101/626507v4.full.pdf). All bolded numbers represent better performance compared to BLAST. The key takeways are the performance of pretrained language models on 1-sample classification and the substantial performance boost from pretraining said models. 52 | 53 | | Model | 1-Sample Accuracy | 5-Sample Accuracy | 10-Sample Accuracy | 50-Sample Accuracy | 54 | |--------------------------------------------|:-----------------:|:-----------------:|:------------------:|:------------------:| 55 | | BLAST | 0.860750 | 0.978355 | 0.991342 | 0.996392 | 56 | | 2-layer CNN | 0.687815 | 0.870944 | 0.914924 | 0.956741 | 57 | | Small Transformer | 0.769286 | 0.920692 | 0.952415 | 0.974045 | 58 | | Pretrained Small Transformer | **0.873828** | 0.968998 | 0.979813 | 0.992790 | 59 | | Medium Transformer | 0.778659 | 0.921413 | 0.956741 | 0.981255 | 60 | | Pretrained Medium Transformer | **0.863775** | 0.968277 | 0.984859 | 0.994232 | 61 | | Large Transformer | 0.749820 | 0.894737 | 0.937996 | 0.970440 | 62 | | Pretrained Large Transformer | **0.865898** | 0.974045 | 0.984859 | 0.995674 | 63 | | Blundell Lens-Family CNN | **0.877345** | **0.980519** | **0.992063** | 0.993506 | 64 | | Blundell Full-Family CNN** | **0.923521** | **0.984848** | **0.992785** | 0.995671 | 65 | | Blundell Full-Family CNN** w/ Whitening*** | **0.940837** | **0.988456** | **0.996392** | **0.996392** | 66 | 67 | ** The Full-Family Blundell CNN is not performing transfer learning. It was trained on families that appear in the KNN task. 68 | 69 | *** Whitened embeddings are obtained by performing PCA on the embeddings of all Pfam seed sequences and applying the corresponding whitening transformation to the KNN train and test sequences. 70 | 71 | Below we show plots of the top 10 n-Sample Test KNN Accuracies vs. Lens Accuracies for different models and for n = 1, 5, 10, 50. The key takeaways are the noticable boost pretraining the language models provides, the fact that Lens Accuracy is not a perfect predictor of Test KNN Accuracy, and the independence of performance and transformer size. 72 | 73 | ![1-sample](/figures/1-sample_test_knn_accuracy.png) 74 | 75 | ![5-sample](/figures/5-sample_test_knn_accuracy.png) 76 | 77 | ![10-sample](/figures/10-sample_test_knn_accuracy.png) 78 | 79 | ![50-sample](/figures/50-sample_test_knn_accuracy.png) 80 | 81 | ## Quickstart 82 | To clone this project run 83 | ``` 84 | git clone --recurse-submodules https://github.com/googleinterns/protein-embedding-retrieval.git 85 | ``` 86 | 87 | Once the project is cloned, the first step is to install [Caliban](https://github.com/google/caliban). We use Caliban for running individual jobs and parallelizing many jobs on GCP (Google Cloud Platform). 88 | 89 | For a simple demo on your machine (recommended only if it is equipped with a GPU) run 90 | ``` 91 | caliban run --experiment_config demo.json pfam_experiment.py 92 | ``` 93 | The demo takes ~3 hours with an Nvidia Tesla P100 GPU (the Caliban default). By default this will load data from and save data to the 'neuralblast_public' GCS bucket. You can change this by modifying the values of 'load_gcs_bucket' and 'save_gcs_bucket' in demo.json. 94 | 95 | To run on cloud first connect to a GCP project (equipped with GPU resources) by running 96 | ``` 97 | gcloud init 98 | ``` 99 | If your GCP project is named MY_PROJECT run 100 | ``` 101 | export PROJECT_ID=MY_PROJECT 102 | ``` 103 | Finally, run 104 | ``` 105 | caliban cloud --experiment_config demo.json pfam_experiment.py 106 | ``` 107 | 108 | 109 | ## Reproducing our Results 110 | To reproduce our results you first need to connect to a GCP project, ideally one with a large number of GPUs, and clone the project. Then take the [generate_params.py](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/generate_params.py) script, and modify the variable 'save_gcs_bucket' in the call to main() using the GCS bucket you wish to save to (and potentially do the same for the one you want to load from 'load_gcs_bucket'). Run the script to generate the appropriate parameter combinations and run 111 | ``` 112 | caliban cloud --experiment_config params_combinations.json pfam_experiment.py 113 | ``` 114 | 115 | ## Source Code Headers 116 | 117 | Every file containing source code must include copyright and license 118 | information. This includes any JS/CSS files that you might be serving out to 119 | browsers. (This is to help well-intentioned people avoid accidental copying that 120 | doesn't comply with the license.) 121 | 122 | Apache header: 123 | 124 | Copyright 2020 Google LLC 125 | 126 | Licensed under the Apache License, Version 2.0 (the "License"); 127 | you may not use this file except in compliance with the License. 128 | You may obtain a copy of the License at 129 | 130 | https://www.apache.org/licenses/LICENSE-2.0 131 | 132 | Unless required by applicable law or agreed to in writing, software 133 | distributed under the License is distributed on an "AS IS" BASIS, 134 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 135 | See the License for the specific language governing permissions and 136 | limitations under the License. 137 | -------------------------------------------------------------------------------- /contextual_lenses/test/test_learning.py: -------------------------------------------------------------------------------- 1 | """End-to-end learning tests using MSE loss for contextual_lenses.py.""" 2 | 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | import flax 8 | from flax import nn 9 | 10 | import numpy as np 11 | 12 | from absl.testing import parameterized 13 | from absl.testing import absltest 14 | 15 | from contextual_lenses import mean_pool, max_pool, \ 16 | linear_max_pool, linear_mean_pool, gated_conv 17 | 18 | from train_utils import create_optimizer, train_step, \ 19 | create_representation_model 20 | 21 | from encoders import one_hot_encoder, cnn_one_hot_encoder, \ 22 | one_hot_pos_emb_encoder, cnn_one_hot_pos_emb_encoder 23 | 24 | from loss_fns import mse_loss 25 | 26 | 27 | def generate_random_sequences(batch_size=3, seq_len=12, num_categories=21): 28 | """Generates batch_size many random sequences of length seq_len from num_categories indices.""" 29 | 30 | np.random.seed(0) 31 | input_data = jnp.array(np.random.randint(0, num_categories-1, size=(batch_size, seq_len))) 32 | 33 | return input_data 34 | 35 | 36 | def generate_random_targets(batch_size=3): 37 | """Generates batch_size many scalar targets.""" 38 | 39 | np.random.seed(0) 40 | output_data = jnp.array(np.random.normal(size=(batch_size,))) 41 | 42 | return output_data 43 | 44 | 45 | def train(model, input_data, output_data, learning_rate=1e-3, epochs=1000): 46 | """Fits model to training data and returns train loss.""" 47 | 48 | optimizer = create_optimizer(model, learning_rate=learning_rate, weight_decay=0) 49 | 50 | loss_fn_kwargs={} 51 | 52 | for epoch in range(epochs): 53 | optimizer = train_step(optimizer, input_data, output_data, mse_loss, loss_fn_kwargs) 54 | 55 | preds = jnp.squeeze(optimizer.target(input_data), axis=1) 56 | 57 | train_loss = jnp.mean(jnp.square(preds-output_data)) 58 | 59 | return train_loss 60 | 61 | 62 | # Test cases: 63 | 64 | # One-hot: 65 | test1 = { 66 | 'testcase_name': 'max_pool', 67 | 'encoder_fn': one_hot_encoder, 68 | 'encoder_fn_kwargs': { 69 | 70 | }, 71 | 'reduce_fn': max_pool, 72 | 'reduce_fn_kwargs': { 73 | 74 | }, 75 | 'learning_rate': 1e-2, 76 | 'epochs': 1000, 77 | 'loss_threshold': 1e-4 78 | } 79 | 80 | test2 = { 81 | 'testcase_name': 'mean_pool', 82 | 'encoder_fn': one_hot_encoder, 83 | 'encoder_fn_kwargs': { 84 | 85 | }, 86 | 'reduce_fn': mean_pool, 87 | 'reduce_fn_kwargs': { 88 | 89 | }, 90 | 'learning_rate': 1e-2, 91 | 'epochs': 1000, 92 | 'loss_threshold': 1e-4 93 | } 94 | 95 | test3 = { 96 | 'testcase_name': 'linear_max_pool', 97 | 'encoder_fn': one_hot_encoder, 98 | 'encoder_fn_kwargs': { 99 | 100 | }, 101 | 'reduce_fn': linear_max_pool, 102 | 'reduce_fn_kwargs': { 103 | 'rep_size': 512 104 | }, 105 | 'learning_rate': 1e-3, 106 | 'epochs': 100, 107 | 'loss_threshold': 1e-4 108 | } 109 | 110 | test4 = { 111 | 'testcase_name': 'linear_mean_pool', 112 | 'encoder_fn': one_hot_encoder, 113 | 'encoder_fn_kwargs': { 114 | 115 | }, 116 | 'reduce_fn': linear_mean_pool, 117 | 'reduce_fn_kwargs': { 118 | 'rep_size': 2048 119 | }, 120 | 'learning_rate': 1e-3, 121 | 'epochs': 100, 122 | 'loss_threshold': 1e-4 123 | } 124 | 125 | 126 | test5 = { 127 | 'testcase_name': 'gated_conv', 128 | 'encoder_fn': one_hot_encoder, 129 | 'encoder_fn_kwargs': { 130 | 131 | }, 132 | 'reduce_fn': gated_conv, 133 | 'reduce_fn_kwargs': { 134 | 'rep_size': 256, 135 | 'm_layers': 3, 136 | 'm_features': [[512, 512], [512, 512]], 137 | 'm_kernel_sizes': [[12, 12], [10, 10], [8, 8]], 138 | 'conv_rep_size': 256 139 | }, 140 | 'learning_rate': 1e-3, 141 | 'epochs': 100, 142 | 'loss_threshold': 1e-4 143 | } 144 | 145 | 146 | # CNN + one-hot: 147 | test6 = { 148 | 'testcase_name': 'cnn_max_pool', 149 | 'encoder_fn': cnn_one_hot_encoder, 150 | 'encoder_fn_kwargs': { 151 | 'n_layers': 1, 152 | 'n_features': [256], 153 | 'n_kernel_sizes': [3] 154 | }, 155 | 'reduce_fn': max_pool, 156 | 'reduce_fn_kwargs': { 157 | 158 | }, 159 | 'learning_rate': 1e-3, 160 | 'epochs': 100, 161 | 'loss_threshold': 1e-4 162 | } 163 | 164 | 165 | test7 = { 166 | 'testcase_name': 'cnn_mean_pool', 167 | 'encoder_fn': cnn_one_hot_encoder, 168 | 'encoder_fn_kwargs': { 169 | 'n_layers': 1, 170 | 'n_features': [512], 171 | 'n_kernel_sizes': [5] 172 | }, 173 | 'reduce_fn': mean_pool, 174 | 'reduce_fn_kwargs': { 175 | 176 | }, 177 | 'learning_rate': 1e-3, 178 | 'epochs': 100, 179 | 'loss_threshold': 1e-4 180 | } 181 | 182 | 183 | test8 = { 184 | 'testcase_name': 'cnn_linear_max_pool', 185 | 'encoder_fn': cnn_one_hot_encoder, 186 | 'encoder_fn_kwargs': { 187 | 'n_layers': 1, 188 | 'n_features': [32], 189 | 'n_kernel_sizes': [3] 190 | }, 191 | 'reduce_fn': linear_max_pool, 192 | 'reduce_fn_kwargs': { 193 | 'rep_size': 256 194 | }, 195 | 'learning_rate': 1e-3, 196 | 'epochs': 100, 197 | 'loss_threshold': 1e-4 198 | } 199 | 200 | 201 | test9 = { 202 | 'testcase_name': 'cnn_linear_mean_pool', 203 | 'encoder_fn': cnn_one_hot_encoder, 204 | 'encoder_fn_kwargs': { 205 | 'n_layers': 1, 206 | 'n_features': [512], 207 | 'n_kernel_sizes': [5] 208 | }, 209 | 'reduce_fn': linear_mean_pool, 210 | 'reduce_fn_kwargs': { 211 | 'rep_size': 256 212 | }, 213 | 'learning_rate': 1e-3, 214 | 'epochs': 100, 215 | 'loss_threshold': 1e-4 216 | } 217 | 218 | 219 | test10 = { 220 | 'testcase_name': 'cnn_gated_conv', 221 | 'encoder_fn': cnn_one_hot_encoder, 222 | 'encoder_fn_kwargs': { 223 | 'n_layers': 1, 224 | 'n_features': [32], 225 | 'n_kernel_sizes': [3] 226 | }, 227 | 'reduce_fn': gated_conv, 228 | 'reduce_fn_kwargs': { 229 | 'rep_size': 256, 230 | 'm_layers': 3, 231 | 'm_features': [[512, 512], [512, 512]], 232 | 'm_kernel_sizes': [[12, 12], [10, 10], [8, 8]], 233 | 'conv_rep_size': 256 234 | }, 235 | 'learning_rate': 1e-3, 236 | 'epochs': 100, 237 | 'loss_threshold': 1e-4 238 | } 239 | 240 | 241 | # One-hot + positional embeddings: 242 | test11 = { 243 | 'testcase_name': 'pos_emb_max_pool', 244 | 'encoder_fn': one_hot_pos_emb_encoder, 245 | 'encoder_fn_kwargs': { 246 | 'max_len': 512, 247 | 'posemb_init': nn.initializers.normal(stddev=1e-6) 248 | }, 249 | 'reduce_fn': max_pool, 250 | 'reduce_fn_kwargs': { 251 | 252 | }, 253 | 'learning_rate': 1e-2, 254 | 'epochs': 1000, 255 | 'loss_threshold': 1e-4 256 | } 257 | 258 | test12 = { 259 | 'testcase_name': 'pos_emb_mean_pool', 260 | 'encoder_fn': one_hot_pos_emb_encoder, 261 | 'encoder_fn_kwargs': { 262 | 'max_len': 512, 263 | 'posemb_init': nn.initializers.normal(stddev=1e-6) 264 | }, 265 | 'reduce_fn': mean_pool, 266 | 'reduce_fn_kwargs': { 267 | 268 | }, 269 | 'learning_rate': 1e-2, 270 | 'epochs': 1000, 271 | 'loss_threshold': 1e-4 272 | } 273 | 274 | test13 = { 275 | 'testcase_name': 'pos_emb_linear_max_pool', 276 | 'encoder_fn': one_hot_pos_emb_encoder, 277 | 'encoder_fn_kwargs': { 278 | 'max_len': 512, 279 | 'posemb_init': nn.initializers.normal(stddev=1e-6) 280 | }, 281 | 'reduce_fn': linear_max_pool, 282 | 'reduce_fn_kwargs': { 283 | 'rep_size': 256 284 | }, 285 | 'learning_rate': 1e-3, 286 | 'epochs': 100, 287 | 'loss_threshold': 1e-4 288 | } 289 | 290 | test14 = { 291 | 'testcase_name': 'pos_emb_linear_mean_pool', 292 | 'encoder_fn': one_hot_pos_emb_encoder, 293 | 'encoder_fn_kwargs': { 294 | 'max_len': 512, 295 | 'posemb_init': nn.initializers.normal(stddev=1e-6) 296 | }, 297 | 'reduce_fn': linear_mean_pool, 298 | 'reduce_fn_kwargs': { 299 | 'rep_size': 4096 300 | }, 301 | 'learning_rate': 1e-3, 302 | 'epochs': 100, 303 | 'loss_threshold': 1e-4 304 | } 305 | 306 | 307 | test15 = { 308 | 'testcase_name': 'pos_emb_gated_conv', 309 | 'encoder_fn': one_hot_pos_emb_encoder, 310 | 'encoder_fn_kwargs': { 311 | 'max_len': 512, 312 | 'posemb_init': nn.initializers.normal(stddev=1e-6) 313 | }, 314 | 'reduce_fn': gated_conv, 315 | 'reduce_fn_kwargs': { 316 | 'rep_size': 256, 317 | 'm_layers': 3, 318 | 'm_features': [[512, 512], [512, 512]], 319 | 'm_kernel_sizes': [[12, 12], [10, 10], [8, 8]], 320 | 'conv_rep_size': 256 321 | }, 322 | 'learning_rate': 1e-3, 323 | 'epochs': 100, 324 | 'loss_threshold': 1e-4 325 | } 326 | 327 | 328 | # CNN + one-hot + positional embeddings: 329 | test16 = { 330 | 'testcase_name': 'cnn_pos_emb_max_pool', 331 | 'encoder_fn': cnn_one_hot_pos_emb_encoder, 332 | 'encoder_fn_kwargs': { 333 | 'n_layers': 1, 334 | 'n_features': [256], 335 | 'n_kernel_sizes': [3], 336 | 'max_len': 512, 337 | 'posemb_init': nn.initializers.normal(stddev=1e-6) 338 | }, 339 | 'reduce_fn': max_pool, 340 | 'reduce_fn_kwargs': { 341 | 342 | }, 343 | 'learning_rate': 1e-3, 344 | 'epochs': 100, 345 | 'loss_threshold': 1e-4 346 | } 347 | 348 | 349 | test17 = { 350 | 'testcase_name': 'cnn_pos_emb_mean_pool', 351 | 'encoder_fn': cnn_one_hot_pos_emb_encoder, 352 | 'encoder_fn_kwargs': { 353 | 'n_layers': 1, 354 | 'n_features': [512], 355 | 'n_kernel_sizes': [5], 356 | 'max_len': 512, 357 | 'posemb_init': nn.initializers.normal(stddev=1e-6) 358 | }, 359 | 'reduce_fn': mean_pool, 360 | 'reduce_fn_kwargs': { 361 | 362 | }, 363 | 'learning_rate': 1e-3, 364 | 'epochs': 100, 365 | 'loss_threshold': 1e-4 366 | } 367 | 368 | 369 | test18 = { 370 | 'testcase_name': 'cnn_pos_emb_linear_max_pool', 371 | 'encoder_fn': cnn_one_hot_pos_emb_encoder, 372 | 'encoder_fn_kwargs': { 373 | 'n_layers': 1, 374 | 'n_features': [32], 375 | 'n_kernel_sizes': [3], 376 | 'max_len': 512, 377 | 'posemb_init': nn.initializers.normal(stddev=1e-6) 378 | }, 379 | 'reduce_fn': linear_max_pool, 380 | 'reduce_fn_kwargs': { 381 | 'rep_size': 256 382 | }, 383 | 'learning_rate': 1e-3, 384 | 'epochs': 100, 385 | 'loss_threshold': 1e-4 386 | } 387 | 388 | 389 | test19 = { 390 | 'testcase_name': 'cnn_pos_emb_linear_mean_pool', 391 | 'encoder_fn': cnn_one_hot_pos_emb_encoder, 392 | 'encoder_fn_kwargs': { 393 | 'n_layers': 1, 394 | 'n_features': [512], 395 | 'n_kernel_sizes': [5], 396 | 'max_len': 512, 397 | 'posemb_init': nn.initializers.normal(stddev=1e-6) 398 | }, 399 | 'reduce_fn': linear_mean_pool, 400 | 'reduce_fn_kwargs': { 401 | 'rep_size': 256 402 | }, 403 | 'learning_rate': 1e-3, 404 | 'epochs': 100, 405 | 'loss_threshold': 1e-4 406 | } 407 | 408 | 409 | test20 = { 410 | 'testcase_name': 'cnn_pos_emb_gated_conv', 411 | 'encoder_fn': cnn_one_hot_pos_emb_encoder, 412 | 'encoder_fn_kwargs': { 413 | 'n_layers': 1, 414 | 'n_features': [32], 415 | 'n_kernel_sizes': [3], 416 | 'max_len': 512, 417 | 'posemb_init': nn.initializers.normal(stddev=1e-6) 418 | }, 419 | 'reduce_fn': gated_conv, 420 | 'reduce_fn_kwargs': { 421 | 'rep_size': 256, 422 | 'm_layers': 3, 423 | 'm_features': [[512, 512], [512, 512]], 424 | 'm_kernel_sizes': [[12, 12], [10, 10], [8, 8]], 425 | 'conv_rep_size': 256 426 | }, 427 | 'learning_rate': 1e-3, 428 | 'epochs': 100, 429 | 'loss_threshold': 1e-4 430 | } 431 | 432 | 433 | tests = (test1, 434 | test2, 435 | test3, 436 | test4, 437 | test5, 438 | test6, 439 | test7, 440 | test8, 441 | test9, 442 | test10, 443 | test11, 444 | test12, 445 | test13, 446 | test14, 447 | test15, 448 | test16, 449 | test17, 450 | test18, 451 | test19, 452 | test20) 453 | 454 | 455 | class TestLearning(parameterized.TestCase): 456 | """Abstract method for testing synthetic learning of encoders and lenses.""" 457 | 458 | @parameterized.named_parameters( 459 | *tests 460 | ) 461 | def test_learning(self, encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs, 462 | learning_rate=1e-3, epochs=1000, loss_threshold=1e-4): 463 | 464 | input_data = generate_random_sequences(batch_size=3, seq_len=12, num_categories=21) 465 | output_data = generate_random_targets(batch_size=3) 466 | 467 | model = create_representation_model(encoder_fn=encoder_fn, 468 | encoder_fn_kwargs=encoder_fn_kwargs, 469 | reduce_fn=reduce_fn, 470 | reduce_fn_kwargs=reduce_fn_kwargs, 471 | num_categories=21, 472 | output_features=1) 473 | 474 | train_loss = train(model, input_data, output_data, 475 | learning_rate=learning_rate, epochs=epochs) 476 | 477 | self.assertTrue(train_loss < loss_threshold) 478 | 479 | 480 | if __name__ == '__main__': 481 | absltest.main() 482 | -------------------------------------------------------------------------------- /contextual_lenses/train_utils.py: -------------------------------------------------------------------------------- 1 | """Train utils 2 | 3 | General tools for instantiating and training models. 4 | """ 5 | 6 | import flax 7 | from flax import nn 8 | from flax import optim 9 | from flax.training import checkpoints 10 | from flax.training import common_utils 11 | 12 | import jax 13 | from jax import random 14 | import jax.nn 15 | import jax.numpy as jnp 16 | from jax.config import config 17 | config.enable_omnistaging() 18 | 19 | import tensorflow as tf 20 | 21 | import numpy as np 22 | 23 | import functools 24 | 25 | import copy 26 | 27 | from google_research.protein_lm import models 28 | 29 | 30 | # Data batching. 31 | def create_data_iterator(df, 32 | input_col, 33 | output_col, 34 | batch_size, 35 | epochs=1, 36 | buffer_size=None, 37 | seed=0, 38 | drop_remainder=False, 39 | add_outputs=True, 40 | as_numpy=True): 41 | """Creates iterator of batches of (inputs) or (inputs, outputs).""" 42 | 43 | if buffer_size is None: 44 | buffer_size = len(df) 45 | 46 | inputs = list(df[input_col].values) 47 | inputs = tf.data.Dataset.from_tensor_slices(inputs) 48 | 49 | outputs = df[output_col].values 50 | outputs = tf.data.Dataset.from_tensor_slices(outputs) 51 | 52 | if add_outputs: 53 | batches = tf.data.Dataset.zip( 54 | (inputs, outputs)).shuffle(buffer_size=buffer_size, 55 | seed=seed, 56 | reshuffle_each_iteration=True) 57 | else: 58 | batches = inputs.shuffle(buffer_size=buffer_size, 59 | seed=seed, 60 | reshuffle_each_iteration=True) 61 | 62 | batches = batches.repeat(epochs).batch(batch_size=batch_size, 63 | drop_remainder=drop_remainder) 64 | 65 | if as_numpy: 66 | batches = batches.as_numpy_iterator() 67 | 68 | return batches 69 | 70 | 71 | def path_inclusion_filter_fn(path, param, layer): 72 | """Returns whether or not layer name is contained in path.""" 73 | 74 | return layer in path 75 | 76 | 77 | def create_optimizer(model, learning_rate, weight_decay, layers=None): 78 | """Instantiates Adam multi-optimizer.""" 79 | 80 | if layers is None: 81 | assert ( 82 | type(learning_rate) == type(weight_decay) == float 83 | ), 'Specify float values for moded learning rate and weight decay!' 84 | optimizer_def = optim.Adam(learning_rate=learning_rate, 85 | weight_decay=weight_decay) 86 | optimizer = optimizer_def.create(model) 87 | 88 | else: 89 | assert ( 90 | len(learning_rate) == len(weight_decay) == len(layers) 91 | ), 'Number of specified learning rates, weight decays, and layers must be equal!' 92 | optimizers = [] 93 | for lr, wd, layer in zip(learning_rate, weight_decay, layers): 94 | if lr > 0: 95 | opt = optim.Adam(learning_rate=lr, weight_decay=wd) 96 | filter_fn = functools.partial(path_inclusion_filter_fn, 97 | layer=layer) 98 | traversal = optim.ModelParamTraversal(filter_fn) 99 | traversal_opt = (traversal, opt) 100 | optimizers.append(traversal_opt) 101 | optimizer_def = optim.MultiOptimizer(*optimizers) 102 | optimizer = optimizer_def.create(model) 103 | 104 | return optimizer 105 | 106 | 107 | @functools.partial(jax.jit, static_argnums=(3, 4)) 108 | def train_step(optimizer, X, Y, loss_fn, loss_fn_kwargs): 109 | """Trains model (optimizer.target) using specified loss function.""" 110 | def compute_loss_fn(model, X, Y, loss_fn, loss_fn_kwargs): 111 | Y_hat = model(X) 112 | loss = loss_fn(Y, Y_hat, **loss_fn_kwargs) 113 | return loss 114 | 115 | grad_fn = jax.value_and_grad(compute_loss_fn) 116 | _, grad = grad_fn(optimizer.target, X, Y, loss_fn, loss_fn_kwargs) 117 | optimizer = optimizer.apply_gradient(grad) 118 | 119 | return optimizer 120 | 121 | 122 | def get_p_train_step(): 123 | """Wraps train_step with jax.pmap.""" 124 | 125 | p_train_step = jax.pmap(train_step, 126 | axis_name='batch', 127 | static_broadcasted_argnums=(3, 4)) 128 | 129 | return p_train_step 130 | 131 | 132 | def train(model, 133 | train_data, 134 | loss_fn, 135 | loss_fn_kwargs, 136 | learning_rate=1e-4, 137 | weight_decay=0.1, 138 | layers=None, 139 | restore_dir=None, 140 | save_dir=None, 141 | use_pmap=False): 142 | """Instantiates optimizer, applies train_step/p_train_step over training data.""" 143 | 144 | optimizer = create_optimizer(model, 145 | learning_rate=learning_rate, 146 | weight_decay=weight_decay, 147 | layers=layers) 148 | 149 | if restore_dir is not None: 150 | optimizer = checkpoints.restore_checkpoint(ckpt_dir=restore_dir, 151 | target=optimizer) 152 | 153 | if use_pmap: 154 | p_train_step = get_p_train_step() 155 | optimizer = optimizer.replicate() 156 | 157 | for batch in iter(train_data): 158 | X, Y = batch 159 | X, Y = common_utils.shard(X), common_utils.shard(Y) 160 | optimizer = p_train_step(optimizer, X, Y, loss_fn, loss_fn_kwargs) 161 | 162 | optimizer = optimizer.unreplicate() 163 | 164 | else: 165 | for batch in iter(train_data): 166 | X, Y = batch 167 | optimizer = train_step(optimizer, X, Y, loss_fn, loss_fn_kwargs) 168 | 169 | if save_dir is not None: 170 | state = optimizer.state 171 | if type(state) == list: 172 | step = [sub_state.step for sub_state in state] 173 | else: 174 | step = state.step 175 | checkpoints.save_checkpoint(ckpt_dir=save_dir, 176 | target=optimizer, 177 | step=step) 178 | 179 | return optimizer 180 | 181 | 182 | def load_params(params, 183 | encoder_fn_params=None, 184 | reduce_fn_params=None, 185 | predict_fn_params=None): 186 | """Updates randomly initialized parameters using loaded parameters.""" 187 | 188 | loaded_params = copy.deepcopy(params) 189 | fn_names = list(loaded_params.keys()) 190 | 191 | num_learnable_layers = len([ 192 | params_dict for params_dict in 193 | [encoder_fn_params, reduce_fn_params, predict_fn_params] 194 | if params_dict is not None 195 | ]) 196 | if encoder_fn_params is not None: 197 | encoder_fn_ind = '_0' 198 | if reduce_fn_params is not None: 199 | reduce_fn_ind = '_1' 200 | predict_fn_ind = '_2' 201 | else: 202 | predict_fn_ind = '_1' 203 | else: 204 | if reduce_fn_params is not None: 205 | reduce_fn_ind = '_0' 206 | predict_fn_ind = '_1' 207 | else: 208 | predict_fn_ind = '_0' 209 | 210 | assert (len(loaded_params.keys()) >= num_learnable_layers 211 | ), 'Model encoder and lens architecture incorrectly specified!' 212 | 213 | encoder_fn_name = None 214 | if encoder_fn_params is not None: 215 | for fn_name in fn_names: 216 | if encoder_fn_ind in fn_name: 217 | if encoder_fn_name is not None: 218 | raise ValueError( 219 | 'Multiple instances of encoder_fn detected. %s' % 220 | fn_name) 221 | encoder_fn_name = fn_name 222 | loaded_params[encoder_fn_name] = encoder_fn_params 223 | 224 | reduce_fn_name = None 225 | if reduce_fn_params is not None: 226 | for fn_name in fn_names: 227 | if reduce_fn_ind in fn_name: 228 | if reduce_fn_name is not None: 229 | raise ValueError( 230 | 'Multiple instances of reduce_fn detected. %s' % 231 | fn_name) 232 | reduce_fn_name = fn_name 233 | loaded_params[reduce_fn_name] = reduce_fn_params 234 | 235 | predict_fn_name = None 236 | if predict_fn_params is not None: 237 | for fn_name in fn_names: 238 | if predict_fn_ind in fn_name: 239 | if predict_fn_name is not None: 240 | raise ValueError( 241 | 'Multiple instances of predict_fn detected. %s' % 242 | fn_name) 243 | predict_fn_name = fn_name 244 | loaded_params[predict_fn_name] = predict_fn_params 245 | 246 | return loaded_params 247 | 248 | 249 | class RepresentationModel(nn.Module): 250 | def apply(self, 251 | x, 252 | encoder_fn, 253 | encoder_fn_kwargs, 254 | reduce_fn, 255 | reduce_fn_kwargs, 256 | num_categories, 257 | output_features, 258 | output='prediction', 259 | use_transformer=False, 260 | padding_mask=None): 261 | """Computes padding mask, encodes indices using embeddings, 262 | applies lensing operation, predicts scalar value. 263 | """ 264 | 265 | outputs = dict() 266 | 267 | if padding_mask is None: 268 | padding_mask = jnp.expand_dims(jnp.where(x < num_categories - 1, 1, 269 | 0), 270 | axis=2) 271 | 272 | if not use_transformer: 273 | x = encoder_fn(x, 274 | num_categories=num_categories, 275 | **encoder_fn_kwargs) 276 | else: 277 | x = encoder_fn(x) 278 | 279 | rep = reduce_fn(x, padding_mask=padding_mask, **reduce_fn_kwargs) 280 | 281 | outputs['embedding'] = rep 282 | 283 | out = nn.Dense(rep, 284 | output_features, 285 | kernel_init=nn.initializers.xavier_uniform(), 286 | bias_init=nn.initializers.normal(stddev=1e-6)) 287 | 288 | outputs['prediction'] = out 289 | 290 | return outputs[output] 291 | 292 | 293 | def create_representation_model(encoder_fn, 294 | encoder_fn_kwargs, 295 | reduce_fn, 296 | reduce_fn_kwargs, 297 | num_categories, 298 | output_features, 299 | output='prediction', 300 | key=random.PRNGKey(0), 301 | encoder_fn_params=None, 302 | reduce_fn_params=None, 303 | predict_fn_params=None): 304 | """Instantiates a RepresentationModel object.""" 305 | 306 | module = RepresentationModel.partial(encoder_fn=encoder_fn, 307 | encoder_fn_kwargs=encoder_fn_kwargs, 308 | reduce_fn=reduce_fn, 309 | reduce_fn_kwargs=reduce_fn_kwargs, 310 | num_categories=num_categories, 311 | output_features=output_features, 312 | output=output, 313 | use_transformer=False) 314 | 315 | _, initial_params = RepresentationModel.init_by_shape( 316 | key, 317 | input_specs=[((1, 1), jnp.float32)], 318 | encoder_fn=encoder_fn, 319 | encoder_fn_kwargs=encoder_fn_kwargs, 320 | reduce_fn=reduce_fn, 321 | reduce_fn_kwargs=reduce_fn_kwargs, 322 | num_categories=num_categories, 323 | output_features=output_features, 324 | output=output, 325 | use_transformer=False) 326 | 327 | loaded_params = load_params(initial_params, encoder_fn_params, 328 | reduce_fn_params, predict_fn_params) 329 | 330 | model = nn.Model(module, loaded_params) 331 | 332 | return model 333 | 334 | 335 | def create_transformer_representation_model(transformer_kwargs, 336 | reduce_fn, 337 | reduce_fn_kwargs, 338 | num_categories, 339 | output_features, 340 | bidirectional=False, 341 | output='prediction', 342 | key=random.PRNGKey(0), 343 | encoder_fn_params=None, 344 | reduce_fn_params=None, 345 | predict_fn_params=None): 346 | """Instantiates a RepresentationModel object with Transformer encoder.""" 347 | 348 | if not bidirectional: 349 | transformer = models.FlaxLM(**transformer_kwargs) 350 | else: 351 | transformer = models.FlaxBERT(**transformer_kwargs) 352 | transformer_optimizer = transformer._optimizer 353 | transformer_model = models.jax_utils.unreplicate( 354 | transformer_optimizer.target) 355 | transformer_encoder = transformer_model.module.partial( 356 | output_head='output_emb') 357 | 358 | module = RepresentationModel.partial(encoder_fn=transformer_encoder, 359 | encoder_fn_kwargs={}, 360 | reduce_fn=reduce_fn, 361 | reduce_fn_kwargs=reduce_fn_kwargs, 362 | num_categories=num_categories, 363 | output_features=output_features, 364 | output=output, 365 | use_transformer=True) 366 | 367 | _, initial_params = RepresentationModel.init_by_shape( 368 | key, 369 | input_specs=[((1, 1), jnp.float32)], 370 | encoder_fn=transformer_encoder, 371 | encoder_fn_kwargs={}, 372 | reduce_fn=reduce_fn, 373 | reduce_fn_kwargs=reduce_fn_kwargs, 374 | num_categories=num_categories, 375 | output_features=output_features, 376 | output=output, 377 | use_transformer=True) 378 | 379 | loaded_params = load_params(initial_params, encoder_fn_params, 380 | reduce_fn_params, predict_fn_params) 381 | 382 | model = nn.Model(module, loaded_params) 383 | 384 | return model 385 | 386 | 387 | def architecture_to_layers(encoder_fn_name, reduce_fn_name): 388 | 389 | layers = [] 390 | 391 | no_trainable_encoder = False 392 | if encoder_fn_name is None or encoder_fn_name == 'transformer': 393 | layers.append('Transformer_0') 394 | elif encoder_fn_name == 'one_hot': 395 | no_trainable_encoder = True 396 | elif encoder_fn_name == 'cnn_one_hot': 397 | layers.append('CNN_0') 398 | else: 399 | raise ValueError('Incorrect encoder name specified.') 400 | 401 | no_trainable_lens = False 402 | if reduce_fn_name == 'mean_pool' or reduce_fn_name == 'max_pool': 403 | no_trainable_lens = True 404 | elif reduce_fn_name == 'linear_mean_pool' or reduce_fn_name == 'linear_max_pool': 405 | if no_trainable_encoder: 406 | layers.append('Dense_0') 407 | else: 408 | layers.append('Dense_1') 409 | elif reduce_fn_name == 'gated_conv': 410 | if no_trainable_encoder: 411 | layers.append('GatedConv_0') 412 | else: 413 | layers.append('GatedConv_1') 414 | else: 415 | raise ValueError('Incorrect lens name specified.') 416 | 417 | if no_trainable_encoder: 418 | if no_trainable_lens: 419 | layers.append('Dense_0') 420 | else: 421 | layers.append('Dense_1') 422 | else: 423 | if no_trainable_lens: 424 | layers.append('Dense_1') 425 | else: 426 | layers.append('Dense_2') 427 | 428 | trainable_encoder = not no_trainable_encoder 429 | 430 | return layers, trainable_encoder 431 | -------------------------------------------------------------------------------- /pfam_experiment.py: -------------------------------------------------------------------------------- 1 | """Lens training + nearest neighbors classification pipeline.""" 2 | 3 | import os 4 | 5 | import sys 6 | sys.path.insert(1, 'google_research/') 7 | 8 | import flax 9 | from flax import nn 10 | from flax.training import checkpoints 11 | 12 | import jax 13 | from jax import random 14 | import jax.numpy as jnp 15 | from jax.config import config 16 | config.enable_omnistaging() 17 | 18 | import numpy as np 19 | 20 | import pandas as pd 21 | 22 | import json 23 | 24 | import copy 25 | 26 | import time 27 | 28 | from pkg_resources import resource_filename 29 | 30 | from fs_gcsfs import GCSFS 31 | 32 | from google_research.protein_lm import domains, models 33 | 34 | from contextual_lenses.contextual_lenses import reduce_fn_name_to_fn 35 | 36 | from contextual_lenses.train_utils import create_optimizer, train, \ 37 | create_representation_model, create_transformer_representation_model, \ 38 | architecture_to_layers 39 | 40 | from contextual_lenses.encoders import encoder_fn_name_to_fn 41 | 42 | from contextual_lenses.loss_fns import cross_entropy_loss 43 | 44 | from contextual_lenses.pfam_utils import get_family_ids, PFAM_NUM_CATEGORIES, \ 45 | pfam_evaluate, create_pfam_batches, pfam_nearest_neighbors_classification 46 | 47 | from contextual_lenses.load_transformer import load_transformer_params 48 | 49 | from absl import app, flags 50 | 51 | # Define flags. 52 | FLAGS = flags.FLAGS 53 | 54 | flags.DEFINE_string('encoder_fn_name', 'cnn_one_hot', 55 | 'Name of encoder_fn to use. None if using Transformer.') 56 | flags.DEFINE_string('encoder_fn_kwargs_path', 'cnn_kwargs', 57 | 'Path to encoder_fn_kwargs.') 58 | flags.DEFINE_string('reduce_fn_name', 'linear_max_pool', 59 | 'Name of reduce_fn to use.') 60 | flags.DEFINE_string('reduce_fn_kwargs_path', 'linear_pool_1024', 61 | 'Path to reduce_fn_kwargs.') 62 | 63 | flags.DEFINE_integer('epochs', 10, 'Number of epochs for lens training.') 64 | flags.DEFINE_integer( 65 | 'measurements', 1, 66 | 'Number of times to interrupt lens training loop to take measurements (1 = no interruption).' 67 | ) 68 | flags.DEFINE_integer('lens_batch_size', 64, 'Batch size for lens training.') 69 | flags.DEFINE_integer('knn_batch_size', 64, 70 | 'Batch size for KNN vector computation.') 71 | 72 | flags.DEFINE_float('encoder_lr', 0.0, 'Encoder learning rate.') 73 | flags.DEFINE_float('lens_lr', 1e-5, 'Lens learning rate.') 74 | flags.DEFINE_float('predictor_lr', 1e-3, 'Predictor learning rate.') 75 | 76 | flags.DEFINE_float('encoder_wd', 0.0, 'Encoder weight decay.') 77 | flags.DEFINE_float('lens_wd', 0.0, 'Lens weight decay.') 78 | flags.DEFINE_float('predictor_wd', 0.0, 'Predictor weight decay.') 79 | 80 | flags.DEFINE_integer('train_families', 10000, 81 | 'Number of famlies used to train lens.') 82 | flags.DEFINE_integer('lens_train_samples', 50, 83 | 'Number of samples used to train lens.') 84 | flags.DEFINE_integer('first_test_family', 15001, 'First family to test on.') 85 | flags.DEFINE_integer('last_test_family', 16000, 'Last family to test on.') 86 | 87 | flags.DEFINE_integer('lens_shuffle_seed', 0, 88 | 'Random seed used for lens training data batching.') 89 | flags.DEFINE_integer('lens_sample_random_state', 0, 90 | 'Random state used for lens training data sampling.') 91 | flags.DEFINE_integer('knn_shuffle_seed', 1, 92 | 'Random seed used for KNN data batching.') 93 | flags.DEFINE_integer('knn_sample_random_state', 1, 94 | 'Random state used for KNN data sampling.') 95 | flags.DEFINE_integer('model_random_key', 0, 96 | 'Random key used for model instantiation.') 97 | 98 | flags.DEFINE_boolean('use_transformer', False, 99 | 'Whether or not to use transformer encoder') 100 | flags.DEFINE_boolean('use_bert', False, 101 | 'Whether or not to use bidirectional transformer.') 102 | flags.DEFINE_string('restore_transformer_dir', None, 103 | 'Directory to load pretrained transformer from.') 104 | 105 | flags.DEFINE_string('load_gcs_bucket', 'neuralblast_public', 106 | 'GCS bucket to load from.') 107 | flags.DEFINE_string('data_partitions_dirpath', 'random_split/', 108 | 'Location of Pfam data in load GCS bucket.') 109 | 110 | flags.DEFINE_string('save_gcs_bucket', 'sequin-public', 111 | 'GCS bucket to save to.') 112 | flags.DEFINE_string('results_save_dir', '', 113 | 'Directory in save GCS bucket to save to.') 114 | 115 | flags.DEFINE_boolean('load_model', False, 116 | 'Whether or not to load a trained model.') 117 | flags.DEFINE_string('load_model_dir', '', 118 | 'Directory in load GCS bucket to load trained optimizer from.') 119 | flags.DEFINE_integer( 120 | 'load_model_step', 0, 121 | 'Number of steps optimizer to be loaded has been trained for.') 122 | 123 | flags.DEFINE_boolean('save_model', False, 124 | 'Whether or not to save trained model.') 125 | flags.DEFINE_string('save_model_dir', '', 126 | 'Directory in save GCS bucket to save trained optimizer to.') 127 | 128 | flags.DEFINE_string('label', '', 'Label used to save experiment results.') 129 | 130 | 131 | def get_model_kwargs(encoder_fn_name, encoder_fn_kwargs_path, reduce_fn_name, 132 | reduce_fn_kwargs_path): 133 | """Determines model components using string names.""" 134 | 135 | encoder_fn = encoder_fn_name_to_fn(encoder_fn_name) 136 | encoder_fn_kwargs = json.load( 137 | open( 138 | resource_filename( 139 | 'contextual_lenses.resources', 140 | os.path.join('encoder_fn_kwargs_resources', 141 | encoder_fn_kwargs_path + '.json')))) 142 | 143 | reduce_fn = reduce_fn_name_to_fn(reduce_fn_name) 144 | reduce_fn_kwargs = json.load( 145 | open( 146 | resource_filename( 147 | 'contextual_lenses.resources', 148 | os.path.join('reduce_fn_kwargs_resources', 149 | reduce_fn_kwargs_path + '.json')))) 150 | 151 | layers, trainable_encoder = architecture_to_layers(encoder_fn_name, 152 | reduce_fn_name) 153 | 154 | return encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs, layers 155 | 156 | 157 | def create_model(encoder_fn, 158 | encoder_fn_kwargs, 159 | reduce_fn, 160 | reduce_fn_kwargs, 161 | layers, 162 | output='prediction', 163 | use_transformer=False, 164 | use_bert=False, 165 | restore_transformer_dir=None, 166 | encoder_fn_params=None, 167 | reduce_fn_params=None, 168 | predict_fn_params=None, 169 | random_key=0): 170 | """Creates representation model (encoder --> lens --> predictor) architecture.""" 171 | 172 | family_ids = get_family_ids() 173 | num_families = len(family_ids) 174 | 175 | if use_transformer: 176 | 177 | if use_bert: 178 | model_cls = models.FlaxBERT 179 | else: 180 | model_cls = models.FlaxLM 181 | 182 | if encoder_fn_params is not None: 183 | pretrained_transformer_params = encoder_fn_params 184 | else: 185 | if restore_transformer_dir is not None: 186 | pretrained_transformer_params = load_transformer_params( 187 | restore_transformer_dir, model_cls) 188 | else: 189 | pretrained_transformer_params = None 190 | 191 | model = create_transformer_representation_model( 192 | transformer_kwargs=encoder_fn_kwargs, 193 | reduce_fn=reduce_fn, 194 | reduce_fn_kwargs=reduce_fn_kwargs, 195 | num_categories=PFAM_NUM_CATEGORIES, 196 | output_features=num_families, 197 | bidirectional=use_bert, 198 | output=output, 199 | key=random.PRNGKey(random_key), 200 | encoder_fn_params=pretrained_transformer_params, 201 | reduce_fn_params=reduce_fn_params, 202 | predict_fn_params=predict_fn_params) 203 | 204 | else: 205 | model = create_representation_model( 206 | encoder_fn=encoder_fn, 207 | encoder_fn_kwargs=encoder_fn_kwargs, 208 | reduce_fn=reduce_fn, 209 | reduce_fn_kwargs=reduce_fn_kwargs, 210 | num_categories=PFAM_NUM_CATEGORIES, 211 | output_features=num_families, 212 | output=output, 213 | key=random.PRNGKey(random_key), 214 | encoder_fn_params=encoder_fn_params, 215 | reduce_fn_params=reduce_fn_params, 216 | predict_fn_params=predict_fn_params) 217 | 218 | return model 219 | 220 | 221 | def set_model_parameters(model, params): 222 | """Updates a model's parameters using a parameters dictionary.""" 223 | 224 | params = copy.deepcopy(params) 225 | 226 | assert ( 227 | model.params.keys() == params.keys()), 'Model parameters do not match!' 228 | 229 | for layer in model.params.keys(): 230 | model.params[layer] = params[layer] 231 | 232 | return model 233 | 234 | 235 | def measure_nearest_neighbor_performance(accuracy_label, encoder, 236 | family_accessions, batch_size, 237 | train_samples, shuffle_seed, 238 | sample_random_state): 239 | """Measures nearest neighbor classification performance and updates datum.""" 240 | 241 | results = pfam_nearest_neighbors_classification( 242 | encoder=encoder, 243 | family_accessions=family_accessions, 244 | batch_size=batch_size, 245 | train_samples=train_samples, 246 | shuffle_seed=shuffle_seed, 247 | sample_random_state=sample_random_state, 248 | data_partitions_dirpath=FLAGS.data_partitions_dirpath, 249 | gcs_bucket=FLAGS.load_gcs_bucket)[0] 250 | 251 | accuracy = results['1-nn accuracy'] 252 | 253 | accuracy_dict = {accuracy_label: accuracy} 254 | 255 | return accuracy_dict 256 | 257 | 258 | # Train lens and measure performance of lens and nearest neighbors classifier. 259 | def main(_): 260 | 261 | if FLAGS.use_transformer: 262 | assert ( 263 | FLAGS.encoder_fn_name == 'transformer' 264 | ), 'encoder_fn_name must be transformer if use_transformer is True!' 265 | 266 | assert (FLAGS.epochs % FLAGS.measurements == 0 267 | ), 'Number of measurements must divide number of epochs!' 268 | measurement_epochs = FLAGS.epochs // FLAGS.measurements 269 | 270 | assert FLAGS.results_save_dir != '', 'Specify results_save_dir!' 271 | 272 | assert FLAGS.label != '', 'Specify label!' 273 | 274 | if FLAGS.load_model: 275 | assert FLAGS.load_model_dir != '', 'Specify load_model_dir!' 276 | assert FLAGS.load_model_step > 0, 'Loaded model must have been trained for more than 0 steps.' 277 | 278 | if FLAGS.save_model: 279 | assert FLAGS.save_model_dir != '', 'Specify save_model_dir!' 280 | 281 | datum = { 282 | 'label': FLAGS.label, 283 | 'encoder_fn_name': FLAGS.encoder_fn_name, 284 | 'encoder_fn_kwargs_path': FLAGS.encoder_fn_kwargs_path, 285 | 'reduce_fn_name': FLAGS.reduce_fn_name, 286 | 'reduce_fn_kwargs_path': FLAGS.reduce_fn_kwargs_path, 287 | 'epochs': FLAGS.epochs, 288 | 'measurements': FLAGS.measurements, 289 | 'lens_batch_size': FLAGS.lens_batch_size, 290 | 'knn_batch_size': FLAGS.knn_batch_size, 291 | 'encoder_lr': FLAGS.encoder_lr, 292 | 'lens_lr': FLAGS.lens_lr, 293 | 'predictor_lr': FLAGS.predictor_lr, 294 | 'encoder_wd': FLAGS.encoder_wd, 295 | 'lens_wd': FLAGS.lens_wd, 296 | 'predictor_wd': FLAGS.predictor_wd, 297 | 'train_families': FLAGS.train_families, 298 | 'lens_train_samples': FLAGS.lens_train_samples, 299 | 'first_test_family': FLAGS.first_test_family, 300 | 'last_test_family': FLAGS.last_test_family, 301 | 'lens_shuffle_seed': FLAGS.lens_shuffle_seed, 302 | 'lens_sample_random_state': FLAGS.lens_sample_random_state, 303 | 'knn_shuffle_seed': FLAGS.knn_shuffle_seed, 304 | 'knn_sample_random_state': FLAGS.knn_sample_random_state, 305 | 'model_random_key': FLAGS.model_random_key, 306 | 'use_transformer': FLAGS.use_transformer, 307 | 'use_bert': FLAGS.use_bert, 308 | 'restore_transformer_dir': FLAGS.restore_transformer_dir, 309 | 'load_gcs_bucket': FLAGS.load_gcs_bucket, 310 | 'data_partitions_dirpath': FLAGS.data_partitions_dirpath, 311 | 'save_gcs_bucket': FLAGS.save_gcs_bucket, 312 | 'results_save_dir': FLAGS.results_save_dir, 313 | 'load_model': FLAGS.load_model, 314 | 'load_model_dir': FLAGS.load_model_dir, 315 | 'load_model_step': FLAGS.load_model_step, 316 | 'save_model': FLAGS.save_model, 317 | 'save_model_dir': FLAGS.save_model_dir 318 | } 319 | 320 | gcsfs = GCSFS(FLAGS.save_gcs_bucket) 321 | 322 | print(datum) 323 | df = pd.DataFrame([datum]) 324 | with gcsfs.open(os.path.join(FLAGS.results_save_dir, FLAGS.label + '.csv'), 325 | 'w') as gcs_file: 326 | df.to_csv(gcs_file, index=False) 327 | 328 | knn_train_samples_ = [1, 5, 10, 50] 329 | 330 | family_ids = get_family_ids() 331 | num_families = len(family_ids) 332 | loss_fn_kwargs = {'num_classes': num_families} 333 | 334 | lens_knn_train_family_accessions = [] 335 | for _ in range(1, FLAGS.train_families + 1): 336 | family_name = 'PF%05d' % _ 337 | lens_knn_train_family_accessions.append(family_name) 338 | 339 | knn_test_family_accessions = [] 340 | for _ in range(FLAGS.first_test_family, FLAGS.last_test_family + 1): 341 | family_name = 'PF%05d' % _ 342 | knn_test_family_accessions.append(family_name) 343 | 344 | encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs, layers = get_model_kwargs( 345 | encoder_fn_name=FLAGS.encoder_fn_name, 346 | encoder_fn_kwargs_path=FLAGS.encoder_fn_kwargs_path, 347 | reduce_fn_name=FLAGS.reduce_fn_name, 348 | reduce_fn_kwargs_path=FLAGS.reduce_fn_kwargs_path) 349 | 350 | embedding_model = create_model( 351 | encoder_fn=encoder_fn, 352 | encoder_fn_kwargs=encoder_fn_kwargs, 353 | reduce_fn=reduce_fn, 354 | reduce_fn_kwargs=reduce_fn_kwargs, 355 | layers=layers, 356 | output='embedding', 357 | use_transformer=FLAGS.use_transformer, 358 | use_bert=FLAGS.use_bert, 359 | restore_transformer_dir=FLAGS.restore_transformer_dir, 360 | random_key=FLAGS.model_random_key) 361 | 362 | datum.update( 363 | measure_nearest_neighbor_performance( 364 | accuracy_label= 365 | 'train_knn_accuracy_untrained_lens_1_knn_train_samples', 366 | encoder=embedding_model, 367 | family_accessions=lens_knn_train_family_accessions, 368 | batch_size=FLAGS.knn_batch_size, 369 | train_samples=1, 370 | shuffle_seed=FLAGS.knn_shuffle_seed, 371 | sample_random_state=FLAGS.knn_sample_random_state)) 372 | 373 | for knn_train_samples in knn_train_samples_: 374 | 375 | datum.update( 376 | measure_nearest_neighbor_performance( 377 | accuracy_label='test_knn_accuracy_untrained_lens_' + 378 | str(knn_train_samples) + '_knn_train_samples', 379 | encoder=embedding_model, 380 | family_accessions=knn_test_family_accessions, 381 | batch_size=FLAGS.knn_batch_size, 382 | train_samples=knn_train_samples, 383 | shuffle_seed=FLAGS.knn_shuffle_seed, 384 | sample_random_state=FLAGS.knn_sample_random_state)) 385 | 386 | model = create_model(encoder_fn=encoder_fn, 387 | encoder_fn_kwargs=encoder_fn_kwargs, 388 | reduce_fn=reduce_fn, 389 | reduce_fn_kwargs=reduce_fn_kwargs, 390 | layers=layers, 391 | output='prediction', 392 | use_transformer=FLAGS.use_transformer, 393 | use_bert=FLAGS.use_bert, 394 | restore_transformer_dir=FLAGS.restore_transformer_dir, 395 | random_key=FLAGS.model_random_key) 396 | 397 | optimizer = create_optimizer( 398 | model=model, 399 | learning_rate=[FLAGS.encoder_lr, FLAGS.lens_lr, FLAGS.predictor_lr], 400 | weight_decay=[FLAGS.encoder_wd, FLAGS.lens_wd, FLAGS.predictor_wd], 401 | layers=layers) 402 | 403 | if FLAGS.load_model: 404 | optimizer = checkpoints.restore_checkpoint(ckpt_dir=os.path.join( 405 | 'gs://' + FLAGS.load_gcs_bucket, FLAGS.load_model_dir), 406 | target=optimizer, 407 | step=FLAGS.load_model_step) 408 | 409 | trained_params = optimizer.target.params 410 | embedding_model = set_model_parameters(model=embedding_model, 411 | params=trained_params) 412 | 413 | if FLAGS.save_model: 414 | checkpoints.save_checkpoint(ckpt_dir=os.path.join( 415 | 'gs://' + FLAGS.save_gcs_bucket, FLAGS.save_model_dir), 416 | target=optimizer, 417 | step=FLAGS.load_model_step) 418 | 419 | for i in range(FLAGS.measurements): 420 | 421 | train_batches, train_indexes = create_pfam_batches( 422 | family_accessions=lens_knn_train_family_accessions, 423 | batch_size=FLAGS.lens_batch_size, 424 | samples=FLAGS.lens_train_samples, 425 | epochs=measurement_epochs, 426 | drop_remainder=True, 427 | shuffle_seed=FLAGS.lens_shuffle_seed + i, 428 | sample_random_state=FLAGS.lens_sample_random_state) 429 | 430 | optimizer = train( 431 | model=optimizer.target, 432 | train_data=train_batches, 433 | loss_fn=cross_entropy_loss, 434 | loss_fn_kwargs=loss_fn_kwargs, 435 | learning_rate=[ 436 | FLAGS.encoder_lr, FLAGS.lens_lr, FLAGS.predictor_lr 437 | ], 438 | weight_decay=[FLAGS.encoder_wd, FLAGS.lens_wd, FLAGS.predictor_wd], 439 | layers=layers) 440 | 441 | results, preds = pfam_evaluate( 442 | predict_fn=optimizer.target, 443 | test_family_accessions=lens_knn_train_family_accessions, 444 | title=None, 445 | loss_fn_kwargs=loss_fn_kwargs, 446 | batch_size=FLAGS.lens_batch_size, 447 | data_partitions_dirpath=FLAGS.data_partitions_dirpath, 448 | gcs_bucket=FLAGS.load_gcs_bucket) 449 | 450 | lens_accuracy = results['accuracy'] 451 | datum['lens_accuracy' + '_measurement_' + str(i)] = lens_accuracy 452 | 453 | lens_cross_entropy = float(results['cross_entropy']) 454 | datum['lens_cross_entropy' + '_measurement_' + 455 | str(i)] = lens_cross_entropy 456 | 457 | trained_params = optimizer.target.params 458 | embedding_model = set_model_parameters(model=embedding_model, 459 | params=trained_params) 460 | 461 | datum.update( 462 | measure_nearest_neighbor_performance( 463 | accuracy_label= 464 | 'train_knn_accuracy_trained_lens_1_knn_train_samples' + 465 | '_measurement_' + str(i), 466 | encoder=embedding_model, 467 | family_accessions=lens_knn_train_family_accessions, 468 | batch_size=FLAGS.knn_batch_size, 469 | train_samples=1, 470 | shuffle_seed=FLAGS.knn_shuffle_seed, 471 | sample_random_state=FLAGS.knn_sample_random_state)) 472 | 473 | for knn_train_samples in knn_train_samples_: 474 | 475 | datum.update( 476 | measure_nearest_neighbor_performance( 477 | accuracy_label='test_knn_accuracy_trained_lens_' + 478 | str(knn_train_samples) + '_knn_train_samples' + 479 | '_measurement_' + str(i), 480 | encoder=embedding_model, 481 | family_accessions=knn_test_family_accessions, 482 | batch_size=FLAGS.knn_batch_size, 483 | train_samples=knn_train_samples, 484 | shuffle_seed=FLAGS.knn_shuffle_seed, 485 | sample_random_state=FLAGS.knn_sample_random_state)) 486 | 487 | print(datum) 488 | df = pd.DataFrame([datum]) 489 | with gcsfs.open(os.path.join(FLAGS.results_save_dir, FLAGS.label + '.csv'), 490 | 'w') as gcs_file: 491 | df.to_csv(gcs_file, index=False) 492 | 493 | if FLAGS.save_model: 494 | checkpoints.save_checkpoint(ckpt_dir=os.path.join( 495 | 'gs://' + FLAGS.save_gcs_bucket, FLAGS.save_model_dir), 496 | target=optimizer, 497 | step=FLAGS.load_model_step + FLAGS.epochs) 498 | 499 | 500 | if __name__ == '__main__': 501 | app.run(main) 502 | -------------------------------------------------------------------------------- /generate_params.py: -------------------------------------------------------------------------------- 1 | """Generate JSON files with hyperparameter combos for caliban.""" 2 | 3 | import json 4 | 5 | import itertools 6 | 7 | import os 8 | 9 | import numpy as np 10 | 11 | from frozendict import frozendict 12 | 13 | 14 | def create_params(encoder_lrs, 15 | lens_lrs, 16 | predictor_lrs, 17 | encoder_wds, 18 | lens_wds, 19 | predictor_wds, 20 | reduce_fn_kwargs_paths, 21 | lens_train_samples, 22 | train_families, 23 | epochs, 24 | measurements, 25 | encoder_fn_name, 26 | encoder_fn_kwargs_path, 27 | reduce_fn_name, 28 | lens_batch_size=64, 29 | knn_batch_size=64, 30 | use_transformer=False, 31 | use_bert=False, 32 | restore_transformer_dir=None, 33 | model_random_keys=[0], 34 | first_test_family=15000, 35 | last_test_family=16000, 36 | load_gcs_bucket='neuralblast_public', 37 | data_partitions_dirpath='random_split/', 38 | save_gcs_bucket='sequin-public', 39 | results_save_dir='pfam_experiment_data', 40 | lens_shuffle_seed=0, 41 | lens_sample_random_state=0, 42 | knn_shuffle_seed=1, 43 | knn_sample_random_state=1, 44 | load_model=False, 45 | load_model_dir='', 46 | load_model_step=0, 47 | save_model=False, 48 | save_model_dir=''): 49 | """Generates parameters from lists of parameters.""" 50 | 51 | params = [] 52 | 53 | for encoder_lr, lens_lr, predictor_lr, encoder_wd, lens_wd, predictor_wd, reduce_fn_kwargs_path, lens_train_samples, model_random_key in \ 54 | itertools.product(encoder_lrs, lens_lrs, predictor_lrs, encoder_wds, lens_wds, predictor_wds, reduce_fn_kwargs_paths, lens_train_samples, model_random_keys): 55 | 56 | param_dict = { 57 | 'encoder_fn_name': encoder_fn_name, 58 | 'encoder_fn_kwargs_path': encoder_fn_kwargs_path, 59 | 'reduce_fn_name': reduce_fn_name, 60 | 'reduce_fn_kwargs_path': reduce_fn_kwargs_path, 61 | 'epochs': epochs, 62 | 'measurements': measurements, 63 | 'lens_batch_size': lens_batch_size, 64 | 'knn_batch_size': knn_batch_size, 65 | 'encoder_lr': encoder_lr, 66 | 'lens_lr': lens_lr, 67 | 'predictor_lr': predictor_lr, 68 | 'encoder_wd': encoder_wd, 69 | 'lens_wd': lens_wd, 70 | 'predictor_wd': predictor_wd, 71 | 'train_families': train_families, 72 | 'lens_train_samples': lens_train_samples, 73 | 'first_test_family': first_test_family, 74 | 'last_test_family': last_test_family, 75 | 'lens_shuffle_seed': lens_shuffle_seed, 76 | 'lens_sample_random_state': lens_sample_random_state, 77 | 'knn_shuffle_seed': knn_shuffle_seed, 78 | 'knn_sample_random_state': knn_sample_random_state, 79 | 'model_random_key': model_random_key, 80 | 'use_transformer': use_transformer, 81 | 'use_bert': use_bert, 82 | 'load_gcs_bucket': load_gcs_bucket, 83 | 'data_partitions_dirpath': data_partitions_dirpath, 84 | 'save_gcs_bucket': save_gcs_bucket, 85 | 'results_save_dir': results_save_dir, 86 | 'load_model': load_model, 87 | 'load_model_dir': load_model_dir, 88 | 'load_model_step': load_model_step, 89 | 'save_model': save_model, 90 | 'save_model_dir': save_model_dir 91 | } 92 | 93 | if restore_transformer_dir is not None: 94 | param_dict['restore_transformer_dir'] = restore_transformer_dir 95 | 96 | params.append(param_dict) 97 | 98 | return params 99 | 100 | 101 | # Generate parameters from different sets of parameter combinations. 102 | def main(load_gcs_bucket, save_gcs_bucket): 103 | 104 | params = [] 105 | 106 | # 1000 train families sweep 107 | params += create_params( 108 | encoder_lrs=[0.0], 109 | lens_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5], 110 | predictor_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5], 111 | encoder_wds=[0.0], 112 | lens_wds=[0.0, 0.1], 113 | predictor_wds=[0.0, 0.1], 114 | reduce_fn_kwargs_paths=['linear_pool_256', 'linear_pool_1024'], 115 | lens_train_samples=[50], 116 | train_families=1000, 117 | epochs=50, 118 | measurements=1, 119 | encoder_fn_name='transformer', 120 | encoder_fn_kwargs_path='medium_transformer_kwargs', 121 | reduce_fn_name='linear_max_pool', 122 | lens_batch_size=64, 123 | knn_batch_size=64, 124 | use_transformer=True, 125 | use_bert=True, 126 | restore_transformer_dir=None, 127 | load_gcs_bucket=load_gcs_bucket, 128 | save_gcs_bucket=save_gcs_bucket) 129 | 130 | params += create_params( 131 | encoder_lrs=[0.0], 132 | lens_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5], 133 | predictor_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5], 134 | encoder_wds=[0.0], 135 | lens_wds=[0.0, 0.1], 136 | predictor_wds=[0.0, 0.1], 137 | reduce_fn_kwargs_paths=['linear_pool_256', 'linear_pool_1024'], 138 | lens_train_samples=[50], 139 | train_families=1000, 140 | epochs=50, 141 | measurements=1, 142 | encoder_fn_name='transformer', 143 | encoder_fn_kwargs_path='medium_transformer_kwargs', 144 | reduce_fn_name='linear_max_pool', 145 | lens_batch_size=64, 146 | knn_batch_size=64, 147 | use_transformer=True, 148 | use_bert=True, 149 | restore_transformer_dir= 150 | 'gs://sequin-public/transformer_models/medium_trembl_bert/', 151 | load_gcs_bucket=load_gcs_bucket, 152 | save_gcs_bucket=save_gcs_bucket) 153 | 154 | params += create_params( 155 | encoder_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5], 156 | lens_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5], 157 | predictor_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5], 158 | encoder_wds=[0.0, 0.1], 159 | lens_wds=[0.0, 0.1], 160 | predictor_wds=[0.0, 0.1], 161 | reduce_fn_kwargs_paths=['linear_pool_256', 'linear_pool_1024'], 162 | lens_train_samples=[50], 163 | train_families=1000, 164 | epochs=50, 165 | measurements=1, 166 | encoder_fn_name='cnn_one_hot', 167 | encoder_fn_kwargs_path='1-layer_cnn_kwargs', 168 | reduce_fn_name='linear_max_pool', 169 | lens_batch_size=256, 170 | knn_batch_size=256, 171 | use_transformer=False, 172 | use_bert=False, 173 | restore_transformer_dir=None, 174 | load_gcs_bucket=load_gcs_bucket, 175 | save_gcs_bucket=save_gcs_bucket) 176 | 177 | # 10000 train families sweep 178 | params += create_params(encoder_lrs=[0.0], 179 | lens_lrs=[1e-3, 5e-4, 1e-4, 5e-5], 180 | predictor_lrs=[1e-3, 5e-4, 1e-4, 5e-5], 181 | encoder_wds=[0.0], 182 | lens_wds=[0.0, 0.05, 0.1, 0.2], 183 | predictor_wds=[0.0, 0.05, 0.1, 0.2], 184 | reduce_fn_kwargs_paths=['linear_pool_1024'], 185 | lens_train_samples=[50], 186 | train_families=10000, 187 | epochs=10, 188 | measurements=2, 189 | encoder_fn_name='transformer', 190 | encoder_fn_kwargs_path='medium_transformer_kwargs', 191 | reduce_fn_name='linear_max_pool', 192 | lens_batch_size=64, 193 | knn_batch_size=64, 194 | use_transformer=True, 195 | use_bert=True, 196 | restore_transformer_dir=None, 197 | load_gcs_bucket=load_gcs_bucket, 198 | save_gcs_bucket=save_gcs_bucket) 199 | 200 | params += create_params( 201 | encoder_lrs=[0.0], 202 | lens_lrs=[1e-3, 5e-4, 1e-4, 5e-5], 203 | predictor_lrs=[1e-3, 5e-4, 1e-4, 5e-5], 204 | encoder_wds=[0.0], 205 | lens_wds=[0.0, 0.05, 0.1, 0.2], 206 | predictor_wds=[0.0, 0.05, 0.1, 0.2], 207 | reduce_fn_kwargs_paths=['linear_pool_1024'], 208 | lens_train_samples=[50], 209 | train_families=10000, 210 | epochs=10, 211 | measurements=2, 212 | encoder_fn_name='transformer', 213 | encoder_fn_kwargs_path='medium_transformer_kwargs', 214 | reduce_fn_name='linear_max_pool', 215 | lens_batch_size=64, 216 | knn_batch_size=64, 217 | use_transformer=True, 218 | use_bert=True, 219 | restore_transformer_dir= 220 | 'gs://sequin-public/transformer_models/medium_trembl_bert/', 221 | load_gcs_bucket=load_gcs_bucket, 222 | save_gcs_bucket=save_gcs_bucket) 223 | 224 | params += create_params(encoder_lrs=[1e-3, 1e-4, 1e-5], 225 | lens_lrs=[1e-3, 1e-4, 1e-5], 226 | predictor_lrs=[1e-3, 1e-4, 1e-5], 227 | encoder_wds=[0.0, 0.1, 0.2], 228 | lens_wds=[0.0, 0.1, 0.2], 229 | predictor_wds=[0.0, 0.1, 0.2], 230 | reduce_fn_kwargs_paths=['linear_pool_1024'], 231 | lens_train_samples=[50], 232 | train_families=10000, 233 | epochs=10, 234 | measurements=2, 235 | encoder_fn_name='cnn_one_hot', 236 | encoder_fn_kwargs_path='2-layer_cnn_kwargs', 237 | reduce_fn_name='linear_max_pool', 238 | lens_batch_size=512, 239 | knn_batch_size=512, 240 | use_transformer=False, 241 | use_bert=False, 242 | restore_transformer_dir=None, 243 | load_gcs_bucket=load_gcs_bucket, 244 | save_gcs_bucket=save_gcs_bucket) 245 | 246 | params += create_params(encoder_lrs=[0.0], 247 | lens_lrs=[1e-4, 5e-5, 1e-5], 248 | predictor_lrs=[1e-3, 5e-4, 1e-4], 249 | encoder_wds=[0.0], 250 | lens_wds=[0.05, 0.1, 0.2], 251 | predictor_wds=[0.0, 0.05, 0.1], 252 | reduce_fn_kwargs_paths=['linear_pool_1024'], 253 | lens_train_samples=[50], 254 | train_families=10000, 255 | epochs=10, 256 | measurements=2, 257 | encoder_fn_name='transformer', 258 | encoder_fn_kwargs_path='medium_transformer_kwargs', 259 | reduce_fn_name='linear_max_pool', 260 | lens_batch_size=64, 261 | knn_batch_size=64, 262 | use_transformer=True, 263 | use_bert=True, 264 | restore_transformer_dir=None, 265 | load_gcs_bucket=load_gcs_bucket, 266 | save_gcs_bucket=save_gcs_bucket) 267 | 268 | params += create_params( 269 | encoder_lrs=[0.0], 270 | lens_lrs=[1e-4, 5e-5, 1e-5], 271 | predictor_lrs=[1e-3, 5e-4, 1e-4], 272 | encoder_wds=[0.0], 273 | lens_wds=[0.05, 0.1, 0.2], 274 | predictor_wds=[0.15, 0.2, 0.25], 275 | reduce_fn_kwargs_paths=['linear_pool_1024'], 276 | lens_train_samples=[50], 277 | train_families=10000, 278 | epochs=10, 279 | measurements=2, 280 | encoder_fn_name='transformer', 281 | encoder_fn_kwargs_path='medium_transformer_kwargs', 282 | reduce_fn_name='linear_max_pool', 283 | lens_batch_size=64, 284 | knn_batch_size=64, 285 | use_transformer=True, 286 | use_bert=True, 287 | restore_transformer_dir= 288 | 'gs://sequin-public/transformer_models/medium_trembl_bert/', 289 | load_gcs_bucket=load_gcs_bucket, 290 | save_gcs_bucket=save_gcs_bucket) 291 | 292 | params += create_params(encoder_lrs=[1e-3, 5e-4, 1e-4], 293 | lens_lrs=[1e-4, 5e-5, 1e-5], 294 | predictor_lrs=[1e-3, 5e-4, 1e-4], 295 | encoder_wds=[0.05, 0.1, 0.2], 296 | lens_wds=[0.05, 0.1, 0.2], 297 | predictor_wds=[0.0, 0.05, 0.1], 298 | reduce_fn_kwargs_paths=['linear_pool_1024'], 299 | lens_train_samples=[50], 300 | train_families=10000, 301 | epochs=10, 302 | measurements=2, 303 | encoder_fn_name='cnn_one_hot', 304 | encoder_fn_kwargs_path='2-layer_cnn_kwargs', 305 | reduce_fn_name='linear_max_pool', 306 | lens_batch_size=512, 307 | knn_batch_size=512, 308 | use_transformer=False, 309 | use_bert=False, 310 | restore_transformer_dir=None, 311 | load_gcs_bucket=load_gcs_bucket, 312 | save_gcs_bucket=save_gcs_bucket) 313 | 314 | params += create_params(encoder_lrs=[1e-3, 5e-4, 1e-4, 5e-5], 315 | lens_lrs=[5e-5, 1e-5, 5e-6], 316 | predictor_lrs=[1e-4, 5e-5, 1e-5, 5e-6], 317 | encoder_wds=[0.2, 0.3], 318 | lens_wds=[0.05, 0.1], 319 | predictor_wds=[0.0, 0.05, 0.1], 320 | reduce_fn_kwargs_paths=['linear_pool_1024'], 321 | lens_train_samples=[50], 322 | train_families=10000, 323 | epochs=10, 324 | measurements=2, 325 | encoder_fn_name='cnn_one_hot', 326 | encoder_fn_kwargs_path='2-layer_cnn_kwargs', 327 | reduce_fn_name='linear_max_pool', 328 | lens_batch_size=512, 329 | knn_batch_size=512, 330 | use_transformer=False, 331 | use_bert=False, 332 | restore_transformer_dir=None, 333 | load_gcs_bucket=load_gcs_bucket, 334 | save_gcs_bucket=save_gcs_bucket) 335 | 336 | # 10000 train families medium transformer random keys sweep 337 | params += create_params(encoder_lrs=[0.0], 338 | lens_lrs=[1e-5], 339 | predictor_lrs=[1e-3], 340 | encoder_wds=[0.0], 341 | lens_wds=[0.05], 342 | predictor_wds=[0.0], 343 | reduce_fn_kwargs_paths=['linear_pool_1024'], 344 | lens_train_samples=[50], 345 | train_families=10000, 346 | epochs=10, 347 | measurements=2, 348 | encoder_fn_name='transformer', 349 | encoder_fn_kwargs_path='medium_transformer_kwargs', 350 | reduce_fn_name='linear_max_pool', 351 | lens_batch_size=64, 352 | knn_batch_size=64, 353 | use_transformer=True, 354 | use_bert=True, 355 | load_gcs_bucket=load_gcs_bucket, 356 | save_gcs_bucket=save_gcs_bucket, 357 | model_random_keys=range(10)) 358 | 359 | params += create_params( 360 | encoder_lrs=[0.0], 361 | lens_lrs=[1e-5], 362 | predictor_lrs=[1e-3], 363 | encoder_wds=[0.0], 364 | lens_wds=[0.05], 365 | predictor_wds=[0.0], 366 | reduce_fn_kwargs_paths=['linear_pool_1024'], 367 | lens_train_samples=[50], 368 | train_families=10000, 369 | epochs=10, 370 | measurements=2, 371 | encoder_fn_name='transformer', 372 | encoder_fn_kwargs_path='medium_transformer_kwargs', 373 | reduce_fn_name='linear_max_pool', 374 | lens_batch_size=64, 375 | knn_batch_size=64, 376 | use_transformer=True, 377 | use_bert=True, 378 | restore_transformer_dir= 379 | 'gs://sequin-public/transformer_models/medium_trembl_bert/', 380 | load_gcs_bucket=load_gcs_bucket, 381 | save_gcs_bucket=save_gcs_bucket, 382 | model_random_keys=range(10)) 383 | 384 | params += create_params(encoder_lrs=[0.0], 385 | lens_lrs=[5e-5], 386 | predictor_lrs=[5e-4], 387 | encoder_wds=[0.0], 388 | lens_wds=[0.2], 389 | predictor_wds=[0.2], 390 | reduce_fn_kwargs_paths=['linear_pool_1024'], 391 | lens_train_samples=[50], 392 | train_families=10000, 393 | epochs=10, 394 | measurements=2, 395 | encoder_fn_name='transformer', 396 | encoder_fn_kwargs_path='medium_transformer_kwargs', 397 | reduce_fn_name='linear_max_pool', 398 | lens_batch_size=64, 399 | knn_batch_size=64, 400 | use_transformer=True, 401 | use_bert=True, 402 | load_gcs_bucket=load_gcs_bucket, 403 | save_gcs_bucket=save_gcs_bucket, 404 | model_random_keys=range(10)) 405 | 406 | params += create_params( 407 | encoder_lrs=[0.0], 408 | lens_lrs=[5e-5], 409 | predictor_lrs=[5e-4], 410 | encoder_wds=[0.0], 411 | lens_wds=[0.2], 412 | predictor_wds=[0.2], 413 | reduce_fn_kwargs_paths=['linear_pool_1024'], 414 | lens_train_samples=[50], 415 | train_families=10000, 416 | epochs=10, 417 | measurements=2, 418 | encoder_fn_name='transformer', 419 | encoder_fn_kwargs_path='medium_transformer_kwargs', 420 | reduce_fn_name='linear_max_pool', 421 | lens_batch_size=64, 422 | knn_batch_size=64, 423 | use_transformer=True, 424 | use_bert=True, 425 | restore_transformer_dir= 426 | 'gs://sequin-public/transformer_models/medium_trembl_bert/', 427 | load_gcs_bucket=load_gcs_bucket, 428 | save_gcs_bucket=save_gcs_bucket, 429 | model_random_keys=range(10)) 430 | 431 | # 10000 train families small transformer random keys sweep 432 | params += create_params(encoder_lrs=[0.0], 433 | lens_lrs=[1e-3, 1e-4, 1e-5, 1e-6], 434 | predictor_lrs=[1e-3, 1e-4, 1e-5, 1e-6], 435 | encoder_wds=[0.0], 436 | lens_wds=[0.0, 0.1, 0.2], 437 | predictor_wds=[0.0, 0.1, 0.2], 438 | reduce_fn_kwargs_paths=['linear_pool_1024'], 439 | lens_train_samples=[50], 440 | train_families=10000, 441 | epochs=10, 442 | measurements=1, 443 | encoder_fn_name='transformer', 444 | encoder_fn_kwargs_path='small_transformer_kwargs', 445 | reduce_fn_name='linear_max_pool', 446 | lens_batch_size=64, 447 | knn_batch_size=64, 448 | use_transformer=True, 449 | use_bert=True, 450 | load_gcs_bucket=load_gcs_bucket, 451 | save_gcs_bucket=save_gcs_bucket) 452 | 453 | params += create_params( 454 | encoder_lrs=[0.0], 455 | lens_lrs=[1e-3, 1e-4, 1e-5, 1e-6], 456 | predictor_lrs=[1e-3, 1e-4, 1e-5, 1e-6], 457 | encoder_wds=[0.0], 458 | lens_wds=[0.0, 0.1, 0.2], 459 | predictor_wds=[0.0, 0.1, 0.2], 460 | reduce_fn_kwargs_paths=['linear_pool_1024'], 461 | lens_train_samples=[50], 462 | train_families=10000, 463 | epochs=10, 464 | measurements=1, 465 | encoder_fn_name='transformer', 466 | encoder_fn_kwargs_path='small_transformer_kwargs', 467 | reduce_fn_name='linear_max_pool', 468 | lens_batch_size=64, 469 | knn_batch_size=64, 470 | use_transformer=True, 471 | use_bert=True, 472 | restore_transformer_dir= 473 | 'gs://sequin-public/transformer_models/small_trembl_bert/', 474 | load_gcs_bucket=load_gcs_bucket, 475 | save_gcs_bucket=save_gcs_bucket) 476 | 477 | params += create_params(encoder_lrs=[0.0], 478 | lens_lrs=[1e-4], 479 | predictor_lrs=[1e-3], 480 | encoder_wds=[0.0], 481 | lens_wds=[0.1], 482 | predictor_wds=[0.0], 483 | reduce_fn_kwargs_paths=['linear_pool_1024'], 484 | lens_train_samples=[50], 485 | train_families=10000, 486 | epochs=10, 487 | measurements=1, 488 | encoder_fn_name='transformer', 489 | encoder_fn_kwargs_path='small_transformer_kwargs', 490 | reduce_fn_name='linear_max_pool', 491 | lens_batch_size=64, 492 | knn_batch_size=64, 493 | use_transformer=True, 494 | use_bert=True, 495 | load_gcs_bucket=load_gcs_bucket, 496 | save_gcs_bucket=save_gcs_bucket, 497 | model_random_keys=range(10)) 498 | 499 | params += create_params( 500 | encoder_lrs=[0.0], 501 | lens_lrs=[1e-4], 502 | predictor_lrs=[1e-3], 503 | encoder_wds=[0.0], 504 | lens_wds=[0.1], 505 | predictor_wds=[0.0], 506 | reduce_fn_kwargs_paths=['linear_pool_1024'], 507 | lens_train_samples=[50], 508 | train_families=10000, 509 | epochs=10, 510 | measurements=1, 511 | encoder_fn_name='transformer', 512 | encoder_fn_kwargs_path='small_transformer_kwargs', 513 | reduce_fn_name='linear_max_pool', 514 | lens_batch_size=64, 515 | knn_batch_size=64, 516 | use_transformer=True, 517 | use_bert=True, 518 | load_gcs_bucket=load_gcs_bucket, 519 | save_gcs_bucket=save_gcs_bucket, 520 | restore_transformer_dir= 521 | 'gs://sequin-public/transformer_models/small_trembl_bert/', 522 | model_random_keys=range(10)) 523 | 524 | params += create_params(encoder_lrs=[0.0], 525 | lens_lrs=[1e-3], 526 | predictor_lrs=[1e-3], 527 | encoder_wds=[0.0], 528 | lens_wds=[0.1], 529 | predictor_wds=[0.0], 530 | reduce_fn_kwargs_paths=['linear_pool_1024'], 531 | lens_train_samples=[50], 532 | train_families=10000, 533 | epochs=10, 534 | measurements=1, 535 | encoder_fn_name='transformer', 536 | encoder_fn_kwargs_path='small_transformer_kwargs', 537 | reduce_fn_name='linear_max_pool', 538 | lens_batch_size=64, 539 | knn_batch_size=64, 540 | use_transformer=True, 541 | use_bert=True, 542 | load_gcs_bucket=load_gcs_bucket, 543 | save_gcs_bucket=save_gcs_bucket, 544 | model_random_keys=range(10)) 545 | 546 | params += create_params( 547 | encoder_lrs=[0.0], 548 | lens_lrs=[1e-3], 549 | predictor_lrs=[1e-3], 550 | encoder_wds=[0.0], 551 | lens_wds=[0.1], 552 | predictor_wds=[0.0], 553 | reduce_fn_kwargs_paths=['linear_pool_1024'], 554 | lens_train_samples=[50], 555 | train_families=10000, 556 | epochs=10, 557 | measurements=1, 558 | encoder_fn_name='transformer', 559 | encoder_fn_kwargs_path='small_transformer_kwargs', 560 | reduce_fn_name='linear_max_pool', 561 | lens_batch_size=64, 562 | knn_batch_size=64, 563 | use_transformer=True, 564 | use_bert=True, 565 | load_gcs_bucket=load_gcs_bucket, 566 | save_gcs_bucket=save_gcs_bucket, 567 | restore_transformer_dir= 568 | 'gs://sequin-public/transformer_models/small_trembl_bert/', 569 | model_random_keys=range(10)) 570 | 571 | params += create_params(encoder_lrs=[0.0], 572 | lens_lrs=[1e-4], 573 | predictor_lrs=[1e-3], 574 | encoder_wds=[0.0], 575 | lens_wds=[0.0], 576 | predictor_wds=[0.0], 577 | reduce_fn_kwargs_paths=['linear_pool_1024'], 578 | lens_train_samples=[50], 579 | train_families=10000, 580 | epochs=10, 581 | measurements=1, 582 | encoder_fn_name='transformer', 583 | encoder_fn_kwargs_path='small_transformer_kwargs', 584 | reduce_fn_name='linear_max_pool', 585 | lens_batch_size=64, 586 | knn_batch_size=64, 587 | use_transformer=True, 588 | use_bert=True, 589 | load_gcs_bucket=load_gcs_bucket, 590 | save_gcs_bucket=save_gcs_bucket, 591 | model_random_keys=range(10)) 592 | 593 | params += create_params( 594 | encoder_lrs=[0.0], 595 | lens_lrs=[1e-4], 596 | predictor_lrs=[1e-3], 597 | encoder_wds=[0.0], 598 | lens_wds=[0.0], 599 | predictor_wds=[0.0], 600 | reduce_fn_kwargs_paths=['linear_pool_1024'], 601 | lens_train_samples=[50], 602 | train_families=10000, 603 | epochs=10, 604 | measurements=1, 605 | encoder_fn_name='transformer', 606 | encoder_fn_kwargs_path='small_transformer_kwargs', 607 | reduce_fn_name='linear_max_pool', 608 | lens_batch_size=64, 609 | knn_batch_size=64, 610 | use_transformer=True, 611 | use_bert=True, 612 | load_gcs_bucket=load_gcs_bucket, 613 | save_gcs_bucket=save_gcs_bucket, 614 | restore_transformer_dir= 615 | 'gs://sequin-public/transformer_models/small_trembl_bert/', 616 | model_random_keys=range(10)) 617 | 618 | params += create_params(encoder_lrs=[0.0], 619 | lens_lrs=[1e-4], 620 | predictor_lrs=[1e-3], 621 | encoder_wds=[0.0], 622 | lens_wds=[0.2], 623 | predictor_wds=[0.2], 624 | reduce_fn_kwargs_paths=['linear_pool_1024'], 625 | lens_train_samples=[50], 626 | train_families=10000, 627 | epochs=10, 628 | measurements=1, 629 | encoder_fn_name='transformer', 630 | encoder_fn_kwargs_path='small_transformer_kwargs', 631 | reduce_fn_name='linear_max_pool', 632 | lens_batch_size=64, 633 | knn_batch_size=64, 634 | use_transformer=True, 635 | use_bert=True, 636 | load_gcs_bucket=load_gcs_bucket, 637 | save_gcs_bucket=save_gcs_bucket, 638 | model_random_keys=range(10)) 639 | 640 | params += create_params( 641 | encoder_lrs=[0.0], 642 | lens_lrs=[1e-4], 643 | predictor_lrs=[1e-3], 644 | encoder_wds=[0.0], 645 | lens_wds=[0.2], 646 | predictor_wds=[0.2], 647 | reduce_fn_kwargs_paths=['linear_pool_1024'], 648 | lens_train_samples=[50], 649 | train_families=10000, 650 | epochs=10, 651 | measurements=1, 652 | encoder_fn_name='transformer', 653 | encoder_fn_kwargs_path='small_transformer_kwargs', 654 | reduce_fn_name='linear_max_pool', 655 | lens_batch_size=64, 656 | knn_batch_size=64, 657 | use_transformer=True, 658 | use_bert=True, 659 | load_gcs_bucket=load_gcs_bucket, 660 | save_gcs_bucket=save_gcs_bucket, 661 | restore_transformer_dir= 662 | 'gs://sequin-public/transformer_models/small_trembl_bert/', 663 | model_random_keys=range(10)) 664 | 665 | # 10000 train families large transformer random keys sweep 666 | for i in range(5): 667 | params += create_params( 668 | encoder_lrs=[0.0], 669 | lens_lrs=[1e-5], 670 | predictor_lrs=[1e-3], 671 | encoder_wds=[0.0], 672 | lens_wds=[0.05], 673 | predictor_wds=[0.0], 674 | reduce_fn_kwargs_paths=['linear_pool_1024'], 675 | lens_train_samples=[50], 676 | train_families=10000, 677 | epochs=10, 678 | measurements=1, 679 | encoder_fn_name='transformer', 680 | encoder_fn_kwargs_path='large_transformer_kwargs', 681 | reduce_fn_name='linear_max_pool', 682 | lens_batch_size=8, 683 | knn_batch_size=64, 684 | use_transformer=True, 685 | use_bert=True, 686 | load_gcs_bucket=load_gcs_bucket, 687 | save_gcs_bucket=save_gcs_bucket, 688 | model_random_keys=[i], 689 | save_model=True, 690 | save_model_dir=os.path.join('pfam_experiment_optimizers', 691 | 'large_0_' + str(i))) 692 | 693 | params += create_params( 694 | encoder_lrs=[0.0], 695 | lens_lrs=[1e-5], 696 | predictor_lrs=[1e-3], 697 | encoder_wds=[0.0], 698 | lens_wds=[0.05], 699 | predictor_wds=[0.0], 700 | reduce_fn_kwargs_paths=['linear_pool_1024'], 701 | lens_train_samples=[50], 702 | train_families=10000, 703 | epochs=10, 704 | measurements=1, 705 | encoder_fn_name='transformer', 706 | encoder_fn_kwargs_path='large_transformer_kwargs', 707 | reduce_fn_name='linear_max_pool', 708 | lens_batch_size=8, 709 | knn_batch_size=64, 710 | use_transformer=True, 711 | use_bert=True, 712 | restore_transformer_dir= 713 | 'gs://sequin-public/transformer_models/large_bert/', 714 | load_gcs_bucket=load_gcs_bucket, 715 | save_gcs_bucket=save_gcs_bucket, 716 | model_random_keys=[i], 717 | save_model=True, 718 | save_model_dir=os.path.join('pfam_experiment_optimizers', 719 | 'large_1_' + str(i))) 720 | 721 | params += create_params( 722 | encoder_lrs=[0.0], 723 | lens_lrs=[5e-5], 724 | predictor_lrs=[5e-4], 725 | encoder_wds=[0.0], 726 | lens_wds=[0.2], 727 | predictor_wds=[0.2], 728 | reduce_fn_kwargs_paths=['linear_pool_1024'], 729 | lens_train_samples=[50], 730 | train_families=10000, 731 | epochs=10, 732 | measurements=1, 733 | encoder_fn_name='transformer', 734 | encoder_fn_kwargs_path='large_transformer_kwargs', 735 | reduce_fn_name='linear_max_pool', 736 | lens_batch_size=8, 737 | knn_batch_size=64, 738 | use_transformer=True, 739 | use_bert=True, 740 | load_gcs_bucket=load_gcs_bucket, 741 | save_gcs_bucket=save_gcs_bucket, 742 | model_random_keys=[i], 743 | save_model=True, 744 | save_model_dir=os.path.join('pfam_experiment_optimizers', 745 | 'large_2_' + str(i))) 746 | 747 | params += create_params( 748 | encoder_lrs=[0.0], 749 | lens_lrs=[5e-5], 750 | predictor_lrs=[5e-4], 751 | encoder_wds=[0.0], 752 | lens_wds=[0.2], 753 | predictor_wds=[0.2], 754 | reduce_fn_kwargs_paths=['linear_pool_1024'], 755 | lens_train_samples=[50], 756 | train_families=10000, 757 | epochs=10, 758 | measurements=1, 759 | encoder_fn_name='transformer', 760 | encoder_fn_kwargs_path='large_transformer_kwargs', 761 | reduce_fn_name='linear_max_pool', 762 | lens_batch_size=8, 763 | knn_batch_size=64, 764 | use_transformer=True, 765 | use_bert=True, 766 | restore_transformer_dir= 767 | 'gs://sequin-public/transformer_models/large_bert/', 768 | load_gcs_bucket=load_gcs_bucket, 769 | save_gcs_bucket=save_gcs_bucket, 770 | model_random_keys=[i], 771 | save_model=True, 772 | save_model_dir=os.path.join('pfam_experiment_optimizers', 773 | 'large_3_' + str(i))) 774 | 775 | params += create_params( 776 | encoder_lrs=[0.0], 777 | lens_lrs=[1e-4], 778 | predictor_lrs=[1e-3], 779 | encoder_wds=[0.0], 780 | lens_wds=[0.1], 781 | predictor_wds=[0.0], 782 | reduce_fn_kwargs_paths=['linear_pool_1024'], 783 | lens_train_samples=[50], 784 | train_families=10000, 785 | epochs=10, 786 | measurements=1, 787 | encoder_fn_name='transformer', 788 | encoder_fn_kwargs_path='large_transformer_kwargs', 789 | reduce_fn_name='linear_max_pool', 790 | lens_batch_size=8, 791 | knn_batch_size=64, 792 | use_transformer=True, 793 | use_bert=True, 794 | load_gcs_bucket=load_gcs_bucket, 795 | save_gcs_bucket=save_gcs_bucket, 796 | model_random_keys=[i], 797 | save_model=True, 798 | save_model_dir=os.path.join('pfam_experiment_optimizers', 799 | 'large_4_' + str(i))) 800 | 801 | params += create_params( 802 | encoder_lrs=[0.0], 803 | lens_lrs=[1e-4], 804 | predictor_lrs=[1e-3], 805 | encoder_wds=[0.0], 806 | lens_wds=[0.1], 807 | predictor_wds=[0.0], 808 | reduce_fn_kwargs_paths=['linear_pool_1024'], 809 | lens_train_samples=[50], 810 | train_families=10000, 811 | epochs=10, 812 | measurements=1, 813 | encoder_fn_name='transformer', 814 | encoder_fn_kwargs_path='large_transformer_kwargs', 815 | reduce_fn_name='linear_max_pool', 816 | lens_batch_size=8, 817 | knn_batch_size=64, 818 | use_transformer=True, 819 | use_bert=True, 820 | load_gcs_bucket=load_gcs_bucket, 821 | save_gcs_bucket=save_gcs_bucket, 822 | restore_transformer_dir= 823 | 'gs://sequin-public/transformer_models/large_bert/', 824 | model_random_keys=[i], 825 | save_model=True, 826 | save_model_dir=os.path.join('pfam_experiment_optimizers', 827 | 'large_5_' + str(i))) 828 | 829 | params += create_params( 830 | encoder_lrs=[0.0], 831 | lens_lrs=[1e-3], 832 | predictor_lrs=[1e-3], 833 | encoder_wds=[0.0], 834 | lens_wds=[0.1], 835 | predictor_wds=[0.0], 836 | reduce_fn_kwargs_paths=['linear_pool_1024'], 837 | lens_train_samples=[50], 838 | train_families=10000, 839 | epochs=10, 840 | measurements=1, 841 | encoder_fn_name='transformer', 842 | encoder_fn_kwargs_path='large_transformer_kwargs', 843 | reduce_fn_name='linear_max_pool', 844 | lens_batch_size=8, 845 | knn_batch_size=64, 846 | use_transformer=True, 847 | use_bert=True, 848 | load_gcs_bucket=load_gcs_bucket, 849 | save_gcs_bucket=save_gcs_bucket, 850 | model_random_keys=[i], 851 | save_model=True, 852 | save_model_dir=os.path.join('pfam_experiment_optimizers', 853 | 'large_6_' + str(i))) 854 | 855 | params += create_params( 856 | encoder_lrs=[0.0], 857 | lens_lrs=[1e-3], 858 | predictor_lrs=[1e-3], 859 | encoder_wds=[0.0], 860 | lens_wds=[0.1], 861 | predictor_wds=[0.0], 862 | reduce_fn_kwargs_paths=['linear_pool_1024'], 863 | lens_train_samples=[50], 864 | train_families=10000, 865 | epochs=10, 866 | measurements=1, 867 | encoder_fn_name='transformer', 868 | encoder_fn_kwargs_path='large_transformer_kwargs', 869 | reduce_fn_name='linear_max_pool', 870 | lens_batch_size=8, 871 | knn_batch_size=64, 872 | use_transformer=True, 873 | use_bert=True, 874 | load_gcs_bucket=load_gcs_bucket, 875 | save_gcs_bucket=save_gcs_bucket, 876 | restore_transformer_dir= 877 | 'gs://sequin-public/transformer_models/large_bert/', 878 | model_random_keys=[i], 879 | save_model=True, 880 | save_model_dir=os.path.join('pfam_experiment_optimizers', 881 | 'large_7_' + str(i))) 882 | 883 | 884 | # 10000 train families random key sweep and save best models 885 | for i in range(7): 886 | params += create_params( 887 | encoder_lrs=[0.0], 888 | lens_lrs=[1e-5], 889 | predictor_lrs=[1e-3], 890 | encoder_wds=[0.0], 891 | lens_wds=[0.05], 892 | predictor_wds=[0.0], 893 | reduce_fn_kwargs_paths=['linear_pool_1024'], 894 | lens_train_samples=[50], 895 | train_families=10000, 896 | epochs=10, 897 | measurements=2, 898 | encoder_fn_name='transformer', 899 | encoder_fn_kwargs_path='medium_transformer_kwargs', 900 | reduce_fn_name='linear_max_pool', 901 | lens_batch_size=64, 902 | knn_batch_size=64, 903 | use_transformer=True, 904 | use_bert=True, 905 | load_gcs_bucket=load_gcs_bucket, 906 | save_gcs_bucket=save_gcs_bucket, 907 | model_random_keys=[i], 908 | save_model=True, 909 | save_model_dir=os.path.join('pfam_experiment_optimizers', 910 | 'medium_' + str(i))) 911 | 912 | params += create_params( 913 | encoder_lrs=[0.0], 914 | lens_lrs=[5e-5], 915 | predictor_lrs=[1e-3], 916 | encoder_wds=[0.0], 917 | lens_wds=[0.2], 918 | predictor_wds=[0.25], 919 | reduce_fn_kwargs_paths=['linear_pool_1024'], 920 | lens_train_samples=[50], 921 | train_families=10000, 922 | epochs=10, 923 | measurements=2, 924 | encoder_fn_name='transformer', 925 | encoder_fn_kwargs_path='medium_transformer_kwargs', 926 | reduce_fn_name='linear_max_pool', 927 | lens_batch_size=64, 928 | knn_batch_size=64, 929 | use_transformer=True, 930 | use_bert=True, 931 | restore_transformer_dir= 932 | 'gs://sequin-public/transformer_models/medium_trembl_bert/', 933 | load_gcs_bucket=load_gcs_bucket, 934 | save_gcs_bucket=save_gcs_bucket, 935 | model_random_keys=[i], 936 | save_model=True, 937 | save_model_dir=os.path.join('pfam_experiment_optimizers', 938 | 'medium_pt_' + str(i))) 939 | 940 | params += create_params( 941 | encoder_lrs=[0.0], 942 | lens_lrs=[1e-3], 943 | predictor_lrs=[1e-3], 944 | encoder_wds=[0.0], 945 | lens_wds=[0.1], 946 | predictor_wds=[0.0], 947 | reduce_fn_kwargs_paths=['linear_pool_1024'], 948 | lens_train_samples=[50], 949 | train_families=10000, 950 | epochs=10, 951 | measurements=1, 952 | encoder_fn_name='transformer', 953 | encoder_fn_kwargs_path='small_transformer_kwargs', 954 | reduce_fn_name='linear_max_pool', 955 | lens_batch_size=64, 956 | knn_batch_size=64, 957 | use_transformer=True, 958 | use_bert=True, 959 | load_gcs_bucket=load_gcs_bucket, 960 | save_gcs_bucket=save_gcs_bucket, 961 | model_random_keys=[i], 962 | save_model=True, 963 | save_model_dir=os.path.join('pfam_experiment_optimizers', 964 | 'small_' + str(i))) 965 | 966 | params += create_params( 967 | encoder_lrs=[0.0], 968 | lens_lrs=[1e-3], 969 | predictor_lrs=[1e-3], 970 | encoder_wds=[0.0], 971 | lens_wds=[0.2], 972 | predictor_wds=[0.2], 973 | reduce_fn_kwargs_paths=['linear_pool_1024'], 974 | lens_train_samples=[50], 975 | train_families=10000, 976 | epochs=10, 977 | measurements=1, 978 | encoder_fn_name='transformer', 979 | encoder_fn_kwargs_path='small_transformer_kwargs', 980 | reduce_fn_name='linear_max_pool', 981 | lens_batch_size=64, 982 | knn_batch_size=64, 983 | use_transformer=True, 984 | use_bert=True, 985 | restore_transformer_dir= 986 | 'gs://sequin-public/transformer_models/small_trembl_bert/', 987 | load_gcs_bucket=load_gcs_bucket, 988 | save_gcs_bucket=save_gcs_bucket, 989 | model_random_keys=[i], 990 | save_model=True, 991 | save_model_dir=os.path.join('pfam_experiment_optimizers', 992 | 'small_pt_' + str(i))) 993 | 994 | params += create_params( 995 | encoder_lrs=[1e-3], 996 | lens_lrs=[5e-5], 997 | predictor_lrs=[5e-5], 998 | encoder_wds=[0.3], 999 | lens_wds=[0.05], 1000 | predictor_wds=[0.0], 1001 | reduce_fn_kwargs_paths=['linear_pool_1024'], 1002 | lens_train_samples=[50], 1003 | train_families=10000, 1004 | epochs=10, 1005 | measurements=2, 1006 | encoder_fn_name='cnn_one_hot', 1007 | encoder_fn_kwargs_path='2-layer_cnn_kwargs', 1008 | reduce_fn_name='linear_max_pool', 1009 | lens_batch_size=512, 1010 | knn_batch_size=512, 1011 | use_transformer=False, 1012 | use_bert=False, 1013 | load_gcs_bucket=load_gcs_bucket, 1014 | save_gcs_bucket=save_gcs_bucket, 1015 | model_random_keys=[i], 1016 | save_model=True, 1017 | save_model_dir=os.path.join('pfam_experiment_optimizers', 1018 | 'cnn_' + str(i))) 1019 | 1020 | for i in range(7): 1021 | params += create_params( 1022 | encoder_lrs=[0.0], 1023 | lens_lrs=[1e-4], 1024 | predictor_lrs=[1e-3], 1025 | encoder_wds=[0.0], 1026 | lens_wds=[0.1], 1027 | predictor_wds=[0.0], 1028 | reduce_fn_kwargs_paths=['linear_pool_1024'], 1029 | lens_train_samples=[50], 1030 | train_families=10000, 1031 | epochs=10, 1032 | measurements=1, 1033 | encoder_fn_name='transformer', 1034 | encoder_fn_kwargs_path='small_transformer_kwargs', 1035 | reduce_fn_name='linear_max_pool', 1036 | lens_batch_size=64, 1037 | knn_batch_size=64, 1038 | use_transformer=True, 1039 | use_bert=True, 1040 | load_gcs_bucket=load_gcs_bucket, 1041 | save_gcs_bucket=save_gcs_bucket, 1042 | model_random_keys=[i], 1043 | save_model=True, 1044 | save_model_dir=os.path.join('pfam_experiment_optimizers', 1045 | 'small_' + str(i+7))) 1046 | 1047 | for i in range(10): 1048 | params += create_params( 1049 | encoder_lrs=[0.0], 1050 | lens_lrs=[1e-4], 1051 | predictor_lrs=[1e-3], 1052 | encoder_wds=[0.0], 1053 | lens_wds=[0.2], 1054 | predictor_wds=[0.2], 1055 | reduce_fn_kwargs_paths=['linear_pool_1024'], 1056 | lens_train_samples=[50], 1057 | train_families=10000, 1058 | epochs=10, 1059 | measurements=1, 1060 | encoder_fn_name='transformer', 1061 | encoder_fn_kwargs_path='small_transformer_kwargs', 1062 | reduce_fn_name='linear_max_pool', 1063 | lens_batch_size=64, 1064 | knn_batch_size=64, 1065 | use_transformer=True, 1066 | use_bert=True, 1067 | restore_transformer_dir= 1068 | 'gs://sequin-public/transformer_models/small_trembl_bert/', 1069 | load_gcs_bucket=load_gcs_bucket, 1070 | save_gcs_bucket=save_gcs_bucket, 1071 | model_random_keys=[i], 1072 | save_model=True, 1073 | save_model_dir=os.path.join('pfam_experiment_optimizers', 1074 | 'small_' + str(i+14))) 1075 | 1076 | params += create_params( 1077 | encoder_lrs=[0.0], 1078 | lens_lrs=[1e-5], 1079 | predictor_lrs=[1e-3], 1080 | encoder_wds=[0.0], 1081 | lens_wds=[0.05], 1082 | predictor_wds=[0.0], 1083 | reduce_fn_kwargs_paths=['linear_pool_1024'], 1084 | lens_train_samples=[50], 1085 | train_families=10000, 1086 | epochs=10, 1087 | measurements=1, 1088 | encoder_fn_name='transformer', 1089 | encoder_fn_kwargs_path='medium_transformer_kwargs', 1090 | reduce_fn_name='linear_max_pool', 1091 | lens_batch_size=64, 1092 | knn_batch_size=64, 1093 | use_transformer=True, 1094 | use_bert=True, 1095 | load_gcs_bucket=load_gcs_bucket, 1096 | save_gcs_bucket=save_gcs_bucket, 1097 | model_random_keys=[i], 1098 | save_model=True, 1099 | save_model_dir=os.path.join('pfam_experiment_optimizers', 1100 | 'medium_' + str(i+7))) 1101 | 1102 | params += create_params( 1103 | encoder_lrs=[0.0], 1104 | lens_lrs=[5e-5], 1105 | predictor_lrs=[1e-3], 1106 | encoder_wds=[0.0], 1107 | lens_wds=[0.2], 1108 | predictor_wds=[0.25], 1109 | reduce_fn_kwargs_paths=['linear_pool_1024'], 1110 | lens_train_samples=[50], 1111 | train_families=10000, 1112 | epochs=10, 1113 | measurements=1, 1114 | encoder_fn_name='transformer', 1115 | encoder_fn_kwargs_path='medium_transformer_kwargs', 1116 | reduce_fn_name='linear_max_pool', 1117 | lens_batch_size=64, 1118 | knn_batch_size=64, 1119 | use_transformer=True, 1120 | use_bert=True, 1121 | restore_transformer_dir= 1122 | 'gs://sequin-public/transformer_models/medium_trembl_bert/', 1123 | load_gcs_bucket=load_gcs_bucket, 1124 | save_gcs_bucket=save_gcs_bucket, 1125 | model_random_keys=[i], 1126 | save_model=True, 1127 | save_model_dir=os.path.join('pfam_experiment_optimizers', 1128 | 'medium_pt_' + str(i+7))) 1129 | 1130 | 1131 | frozen_param_dict_to_label = {} 1132 | label = 0 1133 | for param_dict in params: 1134 | if frozendict(param_dict) not in frozen_param_dict_to_label.keys(): 1135 | frozen_param_dict_to_label[frozendict(param_dict)] = label 1136 | label += 1 1137 | 1138 | unique_params = [] 1139 | for frozen_param_dict in frozen_param_dict_to_label.keys(): 1140 | param_dict = dict(frozen_param_dict) 1141 | param_dict.update( 1142 | {'label': frozen_param_dict_to_label[frozen_param_dict]}) 1143 | unique_params.append(param_dict) 1144 | unique_params = sorted(unique_params, key=lambda x: x['label']) 1145 | 1146 | def transform_label(param_dict): 1147 | param_dict['label'] = '%08d' % param_dict['label'] 1148 | return param_dict 1149 | 1150 | unique_params = [ 1151 | transform_label(param_dict) for param_dict in unique_params 1152 | ] 1153 | 1154 | with open('params_combinations.json', 'w') as f: 1155 | json.dump(unique_params, f) 1156 | 1157 | label_to_params = {} 1158 | for param_dict in unique_params: 1159 | label_to_params[param_dict['label']] = param_dict 1160 | 1161 | with open('label_to_params.json', 'w') as f: 1162 | json.dump(label_to_params, f) 1163 | 1164 | 1165 | if __name__ == '__main__': 1166 | main(load_gcs_bucket='neuralblast_public', 1167 | save_gcs_bucket='sequin-public') 1168 | --------------------------------------------------------------------------------