├── contextual_lenses
├── __init__.py
├── cloud_utils
│ ├── __init__.py
│ ├── cpu_init.sh
│ ├── gpu_init.sh
│ └── tpu_init.py
├── resources
│ ├── __init__.py
│ ├── reduce_fn_kwargs_resources
│ │ ├── pool.json
│ │ ├── linear_pool_256.json
│ │ ├── linear_pool_512.json
│ │ └── linear_pool_1024.json
│ └── encoder_fn_kwargs_resources
│ │ ├── 1-layer_cnn_kwargs.json
│ │ ├── large_transformer_kwargs.json
│ │ ├── medium_transformer_kwargs.json
│ │ ├── small_transformer_kwargs.json
│ │ └── 2-layer_cnn_kwargs.json
├── loss_fns.py
├── load_transformer.py
├── test
│ ├── test_loading.py
│ ├── test_pooling.py
│ └── test_learning.py
├── encoders.py
├── blast_baseline.py
├── contextual_lenses.py
├── pfam_utils.py
└── train_utils.py
├── .gitmodules
├── figures
├── 1-sample_test_knn_accuracy.png
├── 5-sample_test_knn_accuracy.png
├── 10-sample_test_knn_accuracy.png
└── 50-sample_test_knn_accuracy.png
├── demo.json
├── docs
├── contributing.md
└── code-of-conduct.md
├── setup.py
├── LICENSE
├── README.md
├── pfam_experiment.py
└── generate_params.py
/contextual_lenses/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/contextual_lenses/cloud_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/contextual_lenses/resources/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/contextual_lenses/resources/reduce_fn_kwargs_resources/pool.json:
--------------------------------------------------------------------------------
1 | {}
--------------------------------------------------------------------------------
/contextual_lenses/resources/reduce_fn_kwargs_resources/linear_pool_256.json:
--------------------------------------------------------------------------------
1 | {"rep_size": 256}
--------------------------------------------------------------------------------
/contextual_lenses/resources/reduce_fn_kwargs_resources/linear_pool_512.json:
--------------------------------------------------------------------------------
1 | {"rep_size": 512}
--------------------------------------------------------------------------------
/contextual_lenses/resources/reduce_fn_kwargs_resources/linear_pool_1024.json:
--------------------------------------------------------------------------------
1 | {"rep_size": 1024}
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "google-research"]
2 | path = google_research
3 | url = https://github.com/google-research/google-research.git
--------------------------------------------------------------------------------
/figures/1-sample_test_knn_accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/googleinterns/protein-embedding-retrieval/HEAD/figures/1-sample_test_knn_accuracy.png
--------------------------------------------------------------------------------
/figures/5-sample_test_knn_accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/googleinterns/protein-embedding-retrieval/HEAD/figures/5-sample_test_knn_accuracy.png
--------------------------------------------------------------------------------
/figures/10-sample_test_knn_accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/googleinterns/protein-embedding-retrieval/HEAD/figures/10-sample_test_knn_accuracy.png
--------------------------------------------------------------------------------
/figures/50-sample_test_knn_accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/googleinterns/protein-embedding-retrieval/HEAD/figures/50-sample_test_knn_accuracy.png
--------------------------------------------------------------------------------
/contextual_lenses/resources/encoder_fn_kwargs_resources/1-layer_cnn_kwargs.json:
--------------------------------------------------------------------------------
1 | {"n_layers": 1, "n_features": [1024], "n_kernel_sizes": [12], "n_kernel_dilations": null}
--------------------------------------------------------------------------------
/contextual_lenses/resources/encoder_fn_kwargs_resources/large_transformer_kwargs.json:
--------------------------------------------------------------------------------
1 | {"emb_dim": 512, "num_heads": 8, "qkv_dim": 512, "mlp_dim": 1024, "num_layers": 36}
--------------------------------------------------------------------------------
/contextual_lenses/resources/encoder_fn_kwargs_resources/medium_transformer_kwargs.json:
--------------------------------------------------------------------------------
1 | {"emb_dim": 512, "num_heads": 8, "qkv_dim": 512, "mlp_dim": 2048, "num_layers": 6}
--------------------------------------------------------------------------------
/contextual_lenses/resources/encoder_fn_kwargs_resources/small_transformer_kwargs.json:
--------------------------------------------------------------------------------
1 | {"emb_dim": 256, "num_heads": 4, "qkv_dim": 256, "mlp_dim": 1024, "num_layers": 2}
--------------------------------------------------------------------------------
/contextual_lenses/cloud_utils/cpu_init.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | pip install --upgrade pip
4 | pip install --upgrade jax jaxlib flax
5 | export TF_FORCE_GPU_ALLOW_GROWTH=true
--------------------------------------------------------------------------------
/contextual_lenses/resources/encoder_fn_kwargs_resources/2-layer_cnn_kwargs.json:
--------------------------------------------------------------------------------
1 | {"n_layers": 2, "n_features": [1024, 1024], "n_kernel_sizes": [5, 5], "n_kernel_dilations": null}
--------------------------------------------------------------------------------
/contextual_lenses/cloud_utils/gpu_init.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | pip install --upgrade pip
4 |
5 | PYTHON_VERSION=cp37
6 | CUDA_VERSION=cuda101
7 | PLATFORM=linux_x86_64
8 | BASE_URL='https://storage.googleapis.com/jax-releases'
9 |
10 | pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.48-$PYTHON_VERSION-none-$PLATFORM.whl
11 | pip install --upgrade jax flax
12 |
13 | export TF_FORCE_GPU_ALLOW_GROWTH=true
--------------------------------------------------------------------------------
/contextual_lenses/loss_fns.py:
--------------------------------------------------------------------------------
1 | """Loss functions
2 |
3 | Jax loss functions for computing gradient updates.
4 | """
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | from jax.config import config
9 | config.enable_omnistaging()
10 |
11 |
12 | def mse_loss(Y, Y_hat):
13 | """Squeezes predictions and returns MSE loss."""
14 |
15 | if len(Y_hat.shape) > 1:
16 | Y_hat = jnp.squeeze(Y_hat, axis=1)
17 |
18 | loss = jnp.mean(jnp.square(Y - Y_hat))
19 |
20 | return loss
21 |
22 |
23 | def cross_entropy_loss(Y, Y_hat, num_classes):
24 | """Applies log-softmax to predictions and one-hot encodes true values
25 | to compute and return cross-entropy loss.
26 | """
27 |
28 | Y_hat = jax.nn.log_softmax(Y_hat)
29 |
30 | Y = jax.nn.one_hot(Y, num_classes=num_classes)
31 |
32 | loss = -jnp.sum(Y * Y_hat)
33 |
34 | return loss
35 |
--------------------------------------------------------------------------------
/contextual_lenses/cloud_utils/tpu_init.py:
--------------------------------------------------------------------------------
1 | """Function to connect VM instance to Cloud TPU."""
2 |
3 |
4 | import requests
5 | import yaml
6 | import subprocess as sp
7 | from jax.config import config
8 |
9 |
10 | def connect_tpu(tpu_name=None):
11 | """Runs necessary commands to connect VM to Cloud TPU."""
12 |
13 | if tpu_name is not None:
14 | command = 'gcloud compute tpus describe ' + tpu_name
15 |
16 | output = sp.getoutput(command)
17 | output = yaml.load(stream=output, Loader=yaml.SafeLoader)
18 |
19 | ip_address = output['ipAddress']
20 | port = output['port']
21 |
22 | url = 'http://' + ip_address + ':8475/requestversion/tpu_driver_nightly'
23 | requests.post(url)
24 |
25 | config.FLAGS.jax_xla_backend = 'tpu_driver'
26 | config.FLAGS.jax_backend_target = "grpc://" + ip_address + ':' + port
27 |
28 | print('Successfully connected to TPU named \"' + tpu_name + '\"!')
29 |
--------------------------------------------------------------------------------
/demo.json:
--------------------------------------------------------------------------------
1 | {"encoder_fn_kwargs_path": "small_transformer_kwargs", "restore_transformer_dir": "gs://sequin-public/transformer_models/small_trembl_bert/", "lens_shuffle_seed": 0, "label": "demo", "train_families": 10000, "data_partitions_dirpath": "random_split/", "last_test_family": 16000, "load_model_step": 0, "knn_sample_random_state": 1, "predictor_wd": 0.2, "lens_wd": 0.2, "save_model": true, "model_random_key": 0, "encoder_wd": 0.0, "use_bert": true, "first_test_family": 15000, "load_model_dir": "", "save_gcs_bucket": "neuralblast_public", "predictor_lr": 0.001, "knn_shuffle_seed": 1, "load_model": false, "load_gcs_bucket": "neuralblast_public", "knn_batch_size": 64, "reduce_fn_name": "linear_max_pool", "epochs": 10, "results_save_dir": "demo/", "lens_train_samples": 50, "measurements": 1, "reduce_fn_kwargs_path": "linear_pool_1024", "save_model_dir": "demo/model/", "lens_sample_random_state": 0, "use_transformer": true, "lens_lr": 0.0001, "encoder_lr": 0.0, "encoder_fn_name": "transformer", "lens_batch_size": 64}
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/contextual_lenses/load_transformer.py:
--------------------------------------------------------------------------------
1 | """Utils for loading FlaxLM or FlaxBERT models or parameters."""
2 |
3 | import os
4 |
5 | import tensorflow as tf
6 |
7 | import gin
8 |
9 | from protein_lm import data, models
10 |
11 |
12 | def load_transformer_model(ckpt_dir, model_cls, domain=None):
13 | """Loads a model from directory."""
14 |
15 | if domain is None:
16 | domain = data.protein_domain
17 |
18 | config_path = os.path.join(ckpt_dir, 'config.gin')
19 | with gin.config_scope('load_model'):
20 | with tf.io.gfile.GFile(config_path) as f:
21 | gin.parse_config(f, skip_unknown=True)
22 | model = model_cls(domain=domain)
23 | model.load_checkpoint(ckpt_dir)
24 |
25 | return model
26 |
27 |
28 | def load_transformer_params(ckpt_dir, model_cls, domain=None):
29 | """Returns parameters of a loaded model."""
30 |
31 | model = load_transformer_model(ckpt_dir, model_cls, domain=domain)
32 | params = models.jax_utils.unreplicate(model._optimizer.target).params
33 |
34 | return params
35 |
36 |
37 | def load_transformer_encoder(ckpt_dir, model_cls, domain=None):
38 | """Returns parameters of a loaded model."""
39 |
40 | model = load_transformer_model(ckpt_dir, model_cls, domain=domain)
41 | encoder = models.jax_utils.unreplicate(model._optimizer.target)
42 |
43 | return encoder
44 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | # This follows the style of Jaxlib installation here:
4 | # https://github.com/google/jax#pip-installation
5 | PYTHON_VERSION = "cp37"
6 | CUDA_VERSION = "cuda101" # alternatives: cuda90, cuda92, cuda100
7 | PLATFORM = "manylinux2010_x86_64" # alternatives: linux_x86_64
8 | BASE_URL = "https://storage.googleapis.com/jax-releases"
9 |
10 |
11 | def jax_artifact(version, gpu=False):
12 | if gpu:
13 | prefix = f"{BASE_URL}/{CUDA_VERSION}/jaxlib"
14 | wheel_suffix = f"{PYTHON_VERSION}-none-{PLATFORM}.whl"
15 | location = f"{prefix}-{version}-{wheel_suffix}"
16 | return f"jaxlib @ {location}"
17 | return f"jaxlib=={version}"
18 |
19 | def readme():
20 | try:
21 | with open('README.md') as rf:
22 | return rf.read()
23 | except FileNotFoundError:
24 | return None
25 |
26 | JAXLIB_VERSION = "0.1.52"
27 | JAX_VERSION = "0.1.75"
28 | FLAX_VERSION = "0.2.0"
29 |
30 | REQUIRED_PACKAGES = [
31 | "tensorflow",
32 | "pandas",
33 | "numpy",
34 | "gin-config",
35 | "sklearn",
36 | "dm-tree",
37 | "fs",
38 | "fs-gcsfs",
39 | f"jax=={JAX_VERSION}",
40 | f"flax=={FLAX_VERSION}"
41 | ]
42 |
43 | setup(name='contextual_lenses',
44 | version='1.0',
45 | description='Protein contextual lenses.',
46 | long_description=readme(),
47 | author='Amir Shanehsazzadeh',
48 | author_email='amirshanehsaz@google.com',
49 | packages=find_packages(exclude=('docs')),
50 | install_requires=REQUIRED_PACKAGES,
51 | extras_require={
52 | "cpu": [jax_artifact(JAXLIB_VERSION, gpu=False)],
53 | "gpu": [jax_artifact(JAXLIB_VERSION, gpu=True)],
54 | })
55 |
--------------------------------------------------------------------------------
/contextual_lenses/test/test_loading.py:
--------------------------------------------------------------------------------
1 | """Tests for saving and restoring checkpoints."""
2 |
3 |
4 | import os
5 |
6 | from absl.testing import parameterized
7 | from absl.testing import absltest
8 |
9 | from flax.training import checkpoints
10 |
11 | from contextual_lenses import mean_pool, max_pool
12 |
13 | from train_utils import create_optimizer, create_representation_model
14 |
15 | from encoders import one_hot_encoder
16 |
17 |
18 | # Test cases:
19 | test1 = {
20 | 'encoder_fn': one_hot_encoder,
21 | 'encoder_fn_kwargs' : {
22 |
23 | },
24 | 'reduce_fn': mean_pool,
25 | 'reduce_fn_kwargs': {
26 |
27 | }
28 | }
29 |
30 | test2 = {
31 | 'encoder_fn': one_hot_encoder,
32 | 'encoder_fn_kwargs' : {
33 |
34 | },
35 | 'reduce_fn': max_pool,
36 | 'reduce_fn_kwargs': {
37 |
38 | }
39 | }
40 |
41 | tests = (test1, test2)
42 |
43 |
44 | class TestLoading(parameterized.TestCase):
45 | """Abstract method for testing saving and restoring optimizer checkpoints."""
46 |
47 | @parameterized.parameters(
48 | *tests
49 | )
50 | def test_loading(self, encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs):
51 |
52 | model = create_representation_model(encoder_fn=encoder_fn,
53 | encoder_fn_kwargs=encoder_fn_kwargs,
54 | reduce_fn=reduce_fn,
55 | reduce_fn_kwargs=reduce_fn_kwargs,
56 | num_categories=21,
57 | output_features=1)
58 |
59 | optimizer = create_optimizer(model, learning_rate=1e-3, weight_decay=0.)
60 |
61 | params = optimizer.target.params
62 | step = optimizer.state.step
63 |
64 | os.mkdir('tmp/')
65 |
66 | checkpoints.save_checkpoint(ckpt_dir='tmp/', target=optimizer, step=optimizer.state.step)
67 |
68 | loaded_optimizer = checkpoints.restore_checkpoint(ckpt_dir='tmp/', target=optimizer)
69 |
70 | loaded_params = loaded_optimizer.target.params
71 | loaded_step = loaded_optimizer.state.step
72 |
73 | os.system('rm -rf tmp/')
74 |
75 | self.assertTrue(step==loaded_step)
76 |
77 | self.assertTrue(list(params.keys())==list(loaded_params.keys()))
78 | for key in params.keys():
79 | sub_params = params[key]
80 | loaded_sub_params = loaded_params[key]
81 | self.assertTrue(list(sub_params.keys())==list(loaded_sub_params.keys()))
82 | for sub_key in sub_params.keys():
83 | self.assertTrue((sub_params[sub_key]==loaded_sub_params[sub_key]).all())
84 |
85 |
86 | if __name__ == '__main__':
87 | absltest.main()
88 |
89 |
--------------------------------------------------------------------------------
/contextual_lenses/encoders.py:
--------------------------------------------------------------------------------
1 | """Encoder functions
2 |
3 | Fixed and learnable transformations for embedding sequences.
4 | """
5 |
6 | import flax
7 | from flax import nn
8 |
9 | import jax
10 | from jax import lax
11 | import jax.nn
12 | import jax.numpy as jnp
13 | from jax.config import config
14 | config.enable_omnistaging()
15 |
16 | import numpy as np
17 |
18 | from operator import itemgetter
19 |
20 |
21 | def one_hot_encoder(batch_inds, num_categories):
22 | """Applies one-hot encoding from jax.nn."""
23 |
24 | one_hots = jax.nn.one_hot(batch_inds, num_classes=num_categories)
25 |
26 | return one_hots
27 |
28 |
29 | class CNN(nn.Module):
30 | """A simple 1D CNN model."""
31 | def apply(self, x, n_layers, n_features, n_kernel_sizes,
32 | n_kernel_dilations):
33 |
34 | if n_kernel_dilations is None:
35 | n_kernel_dilations = [1] * n_layers
36 |
37 | x = jnp.expand_dims(x, axis=2)
38 |
39 | for layer in range(n_layers):
40 | features = n_features[layer]
41 | kernel_size = (n_kernel_sizes[layer], 1)
42 | kernel_dilation = (n_kernel_dilations[layer], 1)
43 | x = nn.Conv(x,
44 | features=features,
45 | kernel_size=kernel_size,
46 | kernel_dilation=kernel_dilation)
47 | x = nn.relu(x)
48 |
49 | x = jnp.squeeze(x, axis=2)
50 |
51 | return x
52 |
53 |
54 | def cnn_one_hot_encoder(batch_inds,
55 | num_categories,
56 | n_layers,
57 | n_features,
58 | n_kernel_sizes,
59 | n_kernel_dilations=None):
60 | """Applies one-hot encoding followed by 1D CNN."""
61 |
62 | one_hots = one_hot_encoder(batch_inds, num_categories)
63 | cnn_one_hots = CNN(one_hots, n_layers, n_features, n_kernel_sizes,
64 | n_kernel_dilations)
65 |
66 | return cnn_one_hots
67 |
68 |
69 | def encoder_fn_name_to_fn(encoder_fn_name):
70 | """Returns encoder_fn corresponding to encoder_fn_name."""
71 |
72 | if encoder_fn_name is None or encoder_fn_name == 'transformer':
73 | encoder_fn = None
74 | elif encoder_fn_name == 'one_hot':
75 | encoder_fn = one_hot_encoder
76 | elif encoder_fn_name == 'cnn_one_hot':
77 | encoder_fn = cnn_one_hot_encoder
78 | elif encoder_fn_name == 'one_hot_pos_emb':
79 | encoder_fn = one_hot_pos_emb_encoder
80 | elif encoder_fn_name == 'cnn_one_hot_pos_emb':
81 | encoder_fn = cnn_one_hot_pos_emb_encoder
82 | else:
83 | raise ValueError('Incorrect encoder name specified.')
84 |
85 | return encoder_fn
86 |
--------------------------------------------------------------------------------
/docs/code-of-conduct.md:
--------------------------------------------------------------------------------
1 | # Google Open Source Community Guidelines
2 |
3 | At Google, we recognize and celebrate the creativity and collaboration of open
4 | source contributors and the diversity of skills, experiences, cultures, and
5 | opinions they bring to the projects and communities they participate in.
6 |
7 | Every one of Google's open source projects and communities are inclusive
8 | environments, based on treating all individuals respectfully, regardless of
9 | gender identity and expression, sexual orientation, disabilities,
10 | neurodiversity, physical appearance, body size, ethnicity, nationality, race,
11 | age, religion, or similar personal characteristic.
12 |
13 | We value diverse opinions, but we value respectful behavior more.
14 |
15 | Respectful behavior includes:
16 |
17 | * Being considerate, kind, constructive, and helpful.
18 | * Not engaging in demeaning, discriminatory, harassing, hateful, sexualized, or
19 | physically threatening behavior, speech, and imagery.
20 | * Not engaging in unwanted physical contact.
21 |
22 | Some Google open source projects [may adopt][] an explicit project code of
23 | conduct, which may have additional detailed expectations for participants. Most
24 | of those projects will use our [modified Contributor Covenant][].
25 |
26 | [may adopt]: https://opensource.google/docs/releasing/preparing/#conduct
27 | [modified Contributor Covenant]: https://opensource.google/docs/releasing/template/CODE_OF_CONDUCT/
28 |
29 | ## Resolve peacefully
30 |
31 | We do not believe that all conflict is necessarily bad; healthy debate and
32 | disagreement often yields positive results. However, it is never okay to be
33 | disrespectful.
34 |
35 | If you see someone behaving disrespectfully, you are encouraged to address the
36 | behavior directly with those involved. Many issues can be resolved quickly and
37 | easily, and this gives people more control over the outcome of their dispute.
38 | If you are unable to resolve the matter for any reason, or if the behavior is
39 | threatening or harassing, report it. We are dedicated to providing an
40 | environment where participants feel welcome and safe.
41 |
42 | ## Reporting problems
43 |
44 | Some Google open source projects may adopt a project-specific code of conduct.
45 | In those cases, a Google employee will be identified as the Project Steward,
46 | who will receive and handle reports of code of conduct violations. In the event
47 | that a project hasn’t identified a Project Steward, you can report problems by
48 | emailing opensource@google.com.
49 |
50 | We will investigate every complaint, but you may not receive a direct response.
51 | We will use our discretion in determining when and how to follow up on reported
52 | incidents, which may range from not taking action to permanent expulsion from
53 | the project and project-sponsored spaces. We will notify the accused of the
54 | report and provide them an opportunity to discuss it before any action is
55 | taken. The identity of the reporter will be omitted from the details of the
56 | report supplied to the accused. In potentially harmful situations, such as
57 | ongoing harassment or threats to anyone's safety, we may take action without
58 | notice.
59 |
60 | *This document was adapted from the [IndieWeb Code of Conduct][] and can also
61 | be found at .*
62 |
63 | [IndieWeb Code of Conduct]: https://indieweb.org/code-of-conduct
64 |
--------------------------------------------------------------------------------
/contextual_lenses/blast_baseline.py:
--------------------------------------------------------------------------------
1 | """Computes accuracy of 1 nearest neighbor classification using BLAST.
2 |
3 | Example usage:
4 | blast_baseline.py \
5 | --train_file=./resources/knn_data/5-samples_train_knn_data_families_15001-16000.csv \
6 | --test_file=./resources/knn_data/test_knn_data_families_15001-16000.csv
7 | """
8 |
9 | import os
10 | import subprocess
11 | import tempfile
12 |
13 | from absl import app
14 | from absl import flags
15 | import numpy as np
16 | import pandas as pd
17 |
18 | # Need to install locally from https://github.com/google-research/proteinfer
19 | from proteinfer import baseline_utils
20 |
21 | # We also assume that blast is installed locally.
22 | # sudo apt-get install ncbi-blast+
23 |
24 |
25 | flags.DEFINE_string('train_file', '', 'Input train csv file.')
26 | flags.DEFINE_string('test_file', '', 'Input test csv file.')
27 |
28 |
29 | FLAGS = flags.FLAGS
30 |
31 |
32 | _BLAST_FLAGS = '-outfmt 6 -max_hsps 1 -num_threads 10 -num_alignments 1'
33 |
34 |
35 | def _get_header(row):
36 | accession = row.accession.replace('/', '_').replace('-', '_')
37 | return '>accession="%s"\tlabels="%s"' % (accession, row.label)
38 |
39 |
40 | def _get_fasta_entry(row):
41 | header = _get_header(row)
42 | return '\n'.join([header, row.sequence])
43 |
44 |
45 | def _write_fasta(df, output_file):
46 | entries = df.apply(_get_fasta_entry, axis=1)
47 | with open(output_file, 'w') as file:
48 | file.write('\n'.join(entries))
49 |
50 |
51 | def _run_cmd(cmd_string):
52 | subprocess.run(cmd_string.split(' '), check=True)
53 |
54 |
55 | class BlastClassifier(object):
56 | """Stateful wrapper for BLAST system calls."""
57 |
58 | def __init__(self, df):
59 | _, self._train_fasta = tempfile.mkstemp()
60 | _, self._blast_db = tempfile.mkstemp()
61 | _write_fasta(df, self._train_fasta)
62 | print(self._train_fasta)
63 | self._train_df = baseline_utils.load_ground_truth(self._train_fasta)
64 | self._label_vocab = df.label.unique()
65 | cmd = 'makeblastdb -in %s -dbtype prot -out %s' % (self._train_fasta,
66 | self._blast_db)
67 | _run_cmd(cmd)
68 |
69 | def __del__(self):
70 | os.remove(self._train_fasta)
71 | os.remove(self._blast_db)
72 |
73 | def predict(self, df):
74 | """Predicts labels by propagating labels from BLAST top hit."""
75 |
76 | _, query_fasta = tempfile.mkstemp()
77 | _, blast_output = tempfile.mkstemp()
78 | _write_fasta(df, query_fasta)
79 | cmd = 'blastp -query %s -db %s %s -out %s' % (query_fasta, self._blast_db,
80 | _BLAST_FLAGS, blast_output)
81 | _run_cmd(cmd)
82 |
83 | assert df.label.isin(self._label_vocab).all()
84 | query_df = baseline_utils.load_ground_truth(query_fasta)
85 | results_df = baseline_utils.load_blast_output(blast_output,
86 | self._label_vocab,
87 | self._train_df,
88 | query_df)
89 |
90 | os.remove(query_fasta)
91 | os.remove(blast_output)
92 | return results_df
93 |
94 |
95 | def _get_label(label_set):
96 | if label_set:
97 | assert len(label_set) == 1
98 | return next(iter(label_set))
99 | else:
100 | return None
101 |
102 |
103 | def _compute_accuracy(df):
104 | prediction = df.predicted_label.apply(_get_label)
105 | true_label = df.true_label.apply(_get_label)
106 | return np.mean(prediction == true_label)
107 |
108 |
109 | def _load(filename):
110 | df = pd.read_csv(filename)
111 | df.rename(columns=dict(sequence_name='accession'), inplace=True)
112 | return df
113 |
114 |
115 | def main(argv):
116 | if len(argv) > 1:
117 | raise app.UsageError('Too many command-line arguments.')
118 |
119 | train_df = _load(FLAGS.train_file)
120 | test_df = _load(FLAGS.test_file)
121 |
122 | blast_classifier = BlastClassifier(df=train_df)
123 | output_df = blast_classifier.predict(test_df)
124 |
125 | accuracy = _compute_accuracy(output_df)
126 | print('Accuracy = %f' % accuracy)
127 |
128 |
129 | if __name__ == '__main__':
130 | app.run(main)
131 |
--------------------------------------------------------------------------------
/contextual_lenses/contextual_lenses.py:
--------------------------------------------------------------------------------
1 | """Contextual lenses
2 |
3 | Creates sequence length independent representation of embedded sequences
4 | Original paper: https://arxiv.org/pdf/2002.08866.pdf.
5 | """
6 |
7 | import flax
8 | from flax import nn
9 |
10 | import jax
11 | import jax.nn
12 | import jax.numpy as jnp
13 | from jax.config import config
14 | config.enable_omnistaging()
15 |
16 | import numpy as np
17 |
18 | from operator import itemgetter
19 |
20 |
21 | def max_pool(x, padding_mask=None, pad_constant=1e8):
22 | """Apply padding, take maximum over sequence length axis."""
23 |
24 | if padding_mask is not None:
25 | x = x * padding_mask
26 | neg_mask = -pad_constant * (1 - padding_mask)
27 | x = x + neg_mask
28 |
29 | rep = jnp.max(x, axis=-2)
30 |
31 | return rep
32 |
33 |
34 | def mean_pool(x, padding_mask=None):
35 | """Apply padding, take mean over sequence length axis."""
36 |
37 | if padding_mask is not None:
38 | x = x * padding_mask
39 | rep = jnp.sum(x, axis=-2) / jnp.sum(padding_mask, axis=-2)
40 | else:
41 | rep = jnp.mean(x, axis=-2)
42 |
43 | return rep
44 |
45 |
46 | def linear_max_pool(x, rep_size, padding_mask=None):
47 | """Apply linear transformation + ReLU, apply padding,
48 | take maximum over sequence length.
49 | """
50 |
51 | x = nn.Dense(x,
52 | rep_size,
53 | kernel_init=nn.initializers.xavier_uniform(),
54 | bias_init=nn.initializers.normal(stddev=1e-6))
55 |
56 | x = nn.relu(x)
57 |
58 | rep = max_pool(x, padding_mask=padding_mask)
59 |
60 | return rep
61 |
62 |
63 | def linear_mean_pool(x, rep_size, padding_mask=None):
64 | """Apply linear transformation + ReLU, apply padding,
65 | take mean over sequence length.
66 | """
67 |
68 | x = nn.Dense(x,
69 | rep_size,
70 | kernel_init=nn.initializers.xavier_uniform(),
71 | bias_init=nn.initializers.normal(stddev=1e-6))
72 |
73 | x = nn.relu(x)
74 |
75 | rep = mean_pool(x, padding_mask=padding_mask)
76 |
77 | return rep
78 |
79 |
80 | class GatedConv(nn.Module):
81 | """Gated Convolutional lens followed by max pooling,
82 | see original paper for details.
83 | """
84 | def apply(self,
85 | x,
86 | rep_size,
87 | m_layers,
88 | m_features,
89 | m_kernel_sizes,
90 | conv_rep_size,
91 | padding_mask=None):
92 |
93 | H_0 = nn.relu(nn.Dense(x, conv_rep_size))
94 | G_0 = nn.relu(nn.Dense(x, conv_rep_size))
95 | H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2)
96 |
97 | for layer in range(1, m_layers + 1):
98 |
99 | if layer < m_layers:
100 | H_features, G_features = m_features[layer - 1]
101 | else:
102 | H_features, G_features = conv_rep_size, conv_rep_size
103 |
104 | H_kernel_size, G_kernel_size = m_kernel_sizes[layer - 1]
105 |
106 | H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1))
107 | G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1))
108 |
109 | if layer < m_layers:
110 | H = nn.relu(H)
111 | G = nn.relu(G)
112 | else:
113 | H = nn.tanh(H)
114 | G = nn.sigmoid(G)
115 |
116 | H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2)
117 |
118 | F = H * G + G_0
119 |
120 | rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size)
121 |
122 | return rep
123 |
124 |
125 | def gated_conv(x,
126 | rep_size,
127 | m_layers,
128 | m_features,
129 | m_kernel_sizes,
130 | conv_rep_size,
131 | padding_mask=None):
132 | """Calls GatedConv method for use as a lens."""
133 |
134 | rep = GatedConv(x,
135 | rep_size=rep_size,
136 | m_features=m_features,
137 | m_layers=m_layers,
138 | m_kernel_sizes=m_kernel_sizes,
139 | conv_rep_size=conv_rep_size,
140 | padding_mask=padding_mask)
141 |
142 | return rep
143 |
144 |
145 | def reduce_fn_name_to_fn(reduce_fn_name):
146 | """Returns reduce_fn corresponding to reduce_fn_name."""
147 |
148 | if reduce_fn_name == 'mean_pool':
149 | reduce_fn = mean_pool
150 | elif reduce_fn_name == 'max_pool':
151 | reduce_fn = max_pool
152 | elif reduce_fn_name == 'linear_mean_pool':
153 | reduce_fn = linear_mean_pool
154 | elif reduce_fn_name == 'linear_max_pool':
155 | reduce_fn = linear_max_pool
156 | elif reduce_fn_name == 'gated_conv':
157 | reduce_fn = gated_conv
158 | else:
159 | raise ValueError('Incorrect lens name specified.')
160 |
161 | return reduce_fn
162 |
--------------------------------------------------------------------------------
/contextual_lenses/test/test_pooling.py:
--------------------------------------------------------------------------------
1 | """Tests for mean and max pool reduce functions."""
2 |
3 |
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | import numpy as np
8 |
9 | from absl.testing import parameterized
10 | from absl.testing import absltest
11 |
12 | from contextual_lenses import mean_pool, max_pool
13 |
14 |
15 | # Test cases:
16 | test1 = {
17 | 'x': jnp.array([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
18 | [[0, -1, 2, -3], [4, -5, 6, -7], [8, -9, 10, -11]]]),
19 | 'padding_mask' : None,
20 | 'mean_pool_rep': jnp.array([[4, 5, 6, 7],
21 | [4, -5, 6, -7]]),
22 | 'max_pool_rep': jnp.array([[8, 9, 10, 11],
23 | [8, -1, 10, -3]])
24 | }
25 |
26 | test2 = {
27 | 'x': jnp.array([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
28 | [[0, -1, 2, -3], [4, -5, 6, -7], [8, -9, 10, -11]]]),
29 | 'padding_mask' : jnp.array([[[1], [1], [1]]
30 | [1], [1], [1]]),
31 | 'mean_pool_rep': jnp.array([[4, 5, 6, 7],
32 | [4, -5, 6, -7]]),
33 | 'max_pool_rep': jnp.array([[8, 9, 10, 11],
34 | [8, -1, 10, -3]])
35 | }
36 |
37 | test3 = {
38 | 'x': jnp.array([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
39 | [[0, -1, 2, -3], [4, -5, 6, -7], [8, -9, 10, -11]]]),
40 | 'padding_mask' : jnp.array([[[1], [1], [1]],
41 | [[1], [1], [0]]]),
42 | 'mean_pool_rep': jnp.array([[4, 5, 6, 7],
43 | [2, -3, 4, -5]]),
44 | 'max_pool_rep': jnp.array([[8, 9, 10, 11],
45 | [4, -1, 6, -3]])
46 | }
47 |
48 | test4 = {
49 | 'x': jnp.array([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]],
50 | [[0, -1, 2, -3], [4, -5, 6, -7], [8, -9, 10, -11]]]),
51 | 'padding_mask' : jnp.array([[[1], [0], [1]],
52 | [[0], [1], [0]]]),
53 | 'mean_pool_rep': jnp.array([[4, 5, 6, 7],
54 | [4, -5, 6, -7]]),
55 | 'max_pool_rep': jnp.array([[8, 9, 10, 11],
56 | [4, -5, 6, -7]])
57 | }
58 |
59 | test5 = {
60 | 'x': jnp.array([[[5, 2, -5], [1, -3, 4], [-3, -8, 1]],
61 | [[-2, -7, 4], [9, -4, 5], [-1, 2, -3]]]),
62 | 'padding_mask' : None,
63 | 'mean_pool_rep': jnp.array([[1, -3, 0],
64 | [2, -3, 2]]),
65 | 'max_pool_rep': jnp.array([[5, 2, 4],
66 | [9, 2, 5]])
67 | }
68 |
69 | test6 = {
70 | 'x': jnp.array([[[5, 2, -5], [1, -3, 4], [-3, -8, 1]],
71 | [[-2, -7, 4], [9, -4, 5], [-1, 2, -3]]]),
72 | 'padding_mask' : jnp.array([[[1], [1], [1]],
73 | [[1], [1], [1]]]),
74 | 'mean_pool_rep': jnp.array([[1, -3, 0],
75 | [2, -3, 2]]),
76 | 'max_pool_rep': jnp.array([[5, 2, 4],
77 | [9, 2, 5]])
78 | }
79 |
80 | test7 = {
81 | 'x': jnp.array([[[6, 2, -5], [1, -2, 4], [-3, -8, 0]],
82 | [[-2, -7, 4], [10, -4, 5], [-1, 3, -3]]]),
83 | 'padding_mask' : jnp.array([[[0], [1], [1]],
84 | [[1], [0], [0]]]),
85 | 'mean_pool_rep': jnp.array([[-1, -5, 2],
86 | [-2, -7, 4]]),
87 | 'max_pool_rep': jnp.array([[1, -2, 4],
88 | [-2, -7, 4]])
89 | }
90 |
91 | test8 = {
92 | 'x': jnp.array([[[6, 2, -5], [1, -2, 4], [-3, -8, 0]],
93 | [[-2, -7, 4], [10, -4, 5], [-1, 3, -3]]]),
94 | 'padding_mask' : jnp.array([[[1], [0], [1]],
95 | [[1], [1], [0]]]),
96 | 'mean_pool_rep': jnp.array([[3/2, -3, -5/2],
97 | [4, -11/2, 9/2]]),
98 | 'max_pool_rep': jnp.array([[6, 2, 0],
99 | [10, -4, 5]])
100 | }
101 |
102 | tests = (test1,
103 | test2,
104 | test3,
105 | test4,
106 | test5,
107 | test6,
108 | test7,
109 | test8)
110 |
111 |
112 | class TestPooling(parameterized.TestCase):
113 | """Abstract method for testing mean and max pool reduce functions."""
114 |
115 | @parameterized.parameters(
116 | *tests
117 | )
118 | def test_mean_pool(self, x, padding_mask, mean_pool_rep, **unused_kwargs):
119 | self.assertTrue(jnp.array_equal(mean_pool(x, padding_mask), mean_pool_rep))
120 |
121 | @parameterized.parameters(
122 | *tests
123 | )
124 | def test_max_pool(self, x, padding_mask, max_pool_rep, **unused_kwargs):
125 | self.assertTrue(jnp.array_equal(max_pool(x, padding_mask), max_pool_rep))
126 |
127 | @parameterized.parameters(
128 | *tests
129 | )
130 | def test_max_pool_greater_than_mean_pool(self, x, padding_mask, **unused_kwargs):
131 | self.assertTrue((max_pool(x, padding_mask=None) >= mean_pool(x, padding_mask=None)).all())
132 |
133 |
134 | if __name__ == '__main__':
135 | absltest.main()
136 |
--------------------------------------------------------------------------------
/contextual_lenses/pfam_utils.py:
--------------------------------------------------------------------------------
1 | """Utils for Pfam family classification experiments."""
2 |
3 | import os
4 |
5 | import jax
6 | import jax.numpy as jnp
7 |
8 | import numpy as np
9 |
10 | import tensorflow as tf
11 |
12 | import pandas as pd
13 |
14 | import matplotlib.pyplot as plt
15 |
16 | import scipy.stats
17 |
18 | import sklearn.metrics as metrics
19 | from sklearn.neighbors import KNeighborsClassifier as knn
20 |
21 | from pkg_resources import resource_filename
22 |
23 | from fs_gcsfs import GCSFS
24 |
25 | from google_research.protein_lm import domains
26 |
27 | from contextual_lenses.train_utils import create_data_iterator
28 |
29 | from contextual_lenses.loss_fns import cross_entropy_loss
30 |
31 |
32 | # Data preprocessing.
33 | # Original code source: https://www.kaggle.com/drewbryant/starter-pfam-seed-random-split.
34 | def read_all_shards(partition, data_dir, bucket_name):
35 | """Combines different CSVs into a single dataframe."""
36 |
37 | shards = []
38 | gcsfs = GCSFS(bucket_name)
39 | for fn in gcsfs.listdir(os.path.join(data_dir, partition)):
40 | with gcsfs.open(os.path.join(data_dir, partition, fn)) as f:
41 | shards.append(pd.read_csv(f, index_col=None))
42 |
43 | return pd.concat(shards)
44 |
45 |
46 | def mod_family_accession(family_accession):
47 | """Reduces family accession to everything prior to '.'."""
48 |
49 | return family_accession[:family_accession.index('.')]
50 |
51 |
52 | # Pfam protein_lm domain.
53 | PFAM_PROTEIN_DOMAIN = domains.VariableLengthDiscreteDomain(
54 | vocab=domains.ProteinVocab(include_anomalous_amino_acids=True,
55 | include_bos=True,
56 | include_eos=True,
57 | include_pad=True,
58 | include_mask=True),
59 | length=512)
60 |
61 |
62 | # Number of categories for one-hot encoding.
63 | PFAM_NUM_CATEGORIES = 27
64 |
65 |
66 | def residues_to_one_hot_inds(seq):
67 | """Converts amino acid residues to one hot indices."""
68 |
69 | one_hot_inds = PFAM_PROTEIN_DOMAIN.encode([seq])[0]
70 |
71 | return one_hot_inds
72 |
73 |
74 | def get_family_ids():
75 | """Pfam family ids."""
76 |
77 | family_ids = open(
78 | resource_filename('contextual_lenses.resources', 'pfam_family_ids.txt'),
79 | 'r').readlines()
80 |
81 | return family_ids
82 |
83 | def get_family_id_to_index():
84 | """Dictionary mapping family id to index."""
85 |
86 | family_ids = open(
87 | resource_filename('contextual_lenses.resources', 'pfam_family_ids.txt'),
88 | 'r').readlines()
89 | family_id_to_index = {}
90 | for i, family_id in enumerate(family_ids):
91 | family_id_to_index[family_id.replace('\n', '')] = i
92 |
93 | return family_id_to_index
94 |
95 |
96 | def create_pfam_df(family_accessions,
97 | test=False,
98 | samples=None,
99 | random_state=0,
100 | data_partitions_dirpath='random_split/',
101 | gcs_bucket='neuralblast_public'):
102 | """Processes Pfam data into a featurized dataframe with samples many entries per family."""
103 |
104 | family_id_to_index = get_family_id_to_index()
105 |
106 | if test:
107 | pfam_df = read_all_shards(partition='test',
108 | data_dir=data_partitions_dirpath,
109 | bucket_name=gcs_bucket)
110 | else:
111 | pfam_df = read_all_shards(partition='train',
112 | data_dir=data_partitions_dirpath,
113 | bucket_name=gcs_bucket)
114 |
115 | pfam_df['mod_family_accession'] = pfam_df.family_accession.apply(
116 | lambda x: mod_family_accession(x))
117 | pfam_df = pfam_df[pfam_df.mod_family_accession.isin(family_accessions)]
118 | pfam_df['index'] = pfam_df.family_id.apply(lambda x: family_id_to_index[x])
119 |
120 | pfam_df['one_hot_inds'] = pfam_df.sequence.apply(
121 | lambda x: residues_to_one_hot_inds(x[:512]))
122 |
123 | if samples is not None:
124 | pfam_df = pfam_df.sample(frac=1,
125 | replace=False,
126 | random_state=random_state)
127 | pfam_df = pfam_df.groupby('mod_family_accession').head(
128 | samples).reset_index()
129 |
130 | return pfam_df
131 |
132 |
133 | def create_pfam_seq_batches(family_accessions,
134 | batch_size,
135 | test=False,
136 | samples=None,
137 | epochs=1,
138 | drop_remainder=False,
139 | buffer_size=None,
140 | shuffle_seed=0,
141 | sample_random_state=0,
142 | data_partitions_dirpath='random_split/',
143 | gcs_bucket='neuralblast_public',
144 | as_numpy=False):
145 | """Creates iterable object of Pfam sequences."""
146 |
147 | pfam_df = create_pfam_df(family_accessions,
148 | test=test,
149 | samples=samples,
150 | random_state=sample_random_state,
151 | data_partitions_dirpath=data_partitions_dirpath,
152 | gcs_bucket=gcs_bucket)
153 |
154 | pfam_batches = create_data_iterator(df=pfam_df,
155 | input_col='one_hot_inds',
156 | output_col='index',
157 | batch_size=batch_size,
158 | epochs=epochs,
159 | buffer_size=buffer_size,
160 | seed=shuffle_seed,
161 | drop_remainder=drop_remainder,
162 | add_outputs=False,
163 | as_numpy=as_numpy)
164 |
165 | return pfam_batches
166 |
167 |
168 | def create_pfam_batches(family_accessions,
169 | batch_size,
170 | test=False,
171 | samples=None,
172 | epochs=1,
173 | drop_remainder=False,
174 | buffer_size=None,
175 | shuffle_seed=0,
176 | sample_random_state=0,
177 | data_partitions_dirpath='random_split/',
178 | gcs_bucket='neuralblast_public',
179 | as_numpy=True):
180 | """Creates iterable object of Pfam data batches."""
181 |
182 | pfam_df = create_pfam_df(family_accessions,
183 | test=test,
184 | samples=samples,
185 | random_state=sample_random_state,
186 | data_partitions_dirpath=data_partitions_dirpath,
187 | gcs_bucket=gcs_bucket)
188 |
189 | pfam_indexes = pfam_df['index'].values
190 |
191 | pfam_batches = create_data_iterator(df=pfam_df,
192 | input_col='one_hot_inds',
193 | output_col='index',
194 | batch_size=batch_size,
195 | epochs=epochs,
196 | buffer_size=buffer_size,
197 | seed=shuffle_seed,
198 | drop_remainder=drop_remainder,
199 | as_numpy=as_numpy)
200 |
201 | return pfam_batches, pfam_indexes
202 |
203 |
204 | # Model evaluation.
205 | def pfam_evaluate(predict_fn,
206 | test_family_accessions,
207 | title,
208 | loss_fn_kwargs,
209 | batch_size=512,
210 | data_partitions_dirpath='random_split/',
211 | gcs_bucket='neuralblast_public'):
212 | """Computes predicted family ids and measures performance in cross entropy and accuracy."""
213 |
214 | test_batches, test_indexes = create_pfam_batches(
215 | family_accessions=test_family_accessions,
216 | batch_size=batch_size,
217 | test=True,
218 | buffer_size=1,
219 | gcs_bucket=gcs_bucket,
220 | data_partitions_dirpath=data_partitions_dirpath)
221 |
222 | pred_indexes = []
223 | cross_entropy = 0.
224 |
225 | for batch in iter(test_batches):
226 |
227 | X, Y = batch
228 |
229 | Y_hat = predict_fn(X)
230 |
231 | cross_entropy += cross_entropy_loss(Y, Y_hat, **loss_fn_kwargs)
232 |
233 | preds = jnp.argmax(Y_hat, axis=1)
234 | for pred in preds:
235 | pred_indexes.append(pred)
236 |
237 | pred_indexes = np.array(pred_indexes)
238 |
239 | acc = metrics.accuracy_score(test_indexes, pred_indexes)
240 |
241 | results = {
242 | 'title': title,
243 | 'cross_entropy': cross_entropy,
244 | 'accuracy': acc,
245 | }
246 |
247 | return results, pred_indexes
248 |
249 |
250 | def compute_embeddings(encoder, data_batches):
251 | """Computes sequence embeddings according to a specified encoder."""
252 |
253 | vectors = []
254 | for batch in iter(data_batches):
255 | X, Y = batch
256 | X_embedded = encoder(X)
257 | for vec in np.array(X_embedded):
258 | vectors.append(vec)
259 | vectors = np.array(vectors)
260 |
261 | return vectors
262 |
263 |
264 | def pfam_nearest_neighbors_classification(
265 | encoder,
266 | family_accessions,
267 | batch_size=512,
268 | n_neighbors=1,
269 | train_samples=None,
270 | test_samples=None,
271 | shuffle_seed=0,
272 | sample_random_state=0,
273 | data_partitions_dirpath='random_split/',
274 | gcs_bucket='neuralblast_public'):
275 | """Nearest neighbors classification on Pfam families using specified encoder."""
276 |
277 | train_batches, train_indexes = create_pfam_batches(
278 | family_accessions=family_accessions,
279 | batch_size=batch_size,
280 | samples=train_samples,
281 | buffer_size=1,
282 | shuffle_seed=shuffle_seed,
283 | sample_random_state=sample_random_state,
284 | data_partitions_dirpath=data_partitions_dirpath,
285 | gcs_bucket=gcs_bucket)
286 | test_batches, test_indexes = create_pfam_batches(
287 | family_accessions=family_accessions,
288 | batch_size=batch_size,
289 | test=True,
290 | samples=test_samples,
291 | buffer_size=1,
292 | shuffle_seed=shuffle_seed,
293 | sample_random_state=sample_random_state,
294 | data_partitions_dirpath=data_partitions_dirpath,
295 | gcs_bucket=gcs_bucket)
296 |
297 | train_vectors = compute_embeddings(encoder, train_batches)
298 | test_vectors = compute_embeddings(encoder, test_batches)
299 |
300 | knn_classifier = knn(n_neighbors=n_neighbors)
301 | knn_classifier.fit(train_vectors, train_indexes)
302 | knn_predictions = knn_classifier.predict(test_vectors)
303 |
304 | knn_accuracy = metrics.accuracy_score(test_indexes, knn_predictions)
305 |
306 | results = {
307 | str(n_neighbors) + "-nn accuracy": knn_accuracy,
308 | 'train_samples': train_samples,
309 | 'test_samples': test_samples
310 | }
311 |
312 | return results, knn_predictions, knn_classifier
313 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Protein Embedding Search
2 |
3 | **This is not an officially supported Google product.**
4 |
5 | Protein database search tools such as [BLAST](https://blast.ncbi.nlm.nih.gov/Blast.cgi) are instrumental for research in the life sciences. However, they are slow and are based on surface-level sequence similarity. We are exploring using neural networks to improve the speed and accuracy of finding relevant sequences from these databases.
6 |
7 | More specifically, we are aiming to learn fixed-length protein embeddings using [contextual lenses](https://arxiv.org/pdf/2002.08866.pdf). Generally speaking, a sequence level protein representation, such as a one-hot encoding, is an array of the the form (sequence_length, n) where n is the amino acid embedding dimension. A contextual lens is a (learnable) map from the (sequence_length, n)-array to an (m,)-vector where m is independent of sequence_length. Embeddings are constructed using an encoder function followed by a contextual lens. To learn these embeddings a downstream prediction task is performed using a single dense layer. Gradients are backpropagated through all 3 components of the architecture (encoder, lens, predictor) using the Adam optimizer with variable (potentially zero) learning rates and weight decays per component.
8 |
9 | ### Encoders
10 | - [One-hot](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/contextual_lenses/encoders.py#L21): non-learnable
11 | - [CNN](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/contextual_lenses/encoders.py#L46): learnable
12 | - [Transformer](https://github.com/google-research/google-research/blob/master/protein_lm/models.py#L870): learnable and pretrainable
13 |
14 | ### Lenses
15 | - [Mean/Max-Pool](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/contextual_lenses/contextual_lenses.py#L21): non-learnable
16 | - [Linear-Mean/Max-Pool](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/contextual_lenses/contextual_lenses.py#L46): learnable
17 | - [GatedConvolution](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/contextual_lenses/contextual_lenses.py#L125): learnable and self-attentive
18 |
19 | ## TAPE Protein Engineering Tasks
20 | [TAPE](https://arxiv.org/pdf/1906.08230.pdf) proposes two protein engineering tasks: fluorescence prediction and stability prediction. We implement our lens architectures on these tasks in a [Google Colab notebook](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/cnn_protein_landscapes.ipynb). We find that for the fluorescence task both linear regression on the one-hot encodings and 1-layer convolution compete with and outperform the best pretrained language models in TAPE. Likewise, we find that for the stability task 3-layer convolution competes with and outperforms TAPE's models. See below for our results compared to TAPE's results (bold represents best performance).
21 |
22 | ### Fluorescence
23 | MSE is mean squared error and rho represents Spearman's rank correlation coefficient.
24 | | Model Type | Model | Full Test Set (MSE, rho) | Bright Mode (MSE, rho) | Dark Mode (MSE, rho) |
25 | | ---------- | ----- | :----------------------: | :--------------------: | :------------------: |
26 | | Baseline | Linear Regression | (0.35, **0.69**) | (0.09, **0.68**) | (0.33, **0.05**) |
27 | | Lens Architecture | 1-Layer CNN + MaxPool | (0.26, **0.69**) | (0.09, 0.65) | (0.29, **0.05**) |
28 | | Lens Architecture | 1-Layer CNN + LinearMaxPool | (0.23, **0.69**) | (0.12, 0.66) | (0.28, **0.05**) |
29 | | TAPE | Best of all models | (**0.19**, 0.68) | (**0.07**, 0.63) | (**0.22**, **0.05**)|
30 |
31 |
32 | ### Stability
33 | Accuracy (Acc) is measured using the parent protein as a decision boundary and labeling mutations as beneficial if predicted stability is greater than predicted parent stability and deleterious if the opposite is true. rho represents Spearman's rank correlation coefficient. The letters A and B represent the alpha and beta topologies, respectively.
34 | | Model Type | Model | Full Test Set (rho, Acc) | AAA (rho, Acc) | ABBA (rho, Acc) | BABB (rho, Acc) | BBABB (rho, Acc) |
35 | | ---------- | ----- | :---------------------------: | :-----------------: | :------------------: | :------------------: | :-------------------: |
36 | | Baseline | Linear Regression | (0.49, 0.60) | (0.21, 0.66) | (-0.03, 0.6) | (0.51, 0.64) | (0.38, 0.61) |
37 | | Lens Architecture | 3-Layer CNN + MaxPool | (0.76, 0.75) | (0.69, **0.71**) | (0.37, 0.70) | (0.50, 0.72) | (0.60, 0.68) |
38 | | Lens Architecture | Dilated 3-Layer CNN + MaxPool | (0.75, 0.73) | (0.67, 0.69) | (0.49, 0.69) | (0.61, 0.70) | (0.53, 0.64) |
39 | | Lens Architecture | 3-Layer CNN + LinearMaxPool | (0.71, **0.77**) | (0.59, 0.69) | (**0.52**, 0.77) | (0.55, **0.73**) | (0.60, **0.70**) |
40 | | Lens Architecture | Ensemble (Average) of above CNN models | (**0.79**, **0.77**) | (0.67, **0.71**) | (**0.53**, 0.75) | (0.65, **0.74**) | (0.60, **0.70**) |
41 | | TAPE | Best of all models | (0.73, 0.70) | (**0.72**, 0.70) | (0.48, **0.79**) | (**0.68**, 0.71) | (**0.67**, **0.70**) |
42 |
43 |
44 | ## Downstream Task
45 | The downstream task we use to train embeddings is Pfam family classification. We pick an encoder and a lens and train the architecture to predict a protein's family using only its primary sequence. We train on 10000 families in the data set and measure **Lens Accuracy**: the accuracy achieved on the *test set of train families* by the architecture trained for family prediction on the *train set of train families*. We then take the embeddings from this trained model and use them to do family prediction on 1000 holdout families with KNN (using 1 neighbor). This test allows us to assess the extent of transfer learning by seeing how much the embeddings have learned about the holdout families from the train families. In theory, a perfect model would map all proteins that are members of the same family to a single vector. To test for this we measure **n-Sample Test KNN Accuracy**: The accuracy achieved on the *test set of test families* by a KNN classifier trained on the embeddings (from our architecture) of the *train set of test families* using *at most n samples per family*.
46 |
47 | ### Pretraining
48 | We also measure the effect that pretraining has on the performance of a language model encoder. There has been a great deal of interest in measuring the degree to which pretraining protein language models improves their performance on downstream tasks. TAPE investigates this and proposes baselines. Our results indicate that pretraining offers a substantial boost in performance on the family classification task. We use transformer language models, specifically BERT models similar to to the the [ProGen model](https://www.biorxiv.org/content/10.1101/2020.03.07.982272v2.full.pdf) and the [models used by FAIR](https://www.biorxiv.org/content/10.1101/622803v2.full.pdf). Our [models](https://github.com/google-research/google-research/tree/master/protein_lm) are implemented in jax/flax and pretrained on the [TrEMBL protein corpus](https://www.uniprot.org/statistics/TrEMBL).
49 |
50 | ## Results
51 | In the table below we show the accuracies achieved using KNN on the model embeddings as well as KNN using BLAST's weighted edit distance. We show a simple 2-layer CNN, 3 different size language models both with and without pretraining, and the [Blundell CNN model](https://www.biorxiv.org/content/10.1101/626507v4.full.pdf). All bolded numbers represent better performance compared to BLAST. The key takeways are the performance of pretrained language models on 1-sample classification and the substantial performance boost from pretraining said models.
52 |
53 | | Model | 1-Sample Accuracy | 5-Sample Accuracy | 10-Sample Accuracy | 50-Sample Accuracy |
54 | |--------------------------------------------|:-----------------:|:-----------------:|:------------------:|:------------------:|
55 | | BLAST | 0.860750 | 0.978355 | 0.991342 | 0.996392 |
56 | | 2-layer CNN | 0.687815 | 0.870944 | 0.914924 | 0.956741 |
57 | | Small Transformer | 0.769286 | 0.920692 | 0.952415 | 0.974045 |
58 | | Pretrained Small Transformer | **0.873828** | 0.968998 | 0.979813 | 0.992790 |
59 | | Medium Transformer | 0.778659 | 0.921413 | 0.956741 | 0.981255 |
60 | | Pretrained Medium Transformer | **0.863775** | 0.968277 | 0.984859 | 0.994232 |
61 | | Large Transformer | 0.749820 | 0.894737 | 0.937996 | 0.970440 |
62 | | Pretrained Large Transformer | **0.865898** | 0.974045 | 0.984859 | 0.995674 |
63 | | Blundell Lens-Family CNN | **0.877345** | **0.980519** | **0.992063** | 0.993506 |
64 | | Blundell Full-Family CNN** | **0.923521** | **0.984848** | **0.992785** | 0.995671 |
65 | | Blundell Full-Family CNN** w/ Whitening*** | **0.940837** | **0.988456** | **0.996392** | **0.996392** |
66 |
67 | ** The Full-Family Blundell CNN is not performing transfer learning. It was trained on families that appear in the KNN task.
68 |
69 | *** Whitened embeddings are obtained by performing PCA on the embeddings of all Pfam seed sequences and applying the corresponding whitening transformation to the KNN train and test sequences.
70 |
71 | Below we show plots of the top 10 n-Sample Test KNN Accuracies vs. Lens Accuracies for different models and for n = 1, 5, 10, 50. The key takeaways are the noticable boost pretraining the language models provides, the fact that Lens Accuracy is not a perfect predictor of Test KNN Accuracy, and the independence of performance and transformer size.
72 |
73 | 
74 |
75 | 
76 |
77 | 
78 |
79 | 
80 |
81 | ## Quickstart
82 | To clone this project run
83 | ```
84 | git clone --recurse-submodules https://github.com/googleinterns/protein-embedding-retrieval.git
85 | ```
86 |
87 | Once the project is cloned, the first step is to install [Caliban](https://github.com/google/caliban). We use Caliban for running individual jobs and parallelizing many jobs on GCP (Google Cloud Platform).
88 |
89 | For a simple demo on your machine (recommended only if it is equipped with a GPU) run
90 | ```
91 | caliban run --experiment_config demo.json pfam_experiment.py
92 | ```
93 | The demo takes ~3 hours with an Nvidia Tesla P100 GPU (the Caliban default). By default this will load data from and save data to the 'neuralblast_public' GCS bucket. You can change this by modifying the values of 'load_gcs_bucket' and 'save_gcs_bucket' in demo.json.
94 |
95 | To run on cloud first connect to a GCP project (equipped with GPU resources) by running
96 | ```
97 | gcloud init
98 | ```
99 | If your GCP project is named MY_PROJECT run
100 | ```
101 | export PROJECT_ID=MY_PROJECT
102 | ```
103 | Finally, run
104 | ```
105 | caliban cloud --experiment_config demo.json pfam_experiment.py
106 | ```
107 |
108 |
109 | ## Reproducing our Results
110 | To reproduce our results you first need to connect to a GCP project, ideally one with a large number of GPUs, and clone the project. Then take the [generate_params.py](https://github.com/googleinterns/protein-embedding-retrieval/blob/master/generate_params.py) script, and modify the variable 'save_gcs_bucket' in the call to main() using the GCS bucket you wish to save to (and potentially do the same for the one you want to load from 'load_gcs_bucket'). Run the script to generate the appropriate parameter combinations and run
111 | ```
112 | caliban cloud --experiment_config params_combinations.json pfam_experiment.py
113 | ```
114 |
115 | ## Source Code Headers
116 |
117 | Every file containing source code must include copyright and license
118 | information. This includes any JS/CSS files that you might be serving out to
119 | browsers. (This is to help well-intentioned people avoid accidental copying that
120 | doesn't comply with the license.)
121 |
122 | Apache header:
123 |
124 | Copyright 2020 Google LLC
125 |
126 | Licensed under the Apache License, Version 2.0 (the "License");
127 | you may not use this file except in compliance with the License.
128 | You may obtain a copy of the License at
129 |
130 | https://www.apache.org/licenses/LICENSE-2.0
131 |
132 | Unless required by applicable law or agreed to in writing, software
133 | distributed under the License is distributed on an "AS IS" BASIS,
134 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
135 | See the License for the specific language governing permissions and
136 | limitations under the License.
137 |
--------------------------------------------------------------------------------
/contextual_lenses/test/test_learning.py:
--------------------------------------------------------------------------------
1 | """End-to-end learning tests using MSE loss for contextual_lenses.py."""
2 |
3 |
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | import flax
8 | from flax import nn
9 |
10 | import numpy as np
11 |
12 | from absl.testing import parameterized
13 | from absl.testing import absltest
14 |
15 | from contextual_lenses import mean_pool, max_pool, \
16 | linear_max_pool, linear_mean_pool, gated_conv
17 |
18 | from train_utils import create_optimizer, train_step, \
19 | create_representation_model
20 |
21 | from encoders import one_hot_encoder, cnn_one_hot_encoder, \
22 | one_hot_pos_emb_encoder, cnn_one_hot_pos_emb_encoder
23 |
24 | from loss_fns import mse_loss
25 |
26 |
27 | def generate_random_sequences(batch_size=3, seq_len=12, num_categories=21):
28 | """Generates batch_size many random sequences of length seq_len from num_categories indices."""
29 |
30 | np.random.seed(0)
31 | input_data = jnp.array(np.random.randint(0, num_categories-1, size=(batch_size, seq_len)))
32 |
33 | return input_data
34 |
35 |
36 | def generate_random_targets(batch_size=3):
37 | """Generates batch_size many scalar targets."""
38 |
39 | np.random.seed(0)
40 | output_data = jnp.array(np.random.normal(size=(batch_size,)))
41 |
42 | return output_data
43 |
44 |
45 | def train(model, input_data, output_data, learning_rate=1e-3, epochs=1000):
46 | """Fits model to training data and returns train loss."""
47 |
48 | optimizer = create_optimizer(model, learning_rate=learning_rate, weight_decay=0)
49 |
50 | loss_fn_kwargs={}
51 |
52 | for epoch in range(epochs):
53 | optimizer = train_step(optimizer, input_data, output_data, mse_loss, loss_fn_kwargs)
54 |
55 | preds = jnp.squeeze(optimizer.target(input_data), axis=1)
56 |
57 | train_loss = jnp.mean(jnp.square(preds-output_data))
58 |
59 | return train_loss
60 |
61 |
62 | # Test cases:
63 |
64 | # One-hot:
65 | test1 = {
66 | 'testcase_name': 'max_pool',
67 | 'encoder_fn': one_hot_encoder,
68 | 'encoder_fn_kwargs': {
69 |
70 | },
71 | 'reduce_fn': max_pool,
72 | 'reduce_fn_kwargs': {
73 |
74 | },
75 | 'learning_rate': 1e-2,
76 | 'epochs': 1000,
77 | 'loss_threshold': 1e-4
78 | }
79 |
80 | test2 = {
81 | 'testcase_name': 'mean_pool',
82 | 'encoder_fn': one_hot_encoder,
83 | 'encoder_fn_kwargs': {
84 |
85 | },
86 | 'reduce_fn': mean_pool,
87 | 'reduce_fn_kwargs': {
88 |
89 | },
90 | 'learning_rate': 1e-2,
91 | 'epochs': 1000,
92 | 'loss_threshold': 1e-4
93 | }
94 |
95 | test3 = {
96 | 'testcase_name': 'linear_max_pool',
97 | 'encoder_fn': one_hot_encoder,
98 | 'encoder_fn_kwargs': {
99 |
100 | },
101 | 'reduce_fn': linear_max_pool,
102 | 'reduce_fn_kwargs': {
103 | 'rep_size': 512
104 | },
105 | 'learning_rate': 1e-3,
106 | 'epochs': 100,
107 | 'loss_threshold': 1e-4
108 | }
109 |
110 | test4 = {
111 | 'testcase_name': 'linear_mean_pool',
112 | 'encoder_fn': one_hot_encoder,
113 | 'encoder_fn_kwargs': {
114 |
115 | },
116 | 'reduce_fn': linear_mean_pool,
117 | 'reduce_fn_kwargs': {
118 | 'rep_size': 2048
119 | },
120 | 'learning_rate': 1e-3,
121 | 'epochs': 100,
122 | 'loss_threshold': 1e-4
123 | }
124 |
125 |
126 | test5 = {
127 | 'testcase_name': 'gated_conv',
128 | 'encoder_fn': one_hot_encoder,
129 | 'encoder_fn_kwargs': {
130 |
131 | },
132 | 'reduce_fn': gated_conv,
133 | 'reduce_fn_kwargs': {
134 | 'rep_size': 256,
135 | 'm_layers': 3,
136 | 'm_features': [[512, 512], [512, 512]],
137 | 'm_kernel_sizes': [[12, 12], [10, 10], [8, 8]],
138 | 'conv_rep_size': 256
139 | },
140 | 'learning_rate': 1e-3,
141 | 'epochs': 100,
142 | 'loss_threshold': 1e-4
143 | }
144 |
145 |
146 | # CNN + one-hot:
147 | test6 = {
148 | 'testcase_name': 'cnn_max_pool',
149 | 'encoder_fn': cnn_one_hot_encoder,
150 | 'encoder_fn_kwargs': {
151 | 'n_layers': 1,
152 | 'n_features': [256],
153 | 'n_kernel_sizes': [3]
154 | },
155 | 'reduce_fn': max_pool,
156 | 'reduce_fn_kwargs': {
157 |
158 | },
159 | 'learning_rate': 1e-3,
160 | 'epochs': 100,
161 | 'loss_threshold': 1e-4
162 | }
163 |
164 |
165 | test7 = {
166 | 'testcase_name': 'cnn_mean_pool',
167 | 'encoder_fn': cnn_one_hot_encoder,
168 | 'encoder_fn_kwargs': {
169 | 'n_layers': 1,
170 | 'n_features': [512],
171 | 'n_kernel_sizes': [5]
172 | },
173 | 'reduce_fn': mean_pool,
174 | 'reduce_fn_kwargs': {
175 |
176 | },
177 | 'learning_rate': 1e-3,
178 | 'epochs': 100,
179 | 'loss_threshold': 1e-4
180 | }
181 |
182 |
183 | test8 = {
184 | 'testcase_name': 'cnn_linear_max_pool',
185 | 'encoder_fn': cnn_one_hot_encoder,
186 | 'encoder_fn_kwargs': {
187 | 'n_layers': 1,
188 | 'n_features': [32],
189 | 'n_kernel_sizes': [3]
190 | },
191 | 'reduce_fn': linear_max_pool,
192 | 'reduce_fn_kwargs': {
193 | 'rep_size': 256
194 | },
195 | 'learning_rate': 1e-3,
196 | 'epochs': 100,
197 | 'loss_threshold': 1e-4
198 | }
199 |
200 |
201 | test9 = {
202 | 'testcase_name': 'cnn_linear_mean_pool',
203 | 'encoder_fn': cnn_one_hot_encoder,
204 | 'encoder_fn_kwargs': {
205 | 'n_layers': 1,
206 | 'n_features': [512],
207 | 'n_kernel_sizes': [5]
208 | },
209 | 'reduce_fn': linear_mean_pool,
210 | 'reduce_fn_kwargs': {
211 | 'rep_size': 256
212 | },
213 | 'learning_rate': 1e-3,
214 | 'epochs': 100,
215 | 'loss_threshold': 1e-4
216 | }
217 |
218 |
219 | test10 = {
220 | 'testcase_name': 'cnn_gated_conv',
221 | 'encoder_fn': cnn_one_hot_encoder,
222 | 'encoder_fn_kwargs': {
223 | 'n_layers': 1,
224 | 'n_features': [32],
225 | 'n_kernel_sizes': [3]
226 | },
227 | 'reduce_fn': gated_conv,
228 | 'reduce_fn_kwargs': {
229 | 'rep_size': 256,
230 | 'm_layers': 3,
231 | 'm_features': [[512, 512], [512, 512]],
232 | 'm_kernel_sizes': [[12, 12], [10, 10], [8, 8]],
233 | 'conv_rep_size': 256
234 | },
235 | 'learning_rate': 1e-3,
236 | 'epochs': 100,
237 | 'loss_threshold': 1e-4
238 | }
239 |
240 |
241 | # One-hot + positional embeddings:
242 | test11 = {
243 | 'testcase_name': 'pos_emb_max_pool',
244 | 'encoder_fn': one_hot_pos_emb_encoder,
245 | 'encoder_fn_kwargs': {
246 | 'max_len': 512,
247 | 'posemb_init': nn.initializers.normal(stddev=1e-6)
248 | },
249 | 'reduce_fn': max_pool,
250 | 'reduce_fn_kwargs': {
251 |
252 | },
253 | 'learning_rate': 1e-2,
254 | 'epochs': 1000,
255 | 'loss_threshold': 1e-4
256 | }
257 |
258 | test12 = {
259 | 'testcase_name': 'pos_emb_mean_pool',
260 | 'encoder_fn': one_hot_pos_emb_encoder,
261 | 'encoder_fn_kwargs': {
262 | 'max_len': 512,
263 | 'posemb_init': nn.initializers.normal(stddev=1e-6)
264 | },
265 | 'reduce_fn': mean_pool,
266 | 'reduce_fn_kwargs': {
267 |
268 | },
269 | 'learning_rate': 1e-2,
270 | 'epochs': 1000,
271 | 'loss_threshold': 1e-4
272 | }
273 |
274 | test13 = {
275 | 'testcase_name': 'pos_emb_linear_max_pool',
276 | 'encoder_fn': one_hot_pos_emb_encoder,
277 | 'encoder_fn_kwargs': {
278 | 'max_len': 512,
279 | 'posemb_init': nn.initializers.normal(stddev=1e-6)
280 | },
281 | 'reduce_fn': linear_max_pool,
282 | 'reduce_fn_kwargs': {
283 | 'rep_size': 256
284 | },
285 | 'learning_rate': 1e-3,
286 | 'epochs': 100,
287 | 'loss_threshold': 1e-4
288 | }
289 |
290 | test14 = {
291 | 'testcase_name': 'pos_emb_linear_mean_pool',
292 | 'encoder_fn': one_hot_pos_emb_encoder,
293 | 'encoder_fn_kwargs': {
294 | 'max_len': 512,
295 | 'posemb_init': nn.initializers.normal(stddev=1e-6)
296 | },
297 | 'reduce_fn': linear_mean_pool,
298 | 'reduce_fn_kwargs': {
299 | 'rep_size': 4096
300 | },
301 | 'learning_rate': 1e-3,
302 | 'epochs': 100,
303 | 'loss_threshold': 1e-4
304 | }
305 |
306 |
307 | test15 = {
308 | 'testcase_name': 'pos_emb_gated_conv',
309 | 'encoder_fn': one_hot_pos_emb_encoder,
310 | 'encoder_fn_kwargs': {
311 | 'max_len': 512,
312 | 'posemb_init': nn.initializers.normal(stddev=1e-6)
313 | },
314 | 'reduce_fn': gated_conv,
315 | 'reduce_fn_kwargs': {
316 | 'rep_size': 256,
317 | 'm_layers': 3,
318 | 'm_features': [[512, 512], [512, 512]],
319 | 'm_kernel_sizes': [[12, 12], [10, 10], [8, 8]],
320 | 'conv_rep_size': 256
321 | },
322 | 'learning_rate': 1e-3,
323 | 'epochs': 100,
324 | 'loss_threshold': 1e-4
325 | }
326 |
327 |
328 | # CNN + one-hot + positional embeddings:
329 | test16 = {
330 | 'testcase_name': 'cnn_pos_emb_max_pool',
331 | 'encoder_fn': cnn_one_hot_pos_emb_encoder,
332 | 'encoder_fn_kwargs': {
333 | 'n_layers': 1,
334 | 'n_features': [256],
335 | 'n_kernel_sizes': [3],
336 | 'max_len': 512,
337 | 'posemb_init': nn.initializers.normal(stddev=1e-6)
338 | },
339 | 'reduce_fn': max_pool,
340 | 'reduce_fn_kwargs': {
341 |
342 | },
343 | 'learning_rate': 1e-3,
344 | 'epochs': 100,
345 | 'loss_threshold': 1e-4
346 | }
347 |
348 |
349 | test17 = {
350 | 'testcase_name': 'cnn_pos_emb_mean_pool',
351 | 'encoder_fn': cnn_one_hot_pos_emb_encoder,
352 | 'encoder_fn_kwargs': {
353 | 'n_layers': 1,
354 | 'n_features': [512],
355 | 'n_kernel_sizes': [5],
356 | 'max_len': 512,
357 | 'posemb_init': nn.initializers.normal(stddev=1e-6)
358 | },
359 | 'reduce_fn': mean_pool,
360 | 'reduce_fn_kwargs': {
361 |
362 | },
363 | 'learning_rate': 1e-3,
364 | 'epochs': 100,
365 | 'loss_threshold': 1e-4
366 | }
367 |
368 |
369 | test18 = {
370 | 'testcase_name': 'cnn_pos_emb_linear_max_pool',
371 | 'encoder_fn': cnn_one_hot_pos_emb_encoder,
372 | 'encoder_fn_kwargs': {
373 | 'n_layers': 1,
374 | 'n_features': [32],
375 | 'n_kernel_sizes': [3],
376 | 'max_len': 512,
377 | 'posemb_init': nn.initializers.normal(stddev=1e-6)
378 | },
379 | 'reduce_fn': linear_max_pool,
380 | 'reduce_fn_kwargs': {
381 | 'rep_size': 256
382 | },
383 | 'learning_rate': 1e-3,
384 | 'epochs': 100,
385 | 'loss_threshold': 1e-4
386 | }
387 |
388 |
389 | test19 = {
390 | 'testcase_name': 'cnn_pos_emb_linear_mean_pool',
391 | 'encoder_fn': cnn_one_hot_pos_emb_encoder,
392 | 'encoder_fn_kwargs': {
393 | 'n_layers': 1,
394 | 'n_features': [512],
395 | 'n_kernel_sizes': [5],
396 | 'max_len': 512,
397 | 'posemb_init': nn.initializers.normal(stddev=1e-6)
398 | },
399 | 'reduce_fn': linear_mean_pool,
400 | 'reduce_fn_kwargs': {
401 | 'rep_size': 256
402 | },
403 | 'learning_rate': 1e-3,
404 | 'epochs': 100,
405 | 'loss_threshold': 1e-4
406 | }
407 |
408 |
409 | test20 = {
410 | 'testcase_name': 'cnn_pos_emb_gated_conv',
411 | 'encoder_fn': cnn_one_hot_pos_emb_encoder,
412 | 'encoder_fn_kwargs': {
413 | 'n_layers': 1,
414 | 'n_features': [32],
415 | 'n_kernel_sizes': [3],
416 | 'max_len': 512,
417 | 'posemb_init': nn.initializers.normal(stddev=1e-6)
418 | },
419 | 'reduce_fn': gated_conv,
420 | 'reduce_fn_kwargs': {
421 | 'rep_size': 256,
422 | 'm_layers': 3,
423 | 'm_features': [[512, 512], [512, 512]],
424 | 'm_kernel_sizes': [[12, 12], [10, 10], [8, 8]],
425 | 'conv_rep_size': 256
426 | },
427 | 'learning_rate': 1e-3,
428 | 'epochs': 100,
429 | 'loss_threshold': 1e-4
430 | }
431 |
432 |
433 | tests = (test1,
434 | test2,
435 | test3,
436 | test4,
437 | test5,
438 | test6,
439 | test7,
440 | test8,
441 | test9,
442 | test10,
443 | test11,
444 | test12,
445 | test13,
446 | test14,
447 | test15,
448 | test16,
449 | test17,
450 | test18,
451 | test19,
452 | test20)
453 |
454 |
455 | class TestLearning(parameterized.TestCase):
456 | """Abstract method for testing synthetic learning of encoders and lenses."""
457 |
458 | @parameterized.named_parameters(
459 | *tests
460 | )
461 | def test_learning(self, encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs,
462 | learning_rate=1e-3, epochs=1000, loss_threshold=1e-4):
463 |
464 | input_data = generate_random_sequences(batch_size=3, seq_len=12, num_categories=21)
465 | output_data = generate_random_targets(batch_size=3)
466 |
467 | model = create_representation_model(encoder_fn=encoder_fn,
468 | encoder_fn_kwargs=encoder_fn_kwargs,
469 | reduce_fn=reduce_fn,
470 | reduce_fn_kwargs=reduce_fn_kwargs,
471 | num_categories=21,
472 | output_features=1)
473 |
474 | train_loss = train(model, input_data, output_data,
475 | learning_rate=learning_rate, epochs=epochs)
476 |
477 | self.assertTrue(train_loss < loss_threshold)
478 |
479 |
480 | if __name__ == '__main__':
481 | absltest.main()
482 |
--------------------------------------------------------------------------------
/contextual_lenses/train_utils.py:
--------------------------------------------------------------------------------
1 | """Train utils
2 |
3 | General tools for instantiating and training models.
4 | """
5 |
6 | import flax
7 | from flax import nn
8 | from flax import optim
9 | from flax.training import checkpoints
10 | from flax.training import common_utils
11 |
12 | import jax
13 | from jax import random
14 | import jax.nn
15 | import jax.numpy as jnp
16 | from jax.config import config
17 | config.enable_omnistaging()
18 |
19 | import tensorflow as tf
20 |
21 | import numpy as np
22 |
23 | import functools
24 |
25 | import copy
26 |
27 | from google_research.protein_lm import models
28 |
29 |
30 | # Data batching.
31 | def create_data_iterator(df,
32 | input_col,
33 | output_col,
34 | batch_size,
35 | epochs=1,
36 | buffer_size=None,
37 | seed=0,
38 | drop_remainder=False,
39 | add_outputs=True,
40 | as_numpy=True):
41 | """Creates iterator of batches of (inputs) or (inputs, outputs)."""
42 |
43 | if buffer_size is None:
44 | buffer_size = len(df)
45 |
46 | inputs = list(df[input_col].values)
47 | inputs = tf.data.Dataset.from_tensor_slices(inputs)
48 |
49 | outputs = df[output_col].values
50 | outputs = tf.data.Dataset.from_tensor_slices(outputs)
51 |
52 | if add_outputs:
53 | batches = tf.data.Dataset.zip(
54 | (inputs, outputs)).shuffle(buffer_size=buffer_size,
55 | seed=seed,
56 | reshuffle_each_iteration=True)
57 | else:
58 | batches = inputs.shuffle(buffer_size=buffer_size,
59 | seed=seed,
60 | reshuffle_each_iteration=True)
61 |
62 | batches = batches.repeat(epochs).batch(batch_size=batch_size,
63 | drop_remainder=drop_remainder)
64 |
65 | if as_numpy:
66 | batches = batches.as_numpy_iterator()
67 |
68 | return batches
69 |
70 |
71 | def path_inclusion_filter_fn(path, param, layer):
72 | """Returns whether or not layer name is contained in path."""
73 |
74 | return layer in path
75 |
76 |
77 | def create_optimizer(model, learning_rate, weight_decay, layers=None):
78 | """Instantiates Adam multi-optimizer."""
79 |
80 | if layers is None:
81 | assert (
82 | type(learning_rate) == type(weight_decay) == float
83 | ), 'Specify float values for moded learning rate and weight decay!'
84 | optimizer_def = optim.Adam(learning_rate=learning_rate,
85 | weight_decay=weight_decay)
86 | optimizer = optimizer_def.create(model)
87 |
88 | else:
89 | assert (
90 | len(learning_rate) == len(weight_decay) == len(layers)
91 | ), 'Number of specified learning rates, weight decays, and layers must be equal!'
92 | optimizers = []
93 | for lr, wd, layer in zip(learning_rate, weight_decay, layers):
94 | if lr > 0:
95 | opt = optim.Adam(learning_rate=lr, weight_decay=wd)
96 | filter_fn = functools.partial(path_inclusion_filter_fn,
97 | layer=layer)
98 | traversal = optim.ModelParamTraversal(filter_fn)
99 | traversal_opt = (traversal, opt)
100 | optimizers.append(traversal_opt)
101 | optimizer_def = optim.MultiOptimizer(*optimizers)
102 | optimizer = optimizer_def.create(model)
103 |
104 | return optimizer
105 |
106 |
107 | @functools.partial(jax.jit, static_argnums=(3, 4))
108 | def train_step(optimizer, X, Y, loss_fn, loss_fn_kwargs):
109 | """Trains model (optimizer.target) using specified loss function."""
110 | def compute_loss_fn(model, X, Y, loss_fn, loss_fn_kwargs):
111 | Y_hat = model(X)
112 | loss = loss_fn(Y, Y_hat, **loss_fn_kwargs)
113 | return loss
114 |
115 | grad_fn = jax.value_and_grad(compute_loss_fn)
116 | _, grad = grad_fn(optimizer.target, X, Y, loss_fn, loss_fn_kwargs)
117 | optimizer = optimizer.apply_gradient(grad)
118 |
119 | return optimizer
120 |
121 |
122 | def get_p_train_step():
123 | """Wraps train_step with jax.pmap."""
124 |
125 | p_train_step = jax.pmap(train_step,
126 | axis_name='batch',
127 | static_broadcasted_argnums=(3, 4))
128 |
129 | return p_train_step
130 |
131 |
132 | def train(model,
133 | train_data,
134 | loss_fn,
135 | loss_fn_kwargs,
136 | learning_rate=1e-4,
137 | weight_decay=0.1,
138 | layers=None,
139 | restore_dir=None,
140 | save_dir=None,
141 | use_pmap=False):
142 | """Instantiates optimizer, applies train_step/p_train_step over training data."""
143 |
144 | optimizer = create_optimizer(model,
145 | learning_rate=learning_rate,
146 | weight_decay=weight_decay,
147 | layers=layers)
148 |
149 | if restore_dir is not None:
150 | optimizer = checkpoints.restore_checkpoint(ckpt_dir=restore_dir,
151 | target=optimizer)
152 |
153 | if use_pmap:
154 | p_train_step = get_p_train_step()
155 | optimizer = optimizer.replicate()
156 |
157 | for batch in iter(train_data):
158 | X, Y = batch
159 | X, Y = common_utils.shard(X), common_utils.shard(Y)
160 | optimizer = p_train_step(optimizer, X, Y, loss_fn, loss_fn_kwargs)
161 |
162 | optimizer = optimizer.unreplicate()
163 |
164 | else:
165 | for batch in iter(train_data):
166 | X, Y = batch
167 | optimizer = train_step(optimizer, X, Y, loss_fn, loss_fn_kwargs)
168 |
169 | if save_dir is not None:
170 | state = optimizer.state
171 | if type(state) == list:
172 | step = [sub_state.step for sub_state in state]
173 | else:
174 | step = state.step
175 | checkpoints.save_checkpoint(ckpt_dir=save_dir,
176 | target=optimizer,
177 | step=step)
178 |
179 | return optimizer
180 |
181 |
182 | def load_params(params,
183 | encoder_fn_params=None,
184 | reduce_fn_params=None,
185 | predict_fn_params=None):
186 | """Updates randomly initialized parameters using loaded parameters."""
187 |
188 | loaded_params = copy.deepcopy(params)
189 | fn_names = list(loaded_params.keys())
190 |
191 | num_learnable_layers = len([
192 | params_dict for params_dict in
193 | [encoder_fn_params, reduce_fn_params, predict_fn_params]
194 | if params_dict is not None
195 | ])
196 | if encoder_fn_params is not None:
197 | encoder_fn_ind = '_0'
198 | if reduce_fn_params is not None:
199 | reduce_fn_ind = '_1'
200 | predict_fn_ind = '_2'
201 | else:
202 | predict_fn_ind = '_1'
203 | else:
204 | if reduce_fn_params is not None:
205 | reduce_fn_ind = '_0'
206 | predict_fn_ind = '_1'
207 | else:
208 | predict_fn_ind = '_0'
209 |
210 | assert (len(loaded_params.keys()) >= num_learnable_layers
211 | ), 'Model encoder and lens architecture incorrectly specified!'
212 |
213 | encoder_fn_name = None
214 | if encoder_fn_params is not None:
215 | for fn_name in fn_names:
216 | if encoder_fn_ind in fn_name:
217 | if encoder_fn_name is not None:
218 | raise ValueError(
219 | 'Multiple instances of encoder_fn detected. %s' %
220 | fn_name)
221 | encoder_fn_name = fn_name
222 | loaded_params[encoder_fn_name] = encoder_fn_params
223 |
224 | reduce_fn_name = None
225 | if reduce_fn_params is not None:
226 | for fn_name in fn_names:
227 | if reduce_fn_ind in fn_name:
228 | if reduce_fn_name is not None:
229 | raise ValueError(
230 | 'Multiple instances of reduce_fn detected. %s' %
231 | fn_name)
232 | reduce_fn_name = fn_name
233 | loaded_params[reduce_fn_name] = reduce_fn_params
234 |
235 | predict_fn_name = None
236 | if predict_fn_params is not None:
237 | for fn_name in fn_names:
238 | if predict_fn_ind in fn_name:
239 | if predict_fn_name is not None:
240 | raise ValueError(
241 | 'Multiple instances of predict_fn detected. %s' %
242 | fn_name)
243 | predict_fn_name = fn_name
244 | loaded_params[predict_fn_name] = predict_fn_params
245 |
246 | return loaded_params
247 |
248 |
249 | class RepresentationModel(nn.Module):
250 | def apply(self,
251 | x,
252 | encoder_fn,
253 | encoder_fn_kwargs,
254 | reduce_fn,
255 | reduce_fn_kwargs,
256 | num_categories,
257 | output_features,
258 | output='prediction',
259 | use_transformer=False,
260 | padding_mask=None):
261 | """Computes padding mask, encodes indices using embeddings,
262 | applies lensing operation, predicts scalar value.
263 | """
264 |
265 | outputs = dict()
266 |
267 | if padding_mask is None:
268 | padding_mask = jnp.expand_dims(jnp.where(x < num_categories - 1, 1,
269 | 0),
270 | axis=2)
271 |
272 | if not use_transformer:
273 | x = encoder_fn(x,
274 | num_categories=num_categories,
275 | **encoder_fn_kwargs)
276 | else:
277 | x = encoder_fn(x)
278 |
279 | rep = reduce_fn(x, padding_mask=padding_mask, **reduce_fn_kwargs)
280 |
281 | outputs['embedding'] = rep
282 |
283 | out = nn.Dense(rep,
284 | output_features,
285 | kernel_init=nn.initializers.xavier_uniform(),
286 | bias_init=nn.initializers.normal(stddev=1e-6))
287 |
288 | outputs['prediction'] = out
289 |
290 | return outputs[output]
291 |
292 |
293 | def create_representation_model(encoder_fn,
294 | encoder_fn_kwargs,
295 | reduce_fn,
296 | reduce_fn_kwargs,
297 | num_categories,
298 | output_features,
299 | output='prediction',
300 | key=random.PRNGKey(0),
301 | encoder_fn_params=None,
302 | reduce_fn_params=None,
303 | predict_fn_params=None):
304 | """Instantiates a RepresentationModel object."""
305 |
306 | module = RepresentationModel.partial(encoder_fn=encoder_fn,
307 | encoder_fn_kwargs=encoder_fn_kwargs,
308 | reduce_fn=reduce_fn,
309 | reduce_fn_kwargs=reduce_fn_kwargs,
310 | num_categories=num_categories,
311 | output_features=output_features,
312 | output=output,
313 | use_transformer=False)
314 |
315 | _, initial_params = RepresentationModel.init_by_shape(
316 | key,
317 | input_specs=[((1, 1), jnp.float32)],
318 | encoder_fn=encoder_fn,
319 | encoder_fn_kwargs=encoder_fn_kwargs,
320 | reduce_fn=reduce_fn,
321 | reduce_fn_kwargs=reduce_fn_kwargs,
322 | num_categories=num_categories,
323 | output_features=output_features,
324 | output=output,
325 | use_transformer=False)
326 |
327 | loaded_params = load_params(initial_params, encoder_fn_params,
328 | reduce_fn_params, predict_fn_params)
329 |
330 | model = nn.Model(module, loaded_params)
331 |
332 | return model
333 |
334 |
335 | def create_transformer_representation_model(transformer_kwargs,
336 | reduce_fn,
337 | reduce_fn_kwargs,
338 | num_categories,
339 | output_features,
340 | bidirectional=False,
341 | output='prediction',
342 | key=random.PRNGKey(0),
343 | encoder_fn_params=None,
344 | reduce_fn_params=None,
345 | predict_fn_params=None):
346 | """Instantiates a RepresentationModel object with Transformer encoder."""
347 |
348 | if not bidirectional:
349 | transformer = models.FlaxLM(**transformer_kwargs)
350 | else:
351 | transformer = models.FlaxBERT(**transformer_kwargs)
352 | transformer_optimizer = transformer._optimizer
353 | transformer_model = models.jax_utils.unreplicate(
354 | transformer_optimizer.target)
355 | transformer_encoder = transformer_model.module.partial(
356 | output_head='output_emb')
357 |
358 | module = RepresentationModel.partial(encoder_fn=transformer_encoder,
359 | encoder_fn_kwargs={},
360 | reduce_fn=reduce_fn,
361 | reduce_fn_kwargs=reduce_fn_kwargs,
362 | num_categories=num_categories,
363 | output_features=output_features,
364 | output=output,
365 | use_transformer=True)
366 |
367 | _, initial_params = RepresentationModel.init_by_shape(
368 | key,
369 | input_specs=[((1, 1), jnp.float32)],
370 | encoder_fn=transformer_encoder,
371 | encoder_fn_kwargs={},
372 | reduce_fn=reduce_fn,
373 | reduce_fn_kwargs=reduce_fn_kwargs,
374 | num_categories=num_categories,
375 | output_features=output_features,
376 | output=output,
377 | use_transformer=True)
378 |
379 | loaded_params = load_params(initial_params, encoder_fn_params,
380 | reduce_fn_params, predict_fn_params)
381 |
382 | model = nn.Model(module, loaded_params)
383 |
384 | return model
385 |
386 |
387 | def architecture_to_layers(encoder_fn_name, reduce_fn_name):
388 |
389 | layers = []
390 |
391 | no_trainable_encoder = False
392 | if encoder_fn_name is None or encoder_fn_name == 'transformer':
393 | layers.append('Transformer_0')
394 | elif encoder_fn_name == 'one_hot':
395 | no_trainable_encoder = True
396 | elif encoder_fn_name == 'cnn_one_hot':
397 | layers.append('CNN_0')
398 | else:
399 | raise ValueError('Incorrect encoder name specified.')
400 |
401 | no_trainable_lens = False
402 | if reduce_fn_name == 'mean_pool' or reduce_fn_name == 'max_pool':
403 | no_trainable_lens = True
404 | elif reduce_fn_name == 'linear_mean_pool' or reduce_fn_name == 'linear_max_pool':
405 | if no_trainable_encoder:
406 | layers.append('Dense_0')
407 | else:
408 | layers.append('Dense_1')
409 | elif reduce_fn_name == 'gated_conv':
410 | if no_trainable_encoder:
411 | layers.append('GatedConv_0')
412 | else:
413 | layers.append('GatedConv_1')
414 | else:
415 | raise ValueError('Incorrect lens name specified.')
416 |
417 | if no_trainable_encoder:
418 | if no_trainable_lens:
419 | layers.append('Dense_0')
420 | else:
421 | layers.append('Dense_1')
422 | else:
423 | if no_trainable_lens:
424 | layers.append('Dense_1')
425 | else:
426 | layers.append('Dense_2')
427 |
428 | trainable_encoder = not no_trainable_encoder
429 |
430 | return layers, trainable_encoder
431 |
--------------------------------------------------------------------------------
/pfam_experiment.py:
--------------------------------------------------------------------------------
1 | """Lens training + nearest neighbors classification pipeline."""
2 |
3 | import os
4 |
5 | import sys
6 | sys.path.insert(1, 'google_research/')
7 |
8 | import flax
9 | from flax import nn
10 | from flax.training import checkpoints
11 |
12 | import jax
13 | from jax import random
14 | import jax.numpy as jnp
15 | from jax.config import config
16 | config.enable_omnistaging()
17 |
18 | import numpy as np
19 |
20 | import pandas as pd
21 |
22 | import json
23 |
24 | import copy
25 |
26 | import time
27 |
28 | from pkg_resources import resource_filename
29 |
30 | from fs_gcsfs import GCSFS
31 |
32 | from google_research.protein_lm import domains, models
33 |
34 | from contextual_lenses.contextual_lenses import reduce_fn_name_to_fn
35 |
36 | from contextual_lenses.train_utils import create_optimizer, train, \
37 | create_representation_model, create_transformer_representation_model, \
38 | architecture_to_layers
39 |
40 | from contextual_lenses.encoders import encoder_fn_name_to_fn
41 |
42 | from contextual_lenses.loss_fns import cross_entropy_loss
43 |
44 | from contextual_lenses.pfam_utils import get_family_ids, PFAM_NUM_CATEGORIES, \
45 | pfam_evaluate, create_pfam_batches, pfam_nearest_neighbors_classification
46 |
47 | from contextual_lenses.load_transformer import load_transformer_params
48 |
49 | from absl import app, flags
50 |
51 | # Define flags.
52 | FLAGS = flags.FLAGS
53 |
54 | flags.DEFINE_string('encoder_fn_name', 'cnn_one_hot',
55 | 'Name of encoder_fn to use. None if using Transformer.')
56 | flags.DEFINE_string('encoder_fn_kwargs_path', 'cnn_kwargs',
57 | 'Path to encoder_fn_kwargs.')
58 | flags.DEFINE_string('reduce_fn_name', 'linear_max_pool',
59 | 'Name of reduce_fn to use.')
60 | flags.DEFINE_string('reduce_fn_kwargs_path', 'linear_pool_1024',
61 | 'Path to reduce_fn_kwargs.')
62 |
63 | flags.DEFINE_integer('epochs', 10, 'Number of epochs for lens training.')
64 | flags.DEFINE_integer(
65 | 'measurements', 1,
66 | 'Number of times to interrupt lens training loop to take measurements (1 = no interruption).'
67 | )
68 | flags.DEFINE_integer('lens_batch_size', 64, 'Batch size for lens training.')
69 | flags.DEFINE_integer('knn_batch_size', 64,
70 | 'Batch size for KNN vector computation.')
71 |
72 | flags.DEFINE_float('encoder_lr', 0.0, 'Encoder learning rate.')
73 | flags.DEFINE_float('lens_lr', 1e-5, 'Lens learning rate.')
74 | flags.DEFINE_float('predictor_lr', 1e-3, 'Predictor learning rate.')
75 |
76 | flags.DEFINE_float('encoder_wd', 0.0, 'Encoder weight decay.')
77 | flags.DEFINE_float('lens_wd', 0.0, 'Lens weight decay.')
78 | flags.DEFINE_float('predictor_wd', 0.0, 'Predictor weight decay.')
79 |
80 | flags.DEFINE_integer('train_families', 10000,
81 | 'Number of famlies used to train lens.')
82 | flags.DEFINE_integer('lens_train_samples', 50,
83 | 'Number of samples used to train lens.')
84 | flags.DEFINE_integer('first_test_family', 15001, 'First family to test on.')
85 | flags.DEFINE_integer('last_test_family', 16000, 'Last family to test on.')
86 |
87 | flags.DEFINE_integer('lens_shuffle_seed', 0,
88 | 'Random seed used for lens training data batching.')
89 | flags.DEFINE_integer('lens_sample_random_state', 0,
90 | 'Random state used for lens training data sampling.')
91 | flags.DEFINE_integer('knn_shuffle_seed', 1,
92 | 'Random seed used for KNN data batching.')
93 | flags.DEFINE_integer('knn_sample_random_state', 1,
94 | 'Random state used for KNN data sampling.')
95 | flags.DEFINE_integer('model_random_key', 0,
96 | 'Random key used for model instantiation.')
97 |
98 | flags.DEFINE_boolean('use_transformer', False,
99 | 'Whether or not to use transformer encoder')
100 | flags.DEFINE_boolean('use_bert', False,
101 | 'Whether or not to use bidirectional transformer.')
102 | flags.DEFINE_string('restore_transformer_dir', None,
103 | 'Directory to load pretrained transformer from.')
104 |
105 | flags.DEFINE_string('load_gcs_bucket', 'neuralblast_public',
106 | 'GCS bucket to load from.')
107 | flags.DEFINE_string('data_partitions_dirpath', 'random_split/',
108 | 'Location of Pfam data in load GCS bucket.')
109 |
110 | flags.DEFINE_string('save_gcs_bucket', 'sequin-public',
111 | 'GCS bucket to save to.')
112 | flags.DEFINE_string('results_save_dir', '',
113 | 'Directory in save GCS bucket to save to.')
114 |
115 | flags.DEFINE_boolean('load_model', False,
116 | 'Whether or not to load a trained model.')
117 | flags.DEFINE_string('load_model_dir', '',
118 | 'Directory in load GCS bucket to load trained optimizer from.')
119 | flags.DEFINE_integer(
120 | 'load_model_step', 0,
121 | 'Number of steps optimizer to be loaded has been trained for.')
122 |
123 | flags.DEFINE_boolean('save_model', False,
124 | 'Whether or not to save trained model.')
125 | flags.DEFINE_string('save_model_dir', '',
126 | 'Directory in save GCS bucket to save trained optimizer to.')
127 |
128 | flags.DEFINE_string('label', '', 'Label used to save experiment results.')
129 |
130 |
131 | def get_model_kwargs(encoder_fn_name, encoder_fn_kwargs_path, reduce_fn_name,
132 | reduce_fn_kwargs_path):
133 | """Determines model components using string names."""
134 |
135 | encoder_fn = encoder_fn_name_to_fn(encoder_fn_name)
136 | encoder_fn_kwargs = json.load(
137 | open(
138 | resource_filename(
139 | 'contextual_lenses.resources',
140 | os.path.join('encoder_fn_kwargs_resources',
141 | encoder_fn_kwargs_path + '.json'))))
142 |
143 | reduce_fn = reduce_fn_name_to_fn(reduce_fn_name)
144 | reduce_fn_kwargs = json.load(
145 | open(
146 | resource_filename(
147 | 'contextual_lenses.resources',
148 | os.path.join('reduce_fn_kwargs_resources',
149 | reduce_fn_kwargs_path + '.json'))))
150 |
151 | layers, trainable_encoder = architecture_to_layers(encoder_fn_name,
152 | reduce_fn_name)
153 |
154 | return encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs, layers
155 |
156 |
157 | def create_model(encoder_fn,
158 | encoder_fn_kwargs,
159 | reduce_fn,
160 | reduce_fn_kwargs,
161 | layers,
162 | output='prediction',
163 | use_transformer=False,
164 | use_bert=False,
165 | restore_transformer_dir=None,
166 | encoder_fn_params=None,
167 | reduce_fn_params=None,
168 | predict_fn_params=None,
169 | random_key=0):
170 | """Creates representation model (encoder --> lens --> predictor) architecture."""
171 |
172 | family_ids = get_family_ids()
173 | num_families = len(family_ids)
174 |
175 | if use_transformer:
176 |
177 | if use_bert:
178 | model_cls = models.FlaxBERT
179 | else:
180 | model_cls = models.FlaxLM
181 |
182 | if encoder_fn_params is not None:
183 | pretrained_transformer_params = encoder_fn_params
184 | else:
185 | if restore_transformer_dir is not None:
186 | pretrained_transformer_params = load_transformer_params(
187 | restore_transformer_dir, model_cls)
188 | else:
189 | pretrained_transformer_params = None
190 |
191 | model = create_transformer_representation_model(
192 | transformer_kwargs=encoder_fn_kwargs,
193 | reduce_fn=reduce_fn,
194 | reduce_fn_kwargs=reduce_fn_kwargs,
195 | num_categories=PFAM_NUM_CATEGORIES,
196 | output_features=num_families,
197 | bidirectional=use_bert,
198 | output=output,
199 | key=random.PRNGKey(random_key),
200 | encoder_fn_params=pretrained_transformer_params,
201 | reduce_fn_params=reduce_fn_params,
202 | predict_fn_params=predict_fn_params)
203 |
204 | else:
205 | model = create_representation_model(
206 | encoder_fn=encoder_fn,
207 | encoder_fn_kwargs=encoder_fn_kwargs,
208 | reduce_fn=reduce_fn,
209 | reduce_fn_kwargs=reduce_fn_kwargs,
210 | num_categories=PFAM_NUM_CATEGORIES,
211 | output_features=num_families,
212 | output=output,
213 | key=random.PRNGKey(random_key),
214 | encoder_fn_params=encoder_fn_params,
215 | reduce_fn_params=reduce_fn_params,
216 | predict_fn_params=predict_fn_params)
217 |
218 | return model
219 |
220 |
221 | def set_model_parameters(model, params):
222 | """Updates a model's parameters using a parameters dictionary."""
223 |
224 | params = copy.deepcopy(params)
225 |
226 | assert (
227 | model.params.keys() == params.keys()), 'Model parameters do not match!'
228 |
229 | for layer in model.params.keys():
230 | model.params[layer] = params[layer]
231 |
232 | return model
233 |
234 |
235 | def measure_nearest_neighbor_performance(accuracy_label, encoder,
236 | family_accessions, batch_size,
237 | train_samples, shuffle_seed,
238 | sample_random_state):
239 | """Measures nearest neighbor classification performance and updates datum."""
240 |
241 | results = pfam_nearest_neighbors_classification(
242 | encoder=encoder,
243 | family_accessions=family_accessions,
244 | batch_size=batch_size,
245 | train_samples=train_samples,
246 | shuffle_seed=shuffle_seed,
247 | sample_random_state=sample_random_state,
248 | data_partitions_dirpath=FLAGS.data_partitions_dirpath,
249 | gcs_bucket=FLAGS.load_gcs_bucket)[0]
250 |
251 | accuracy = results['1-nn accuracy']
252 |
253 | accuracy_dict = {accuracy_label: accuracy}
254 |
255 | return accuracy_dict
256 |
257 |
258 | # Train lens and measure performance of lens and nearest neighbors classifier.
259 | def main(_):
260 |
261 | if FLAGS.use_transformer:
262 | assert (
263 | FLAGS.encoder_fn_name == 'transformer'
264 | ), 'encoder_fn_name must be transformer if use_transformer is True!'
265 |
266 | assert (FLAGS.epochs % FLAGS.measurements == 0
267 | ), 'Number of measurements must divide number of epochs!'
268 | measurement_epochs = FLAGS.epochs // FLAGS.measurements
269 |
270 | assert FLAGS.results_save_dir != '', 'Specify results_save_dir!'
271 |
272 | assert FLAGS.label != '', 'Specify label!'
273 |
274 | if FLAGS.load_model:
275 | assert FLAGS.load_model_dir != '', 'Specify load_model_dir!'
276 | assert FLAGS.load_model_step > 0, 'Loaded model must have been trained for more than 0 steps.'
277 |
278 | if FLAGS.save_model:
279 | assert FLAGS.save_model_dir != '', 'Specify save_model_dir!'
280 |
281 | datum = {
282 | 'label': FLAGS.label,
283 | 'encoder_fn_name': FLAGS.encoder_fn_name,
284 | 'encoder_fn_kwargs_path': FLAGS.encoder_fn_kwargs_path,
285 | 'reduce_fn_name': FLAGS.reduce_fn_name,
286 | 'reduce_fn_kwargs_path': FLAGS.reduce_fn_kwargs_path,
287 | 'epochs': FLAGS.epochs,
288 | 'measurements': FLAGS.measurements,
289 | 'lens_batch_size': FLAGS.lens_batch_size,
290 | 'knn_batch_size': FLAGS.knn_batch_size,
291 | 'encoder_lr': FLAGS.encoder_lr,
292 | 'lens_lr': FLAGS.lens_lr,
293 | 'predictor_lr': FLAGS.predictor_lr,
294 | 'encoder_wd': FLAGS.encoder_wd,
295 | 'lens_wd': FLAGS.lens_wd,
296 | 'predictor_wd': FLAGS.predictor_wd,
297 | 'train_families': FLAGS.train_families,
298 | 'lens_train_samples': FLAGS.lens_train_samples,
299 | 'first_test_family': FLAGS.first_test_family,
300 | 'last_test_family': FLAGS.last_test_family,
301 | 'lens_shuffle_seed': FLAGS.lens_shuffle_seed,
302 | 'lens_sample_random_state': FLAGS.lens_sample_random_state,
303 | 'knn_shuffle_seed': FLAGS.knn_shuffle_seed,
304 | 'knn_sample_random_state': FLAGS.knn_sample_random_state,
305 | 'model_random_key': FLAGS.model_random_key,
306 | 'use_transformer': FLAGS.use_transformer,
307 | 'use_bert': FLAGS.use_bert,
308 | 'restore_transformer_dir': FLAGS.restore_transformer_dir,
309 | 'load_gcs_bucket': FLAGS.load_gcs_bucket,
310 | 'data_partitions_dirpath': FLAGS.data_partitions_dirpath,
311 | 'save_gcs_bucket': FLAGS.save_gcs_bucket,
312 | 'results_save_dir': FLAGS.results_save_dir,
313 | 'load_model': FLAGS.load_model,
314 | 'load_model_dir': FLAGS.load_model_dir,
315 | 'load_model_step': FLAGS.load_model_step,
316 | 'save_model': FLAGS.save_model,
317 | 'save_model_dir': FLAGS.save_model_dir
318 | }
319 |
320 | gcsfs = GCSFS(FLAGS.save_gcs_bucket)
321 |
322 | print(datum)
323 | df = pd.DataFrame([datum])
324 | with gcsfs.open(os.path.join(FLAGS.results_save_dir, FLAGS.label + '.csv'),
325 | 'w') as gcs_file:
326 | df.to_csv(gcs_file, index=False)
327 |
328 | knn_train_samples_ = [1, 5, 10, 50]
329 |
330 | family_ids = get_family_ids()
331 | num_families = len(family_ids)
332 | loss_fn_kwargs = {'num_classes': num_families}
333 |
334 | lens_knn_train_family_accessions = []
335 | for _ in range(1, FLAGS.train_families + 1):
336 | family_name = 'PF%05d' % _
337 | lens_knn_train_family_accessions.append(family_name)
338 |
339 | knn_test_family_accessions = []
340 | for _ in range(FLAGS.first_test_family, FLAGS.last_test_family + 1):
341 | family_name = 'PF%05d' % _
342 | knn_test_family_accessions.append(family_name)
343 |
344 | encoder_fn, encoder_fn_kwargs, reduce_fn, reduce_fn_kwargs, layers = get_model_kwargs(
345 | encoder_fn_name=FLAGS.encoder_fn_name,
346 | encoder_fn_kwargs_path=FLAGS.encoder_fn_kwargs_path,
347 | reduce_fn_name=FLAGS.reduce_fn_name,
348 | reduce_fn_kwargs_path=FLAGS.reduce_fn_kwargs_path)
349 |
350 | embedding_model = create_model(
351 | encoder_fn=encoder_fn,
352 | encoder_fn_kwargs=encoder_fn_kwargs,
353 | reduce_fn=reduce_fn,
354 | reduce_fn_kwargs=reduce_fn_kwargs,
355 | layers=layers,
356 | output='embedding',
357 | use_transformer=FLAGS.use_transformer,
358 | use_bert=FLAGS.use_bert,
359 | restore_transformer_dir=FLAGS.restore_transformer_dir,
360 | random_key=FLAGS.model_random_key)
361 |
362 | datum.update(
363 | measure_nearest_neighbor_performance(
364 | accuracy_label=
365 | 'train_knn_accuracy_untrained_lens_1_knn_train_samples',
366 | encoder=embedding_model,
367 | family_accessions=lens_knn_train_family_accessions,
368 | batch_size=FLAGS.knn_batch_size,
369 | train_samples=1,
370 | shuffle_seed=FLAGS.knn_shuffle_seed,
371 | sample_random_state=FLAGS.knn_sample_random_state))
372 |
373 | for knn_train_samples in knn_train_samples_:
374 |
375 | datum.update(
376 | measure_nearest_neighbor_performance(
377 | accuracy_label='test_knn_accuracy_untrained_lens_' +
378 | str(knn_train_samples) + '_knn_train_samples',
379 | encoder=embedding_model,
380 | family_accessions=knn_test_family_accessions,
381 | batch_size=FLAGS.knn_batch_size,
382 | train_samples=knn_train_samples,
383 | shuffle_seed=FLAGS.knn_shuffle_seed,
384 | sample_random_state=FLAGS.knn_sample_random_state))
385 |
386 | model = create_model(encoder_fn=encoder_fn,
387 | encoder_fn_kwargs=encoder_fn_kwargs,
388 | reduce_fn=reduce_fn,
389 | reduce_fn_kwargs=reduce_fn_kwargs,
390 | layers=layers,
391 | output='prediction',
392 | use_transformer=FLAGS.use_transformer,
393 | use_bert=FLAGS.use_bert,
394 | restore_transformer_dir=FLAGS.restore_transformer_dir,
395 | random_key=FLAGS.model_random_key)
396 |
397 | optimizer = create_optimizer(
398 | model=model,
399 | learning_rate=[FLAGS.encoder_lr, FLAGS.lens_lr, FLAGS.predictor_lr],
400 | weight_decay=[FLAGS.encoder_wd, FLAGS.lens_wd, FLAGS.predictor_wd],
401 | layers=layers)
402 |
403 | if FLAGS.load_model:
404 | optimizer = checkpoints.restore_checkpoint(ckpt_dir=os.path.join(
405 | 'gs://' + FLAGS.load_gcs_bucket, FLAGS.load_model_dir),
406 | target=optimizer,
407 | step=FLAGS.load_model_step)
408 |
409 | trained_params = optimizer.target.params
410 | embedding_model = set_model_parameters(model=embedding_model,
411 | params=trained_params)
412 |
413 | if FLAGS.save_model:
414 | checkpoints.save_checkpoint(ckpt_dir=os.path.join(
415 | 'gs://' + FLAGS.save_gcs_bucket, FLAGS.save_model_dir),
416 | target=optimizer,
417 | step=FLAGS.load_model_step)
418 |
419 | for i in range(FLAGS.measurements):
420 |
421 | train_batches, train_indexes = create_pfam_batches(
422 | family_accessions=lens_knn_train_family_accessions,
423 | batch_size=FLAGS.lens_batch_size,
424 | samples=FLAGS.lens_train_samples,
425 | epochs=measurement_epochs,
426 | drop_remainder=True,
427 | shuffle_seed=FLAGS.lens_shuffle_seed + i,
428 | sample_random_state=FLAGS.lens_sample_random_state)
429 |
430 | optimizer = train(
431 | model=optimizer.target,
432 | train_data=train_batches,
433 | loss_fn=cross_entropy_loss,
434 | loss_fn_kwargs=loss_fn_kwargs,
435 | learning_rate=[
436 | FLAGS.encoder_lr, FLAGS.lens_lr, FLAGS.predictor_lr
437 | ],
438 | weight_decay=[FLAGS.encoder_wd, FLAGS.lens_wd, FLAGS.predictor_wd],
439 | layers=layers)
440 |
441 | results, preds = pfam_evaluate(
442 | predict_fn=optimizer.target,
443 | test_family_accessions=lens_knn_train_family_accessions,
444 | title=None,
445 | loss_fn_kwargs=loss_fn_kwargs,
446 | batch_size=FLAGS.lens_batch_size,
447 | data_partitions_dirpath=FLAGS.data_partitions_dirpath,
448 | gcs_bucket=FLAGS.load_gcs_bucket)
449 |
450 | lens_accuracy = results['accuracy']
451 | datum['lens_accuracy' + '_measurement_' + str(i)] = lens_accuracy
452 |
453 | lens_cross_entropy = float(results['cross_entropy'])
454 | datum['lens_cross_entropy' + '_measurement_' +
455 | str(i)] = lens_cross_entropy
456 |
457 | trained_params = optimizer.target.params
458 | embedding_model = set_model_parameters(model=embedding_model,
459 | params=trained_params)
460 |
461 | datum.update(
462 | measure_nearest_neighbor_performance(
463 | accuracy_label=
464 | 'train_knn_accuracy_trained_lens_1_knn_train_samples' +
465 | '_measurement_' + str(i),
466 | encoder=embedding_model,
467 | family_accessions=lens_knn_train_family_accessions,
468 | batch_size=FLAGS.knn_batch_size,
469 | train_samples=1,
470 | shuffle_seed=FLAGS.knn_shuffle_seed,
471 | sample_random_state=FLAGS.knn_sample_random_state))
472 |
473 | for knn_train_samples in knn_train_samples_:
474 |
475 | datum.update(
476 | measure_nearest_neighbor_performance(
477 | accuracy_label='test_knn_accuracy_trained_lens_' +
478 | str(knn_train_samples) + '_knn_train_samples' +
479 | '_measurement_' + str(i),
480 | encoder=embedding_model,
481 | family_accessions=knn_test_family_accessions,
482 | batch_size=FLAGS.knn_batch_size,
483 | train_samples=knn_train_samples,
484 | shuffle_seed=FLAGS.knn_shuffle_seed,
485 | sample_random_state=FLAGS.knn_sample_random_state))
486 |
487 | print(datum)
488 | df = pd.DataFrame([datum])
489 | with gcsfs.open(os.path.join(FLAGS.results_save_dir, FLAGS.label + '.csv'),
490 | 'w') as gcs_file:
491 | df.to_csv(gcs_file, index=False)
492 |
493 | if FLAGS.save_model:
494 | checkpoints.save_checkpoint(ckpt_dir=os.path.join(
495 | 'gs://' + FLAGS.save_gcs_bucket, FLAGS.save_model_dir),
496 | target=optimizer,
497 | step=FLAGS.load_model_step + FLAGS.epochs)
498 |
499 |
500 | if __name__ == '__main__':
501 | app.run(main)
502 |
--------------------------------------------------------------------------------
/generate_params.py:
--------------------------------------------------------------------------------
1 | """Generate JSON files with hyperparameter combos for caliban."""
2 |
3 | import json
4 |
5 | import itertools
6 |
7 | import os
8 |
9 | import numpy as np
10 |
11 | from frozendict import frozendict
12 |
13 |
14 | def create_params(encoder_lrs,
15 | lens_lrs,
16 | predictor_lrs,
17 | encoder_wds,
18 | lens_wds,
19 | predictor_wds,
20 | reduce_fn_kwargs_paths,
21 | lens_train_samples,
22 | train_families,
23 | epochs,
24 | measurements,
25 | encoder_fn_name,
26 | encoder_fn_kwargs_path,
27 | reduce_fn_name,
28 | lens_batch_size=64,
29 | knn_batch_size=64,
30 | use_transformer=False,
31 | use_bert=False,
32 | restore_transformer_dir=None,
33 | model_random_keys=[0],
34 | first_test_family=15000,
35 | last_test_family=16000,
36 | load_gcs_bucket='neuralblast_public',
37 | data_partitions_dirpath='random_split/',
38 | save_gcs_bucket='sequin-public',
39 | results_save_dir='pfam_experiment_data',
40 | lens_shuffle_seed=0,
41 | lens_sample_random_state=0,
42 | knn_shuffle_seed=1,
43 | knn_sample_random_state=1,
44 | load_model=False,
45 | load_model_dir='',
46 | load_model_step=0,
47 | save_model=False,
48 | save_model_dir=''):
49 | """Generates parameters from lists of parameters."""
50 |
51 | params = []
52 |
53 | for encoder_lr, lens_lr, predictor_lr, encoder_wd, lens_wd, predictor_wd, reduce_fn_kwargs_path, lens_train_samples, model_random_key in \
54 | itertools.product(encoder_lrs, lens_lrs, predictor_lrs, encoder_wds, lens_wds, predictor_wds, reduce_fn_kwargs_paths, lens_train_samples, model_random_keys):
55 |
56 | param_dict = {
57 | 'encoder_fn_name': encoder_fn_name,
58 | 'encoder_fn_kwargs_path': encoder_fn_kwargs_path,
59 | 'reduce_fn_name': reduce_fn_name,
60 | 'reduce_fn_kwargs_path': reduce_fn_kwargs_path,
61 | 'epochs': epochs,
62 | 'measurements': measurements,
63 | 'lens_batch_size': lens_batch_size,
64 | 'knn_batch_size': knn_batch_size,
65 | 'encoder_lr': encoder_lr,
66 | 'lens_lr': lens_lr,
67 | 'predictor_lr': predictor_lr,
68 | 'encoder_wd': encoder_wd,
69 | 'lens_wd': lens_wd,
70 | 'predictor_wd': predictor_wd,
71 | 'train_families': train_families,
72 | 'lens_train_samples': lens_train_samples,
73 | 'first_test_family': first_test_family,
74 | 'last_test_family': last_test_family,
75 | 'lens_shuffle_seed': lens_shuffle_seed,
76 | 'lens_sample_random_state': lens_sample_random_state,
77 | 'knn_shuffle_seed': knn_shuffle_seed,
78 | 'knn_sample_random_state': knn_sample_random_state,
79 | 'model_random_key': model_random_key,
80 | 'use_transformer': use_transformer,
81 | 'use_bert': use_bert,
82 | 'load_gcs_bucket': load_gcs_bucket,
83 | 'data_partitions_dirpath': data_partitions_dirpath,
84 | 'save_gcs_bucket': save_gcs_bucket,
85 | 'results_save_dir': results_save_dir,
86 | 'load_model': load_model,
87 | 'load_model_dir': load_model_dir,
88 | 'load_model_step': load_model_step,
89 | 'save_model': save_model,
90 | 'save_model_dir': save_model_dir
91 | }
92 |
93 | if restore_transformer_dir is not None:
94 | param_dict['restore_transformer_dir'] = restore_transformer_dir
95 |
96 | params.append(param_dict)
97 |
98 | return params
99 |
100 |
101 | # Generate parameters from different sets of parameter combinations.
102 | def main(load_gcs_bucket, save_gcs_bucket):
103 |
104 | params = []
105 |
106 | # 1000 train families sweep
107 | params += create_params(
108 | encoder_lrs=[0.0],
109 | lens_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
110 | predictor_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
111 | encoder_wds=[0.0],
112 | lens_wds=[0.0, 0.1],
113 | predictor_wds=[0.0, 0.1],
114 | reduce_fn_kwargs_paths=['linear_pool_256', 'linear_pool_1024'],
115 | lens_train_samples=[50],
116 | train_families=1000,
117 | epochs=50,
118 | measurements=1,
119 | encoder_fn_name='transformer',
120 | encoder_fn_kwargs_path='medium_transformer_kwargs',
121 | reduce_fn_name='linear_max_pool',
122 | lens_batch_size=64,
123 | knn_batch_size=64,
124 | use_transformer=True,
125 | use_bert=True,
126 | restore_transformer_dir=None,
127 | load_gcs_bucket=load_gcs_bucket,
128 | save_gcs_bucket=save_gcs_bucket)
129 |
130 | params += create_params(
131 | encoder_lrs=[0.0],
132 | lens_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
133 | predictor_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
134 | encoder_wds=[0.0],
135 | lens_wds=[0.0, 0.1],
136 | predictor_wds=[0.0, 0.1],
137 | reduce_fn_kwargs_paths=['linear_pool_256', 'linear_pool_1024'],
138 | lens_train_samples=[50],
139 | train_families=1000,
140 | epochs=50,
141 | measurements=1,
142 | encoder_fn_name='transformer',
143 | encoder_fn_kwargs_path='medium_transformer_kwargs',
144 | reduce_fn_name='linear_max_pool',
145 | lens_batch_size=64,
146 | knn_batch_size=64,
147 | use_transformer=True,
148 | use_bert=True,
149 | restore_transformer_dir=
150 | 'gs://sequin-public/transformer_models/medium_trembl_bert/',
151 | load_gcs_bucket=load_gcs_bucket,
152 | save_gcs_bucket=save_gcs_bucket)
153 |
154 | params += create_params(
155 | encoder_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
156 | lens_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
157 | predictor_lrs=[1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
158 | encoder_wds=[0.0, 0.1],
159 | lens_wds=[0.0, 0.1],
160 | predictor_wds=[0.0, 0.1],
161 | reduce_fn_kwargs_paths=['linear_pool_256', 'linear_pool_1024'],
162 | lens_train_samples=[50],
163 | train_families=1000,
164 | epochs=50,
165 | measurements=1,
166 | encoder_fn_name='cnn_one_hot',
167 | encoder_fn_kwargs_path='1-layer_cnn_kwargs',
168 | reduce_fn_name='linear_max_pool',
169 | lens_batch_size=256,
170 | knn_batch_size=256,
171 | use_transformer=False,
172 | use_bert=False,
173 | restore_transformer_dir=None,
174 | load_gcs_bucket=load_gcs_bucket,
175 | save_gcs_bucket=save_gcs_bucket)
176 |
177 | # 10000 train families sweep
178 | params += create_params(encoder_lrs=[0.0],
179 | lens_lrs=[1e-3, 5e-4, 1e-4, 5e-5],
180 | predictor_lrs=[1e-3, 5e-4, 1e-4, 5e-5],
181 | encoder_wds=[0.0],
182 | lens_wds=[0.0, 0.05, 0.1, 0.2],
183 | predictor_wds=[0.0, 0.05, 0.1, 0.2],
184 | reduce_fn_kwargs_paths=['linear_pool_1024'],
185 | lens_train_samples=[50],
186 | train_families=10000,
187 | epochs=10,
188 | measurements=2,
189 | encoder_fn_name='transformer',
190 | encoder_fn_kwargs_path='medium_transformer_kwargs',
191 | reduce_fn_name='linear_max_pool',
192 | lens_batch_size=64,
193 | knn_batch_size=64,
194 | use_transformer=True,
195 | use_bert=True,
196 | restore_transformer_dir=None,
197 | load_gcs_bucket=load_gcs_bucket,
198 | save_gcs_bucket=save_gcs_bucket)
199 |
200 | params += create_params(
201 | encoder_lrs=[0.0],
202 | lens_lrs=[1e-3, 5e-4, 1e-4, 5e-5],
203 | predictor_lrs=[1e-3, 5e-4, 1e-4, 5e-5],
204 | encoder_wds=[0.0],
205 | lens_wds=[0.0, 0.05, 0.1, 0.2],
206 | predictor_wds=[0.0, 0.05, 0.1, 0.2],
207 | reduce_fn_kwargs_paths=['linear_pool_1024'],
208 | lens_train_samples=[50],
209 | train_families=10000,
210 | epochs=10,
211 | measurements=2,
212 | encoder_fn_name='transformer',
213 | encoder_fn_kwargs_path='medium_transformer_kwargs',
214 | reduce_fn_name='linear_max_pool',
215 | lens_batch_size=64,
216 | knn_batch_size=64,
217 | use_transformer=True,
218 | use_bert=True,
219 | restore_transformer_dir=
220 | 'gs://sequin-public/transformer_models/medium_trembl_bert/',
221 | load_gcs_bucket=load_gcs_bucket,
222 | save_gcs_bucket=save_gcs_bucket)
223 |
224 | params += create_params(encoder_lrs=[1e-3, 1e-4, 1e-5],
225 | lens_lrs=[1e-3, 1e-4, 1e-5],
226 | predictor_lrs=[1e-3, 1e-4, 1e-5],
227 | encoder_wds=[0.0, 0.1, 0.2],
228 | lens_wds=[0.0, 0.1, 0.2],
229 | predictor_wds=[0.0, 0.1, 0.2],
230 | reduce_fn_kwargs_paths=['linear_pool_1024'],
231 | lens_train_samples=[50],
232 | train_families=10000,
233 | epochs=10,
234 | measurements=2,
235 | encoder_fn_name='cnn_one_hot',
236 | encoder_fn_kwargs_path='2-layer_cnn_kwargs',
237 | reduce_fn_name='linear_max_pool',
238 | lens_batch_size=512,
239 | knn_batch_size=512,
240 | use_transformer=False,
241 | use_bert=False,
242 | restore_transformer_dir=None,
243 | load_gcs_bucket=load_gcs_bucket,
244 | save_gcs_bucket=save_gcs_bucket)
245 |
246 | params += create_params(encoder_lrs=[0.0],
247 | lens_lrs=[1e-4, 5e-5, 1e-5],
248 | predictor_lrs=[1e-3, 5e-4, 1e-4],
249 | encoder_wds=[0.0],
250 | lens_wds=[0.05, 0.1, 0.2],
251 | predictor_wds=[0.0, 0.05, 0.1],
252 | reduce_fn_kwargs_paths=['linear_pool_1024'],
253 | lens_train_samples=[50],
254 | train_families=10000,
255 | epochs=10,
256 | measurements=2,
257 | encoder_fn_name='transformer',
258 | encoder_fn_kwargs_path='medium_transformer_kwargs',
259 | reduce_fn_name='linear_max_pool',
260 | lens_batch_size=64,
261 | knn_batch_size=64,
262 | use_transformer=True,
263 | use_bert=True,
264 | restore_transformer_dir=None,
265 | load_gcs_bucket=load_gcs_bucket,
266 | save_gcs_bucket=save_gcs_bucket)
267 |
268 | params += create_params(
269 | encoder_lrs=[0.0],
270 | lens_lrs=[1e-4, 5e-5, 1e-5],
271 | predictor_lrs=[1e-3, 5e-4, 1e-4],
272 | encoder_wds=[0.0],
273 | lens_wds=[0.05, 0.1, 0.2],
274 | predictor_wds=[0.15, 0.2, 0.25],
275 | reduce_fn_kwargs_paths=['linear_pool_1024'],
276 | lens_train_samples=[50],
277 | train_families=10000,
278 | epochs=10,
279 | measurements=2,
280 | encoder_fn_name='transformer',
281 | encoder_fn_kwargs_path='medium_transformer_kwargs',
282 | reduce_fn_name='linear_max_pool',
283 | lens_batch_size=64,
284 | knn_batch_size=64,
285 | use_transformer=True,
286 | use_bert=True,
287 | restore_transformer_dir=
288 | 'gs://sequin-public/transformer_models/medium_trembl_bert/',
289 | load_gcs_bucket=load_gcs_bucket,
290 | save_gcs_bucket=save_gcs_bucket)
291 |
292 | params += create_params(encoder_lrs=[1e-3, 5e-4, 1e-4],
293 | lens_lrs=[1e-4, 5e-5, 1e-5],
294 | predictor_lrs=[1e-3, 5e-4, 1e-4],
295 | encoder_wds=[0.05, 0.1, 0.2],
296 | lens_wds=[0.05, 0.1, 0.2],
297 | predictor_wds=[0.0, 0.05, 0.1],
298 | reduce_fn_kwargs_paths=['linear_pool_1024'],
299 | lens_train_samples=[50],
300 | train_families=10000,
301 | epochs=10,
302 | measurements=2,
303 | encoder_fn_name='cnn_one_hot',
304 | encoder_fn_kwargs_path='2-layer_cnn_kwargs',
305 | reduce_fn_name='linear_max_pool',
306 | lens_batch_size=512,
307 | knn_batch_size=512,
308 | use_transformer=False,
309 | use_bert=False,
310 | restore_transformer_dir=None,
311 | load_gcs_bucket=load_gcs_bucket,
312 | save_gcs_bucket=save_gcs_bucket)
313 |
314 | params += create_params(encoder_lrs=[1e-3, 5e-4, 1e-4, 5e-5],
315 | lens_lrs=[5e-5, 1e-5, 5e-6],
316 | predictor_lrs=[1e-4, 5e-5, 1e-5, 5e-6],
317 | encoder_wds=[0.2, 0.3],
318 | lens_wds=[0.05, 0.1],
319 | predictor_wds=[0.0, 0.05, 0.1],
320 | reduce_fn_kwargs_paths=['linear_pool_1024'],
321 | lens_train_samples=[50],
322 | train_families=10000,
323 | epochs=10,
324 | measurements=2,
325 | encoder_fn_name='cnn_one_hot',
326 | encoder_fn_kwargs_path='2-layer_cnn_kwargs',
327 | reduce_fn_name='linear_max_pool',
328 | lens_batch_size=512,
329 | knn_batch_size=512,
330 | use_transformer=False,
331 | use_bert=False,
332 | restore_transformer_dir=None,
333 | load_gcs_bucket=load_gcs_bucket,
334 | save_gcs_bucket=save_gcs_bucket)
335 |
336 | # 10000 train families medium transformer random keys sweep
337 | params += create_params(encoder_lrs=[0.0],
338 | lens_lrs=[1e-5],
339 | predictor_lrs=[1e-3],
340 | encoder_wds=[0.0],
341 | lens_wds=[0.05],
342 | predictor_wds=[0.0],
343 | reduce_fn_kwargs_paths=['linear_pool_1024'],
344 | lens_train_samples=[50],
345 | train_families=10000,
346 | epochs=10,
347 | measurements=2,
348 | encoder_fn_name='transformer',
349 | encoder_fn_kwargs_path='medium_transformer_kwargs',
350 | reduce_fn_name='linear_max_pool',
351 | lens_batch_size=64,
352 | knn_batch_size=64,
353 | use_transformer=True,
354 | use_bert=True,
355 | load_gcs_bucket=load_gcs_bucket,
356 | save_gcs_bucket=save_gcs_bucket,
357 | model_random_keys=range(10))
358 |
359 | params += create_params(
360 | encoder_lrs=[0.0],
361 | lens_lrs=[1e-5],
362 | predictor_lrs=[1e-3],
363 | encoder_wds=[0.0],
364 | lens_wds=[0.05],
365 | predictor_wds=[0.0],
366 | reduce_fn_kwargs_paths=['linear_pool_1024'],
367 | lens_train_samples=[50],
368 | train_families=10000,
369 | epochs=10,
370 | measurements=2,
371 | encoder_fn_name='transformer',
372 | encoder_fn_kwargs_path='medium_transformer_kwargs',
373 | reduce_fn_name='linear_max_pool',
374 | lens_batch_size=64,
375 | knn_batch_size=64,
376 | use_transformer=True,
377 | use_bert=True,
378 | restore_transformer_dir=
379 | 'gs://sequin-public/transformer_models/medium_trembl_bert/',
380 | load_gcs_bucket=load_gcs_bucket,
381 | save_gcs_bucket=save_gcs_bucket,
382 | model_random_keys=range(10))
383 |
384 | params += create_params(encoder_lrs=[0.0],
385 | lens_lrs=[5e-5],
386 | predictor_lrs=[5e-4],
387 | encoder_wds=[0.0],
388 | lens_wds=[0.2],
389 | predictor_wds=[0.2],
390 | reduce_fn_kwargs_paths=['linear_pool_1024'],
391 | lens_train_samples=[50],
392 | train_families=10000,
393 | epochs=10,
394 | measurements=2,
395 | encoder_fn_name='transformer',
396 | encoder_fn_kwargs_path='medium_transformer_kwargs',
397 | reduce_fn_name='linear_max_pool',
398 | lens_batch_size=64,
399 | knn_batch_size=64,
400 | use_transformer=True,
401 | use_bert=True,
402 | load_gcs_bucket=load_gcs_bucket,
403 | save_gcs_bucket=save_gcs_bucket,
404 | model_random_keys=range(10))
405 |
406 | params += create_params(
407 | encoder_lrs=[0.0],
408 | lens_lrs=[5e-5],
409 | predictor_lrs=[5e-4],
410 | encoder_wds=[0.0],
411 | lens_wds=[0.2],
412 | predictor_wds=[0.2],
413 | reduce_fn_kwargs_paths=['linear_pool_1024'],
414 | lens_train_samples=[50],
415 | train_families=10000,
416 | epochs=10,
417 | measurements=2,
418 | encoder_fn_name='transformer',
419 | encoder_fn_kwargs_path='medium_transformer_kwargs',
420 | reduce_fn_name='linear_max_pool',
421 | lens_batch_size=64,
422 | knn_batch_size=64,
423 | use_transformer=True,
424 | use_bert=True,
425 | restore_transformer_dir=
426 | 'gs://sequin-public/transformer_models/medium_trembl_bert/',
427 | load_gcs_bucket=load_gcs_bucket,
428 | save_gcs_bucket=save_gcs_bucket,
429 | model_random_keys=range(10))
430 |
431 | # 10000 train families small transformer random keys sweep
432 | params += create_params(encoder_lrs=[0.0],
433 | lens_lrs=[1e-3, 1e-4, 1e-5, 1e-6],
434 | predictor_lrs=[1e-3, 1e-4, 1e-5, 1e-6],
435 | encoder_wds=[0.0],
436 | lens_wds=[0.0, 0.1, 0.2],
437 | predictor_wds=[0.0, 0.1, 0.2],
438 | reduce_fn_kwargs_paths=['linear_pool_1024'],
439 | lens_train_samples=[50],
440 | train_families=10000,
441 | epochs=10,
442 | measurements=1,
443 | encoder_fn_name='transformer',
444 | encoder_fn_kwargs_path='small_transformer_kwargs',
445 | reduce_fn_name='linear_max_pool',
446 | lens_batch_size=64,
447 | knn_batch_size=64,
448 | use_transformer=True,
449 | use_bert=True,
450 | load_gcs_bucket=load_gcs_bucket,
451 | save_gcs_bucket=save_gcs_bucket)
452 |
453 | params += create_params(
454 | encoder_lrs=[0.0],
455 | lens_lrs=[1e-3, 1e-4, 1e-5, 1e-6],
456 | predictor_lrs=[1e-3, 1e-4, 1e-5, 1e-6],
457 | encoder_wds=[0.0],
458 | lens_wds=[0.0, 0.1, 0.2],
459 | predictor_wds=[0.0, 0.1, 0.2],
460 | reduce_fn_kwargs_paths=['linear_pool_1024'],
461 | lens_train_samples=[50],
462 | train_families=10000,
463 | epochs=10,
464 | measurements=1,
465 | encoder_fn_name='transformer',
466 | encoder_fn_kwargs_path='small_transformer_kwargs',
467 | reduce_fn_name='linear_max_pool',
468 | lens_batch_size=64,
469 | knn_batch_size=64,
470 | use_transformer=True,
471 | use_bert=True,
472 | restore_transformer_dir=
473 | 'gs://sequin-public/transformer_models/small_trembl_bert/',
474 | load_gcs_bucket=load_gcs_bucket,
475 | save_gcs_bucket=save_gcs_bucket)
476 |
477 | params += create_params(encoder_lrs=[0.0],
478 | lens_lrs=[1e-4],
479 | predictor_lrs=[1e-3],
480 | encoder_wds=[0.0],
481 | lens_wds=[0.1],
482 | predictor_wds=[0.0],
483 | reduce_fn_kwargs_paths=['linear_pool_1024'],
484 | lens_train_samples=[50],
485 | train_families=10000,
486 | epochs=10,
487 | measurements=1,
488 | encoder_fn_name='transformer',
489 | encoder_fn_kwargs_path='small_transformer_kwargs',
490 | reduce_fn_name='linear_max_pool',
491 | lens_batch_size=64,
492 | knn_batch_size=64,
493 | use_transformer=True,
494 | use_bert=True,
495 | load_gcs_bucket=load_gcs_bucket,
496 | save_gcs_bucket=save_gcs_bucket,
497 | model_random_keys=range(10))
498 |
499 | params += create_params(
500 | encoder_lrs=[0.0],
501 | lens_lrs=[1e-4],
502 | predictor_lrs=[1e-3],
503 | encoder_wds=[0.0],
504 | lens_wds=[0.1],
505 | predictor_wds=[0.0],
506 | reduce_fn_kwargs_paths=['linear_pool_1024'],
507 | lens_train_samples=[50],
508 | train_families=10000,
509 | epochs=10,
510 | measurements=1,
511 | encoder_fn_name='transformer',
512 | encoder_fn_kwargs_path='small_transformer_kwargs',
513 | reduce_fn_name='linear_max_pool',
514 | lens_batch_size=64,
515 | knn_batch_size=64,
516 | use_transformer=True,
517 | use_bert=True,
518 | load_gcs_bucket=load_gcs_bucket,
519 | save_gcs_bucket=save_gcs_bucket,
520 | restore_transformer_dir=
521 | 'gs://sequin-public/transformer_models/small_trembl_bert/',
522 | model_random_keys=range(10))
523 |
524 | params += create_params(encoder_lrs=[0.0],
525 | lens_lrs=[1e-3],
526 | predictor_lrs=[1e-3],
527 | encoder_wds=[0.0],
528 | lens_wds=[0.1],
529 | predictor_wds=[0.0],
530 | reduce_fn_kwargs_paths=['linear_pool_1024'],
531 | lens_train_samples=[50],
532 | train_families=10000,
533 | epochs=10,
534 | measurements=1,
535 | encoder_fn_name='transformer',
536 | encoder_fn_kwargs_path='small_transformer_kwargs',
537 | reduce_fn_name='linear_max_pool',
538 | lens_batch_size=64,
539 | knn_batch_size=64,
540 | use_transformer=True,
541 | use_bert=True,
542 | load_gcs_bucket=load_gcs_bucket,
543 | save_gcs_bucket=save_gcs_bucket,
544 | model_random_keys=range(10))
545 |
546 | params += create_params(
547 | encoder_lrs=[0.0],
548 | lens_lrs=[1e-3],
549 | predictor_lrs=[1e-3],
550 | encoder_wds=[0.0],
551 | lens_wds=[0.1],
552 | predictor_wds=[0.0],
553 | reduce_fn_kwargs_paths=['linear_pool_1024'],
554 | lens_train_samples=[50],
555 | train_families=10000,
556 | epochs=10,
557 | measurements=1,
558 | encoder_fn_name='transformer',
559 | encoder_fn_kwargs_path='small_transformer_kwargs',
560 | reduce_fn_name='linear_max_pool',
561 | lens_batch_size=64,
562 | knn_batch_size=64,
563 | use_transformer=True,
564 | use_bert=True,
565 | load_gcs_bucket=load_gcs_bucket,
566 | save_gcs_bucket=save_gcs_bucket,
567 | restore_transformer_dir=
568 | 'gs://sequin-public/transformer_models/small_trembl_bert/',
569 | model_random_keys=range(10))
570 |
571 | params += create_params(encoder_lrs=[0.0],
572 | lens_lrs=[1e-4],
573 | predictor_lrs=[1e-3],
574 | encoder_wds=[0.0],
575 | lens_wds=[0.0],
576 | predictor_wds=[0.0],
577 | reduce_fn_kwargs_paths=['linear_pool_1024'],
578 | lens_train_samples=[50],
579 | train_families=10000,
580 | epochs=10,
581 | measurements=1,
582 | encoder_fn_name='transformer',
583 | encoder_fn_kwargs_path='small_transformer_kwargs',
584 | reduce_fn_name='linear_max_pool',
585 | lens_batch_size=64,
586 | knn_batch_size=64,
587 | use_transformer=True,
588 | use_bert=True,
589 | load_gcs_bucket=load_gcs_bucket,
590 | save_gcs_bucket=save_gcs_bucket,
591 | model_random_keys=range(10))
592 |
593 | params += create_params(
594 | encoder_lrs=[0.0],
595 | lens_lrs=[1e-4],
596 | predictor_lrs=[1e-3],
597 | encoder_wds=[0.0],
598 | lens_wds=[0.0],
599 | predictor_wds=[0.0],
600 | reduce_fn_kwargs_paths=['linear_pool_1024'],
601 | lens_train_samples=[50],
602 | train_families=10000,
603 | epochs=10,
604 | measurements=1,
605 | encoder_fn_name='transformer',
606 | encoder_fn_kwargs_path='small_transformer_kwargs',
607 | reduce_fn_name='linear_max_pool',
608 | lens_batch_size=64,
609 | knn_batch_size=64,
610 | use_transformer=True,
611 | use_bert=True,
612 | load_gcs_bucket=load_gcs_bucket,
613 | save_gcs_bucket=save_gcs_bucket,
614 | restore_transformer_dir=
615 | 'gs://sequin-public/transformer_models/small_trembl_bert/',
616 | model_random_keys=range(10))
617 |
618 | params += create_params(encoder_lrs=[0.0],
619 | lens_lrs=[1e-4],
620 | predictor_lrs=[1e-3],
621 | encoder_wds=[0.0],
622 | lens_wds=[0.2],
623 | predictor_wds=[0.2],
624 | reduce_fn_kwargs_paths=['linear_pool_1024'],
625 | lens_train_samples=[50],
626 | train_families=10000,
627 | epochs=10,
628 | measurements=1,
629 | encoder_fn_name='transformer',
630 | encoder_fn_kwargs_path='small_transformer_kwargs',
631 | reduce_fn_name='linear_max_pool',
632 | lens_batch_size=64,
633 | knn_batch_size=64,
634 | use_transformer=True,
635 | use_bert=True,
636 | load_gcs_bucket=load_gcs_bucket,
637 | save_gcs_bucket=save_gcs_bucket,
638 | model_random_keys=range(10))
639 |
640 | params += create_params(
641 | encoder_lrs=[0.0],
642 | lens_lrs=[1e-4],
643 | predictor_lrs=[1e-3],
644 | encoder_wds=[0.0],
645 | lens_wds=[0.2],
646 | predictor_wds=[0.2],
647 | reduce_fn_kwargs_paths=['linear_pool_1024'],
648 | lens_train_samples=[50],
649 | train_families=10000,
650 | epochs=10,
651 | measurements=1,
652 | encoder_fn_name='transformer',
653 | encoder_fn_kwargs_path='small_transformer_kwargs',
654 | reduce_fn_name='linear_max_pool',
655 | lens_batch_size=64,
656 | knn_batch_size=64,
657 | use_transformer=True,
658 | use_bert=True,
659 | load_gcs_bucket=load_gcs_bucket,
660 | save_gcs_bucket=save_gcs_bucket,
661 | restore_transformer_dir=
662 | 'gs://sequin-public/transformer_models/small_trembl_bert/',
663 | model_random_keys=range(10))
664 |
665 | # 10000 train families large transformer random keys sweep
666 | for i in range(5):
667 | params += create_params(
668 | encoder_lrs=[0.0],
669 | lens_lrs=[1e-5],
670 | predictor_lrs=[1e-3],
671 | encoder_wds=[0.0],
672 | lens_wds=[0.05],
673 | predictor_wds=[0.0],
674 | reduce_fn_kwargs_paths=['linear_pool_1024'],
675 | lens_train_samples=[50],
676 | train_families=10000,
677 | epochs=10,
678 | measurements=1,
679 | encoder_fn_name='transformer',
680 | encoder_fn_kwargs_path='large_transformer_kwargs',
681 | reduce_fn_name='linear_max_pool',
682 | lens_batch_size=8,
683 | knn_batch_size=64,
684 | use_transformer=True,
685 | use_bert=True,
686 | load_gcs_bucket=load_gcs_bucket,
687 | save_gcs_bucket=save_gcs_bucket,
688 | model_random_keys=[i],
689 | save_model=True,
690 | save_model_dir=os.path.join('pfam_experiment_optimizers',
691 | 'large_0_' + str(i)))
692 |
693 | params += create_params(
694 | encoder_lrs=[0.0],
695 | lens_lrs=[1e-5],
696 | predictor_lrs=[1e-3],
697 | encoder_wds=[0.0],
698 | lens_wds=[0.05],
699 | predictor_wds=[0.0],
700 | reduce_fn_kwargs_paths=['linear_pool_1024'],
701 | lens_train_samples=[50],
702 | train_families=10000,
703 | epochs=10,
704 | measurements=1,
705 | encoder_fn_name='transformer',
706 | encoder_fn_kwargs_path='large_transformer_kwargs',
707 | reduce_fn_name='linear_max_pool',
708 | lens_batch_size=8,
709 | knn_batch_size=64,
710 | use_transformer=True,
711 | use_bert=True,
712 | restore_transformer_dir=
713 | 'gs://sequin-public/transformer_models/large_bert/',
714 | load_gcs_bucket=load_gcs_bucket,
715 | save_gcs_bucket=save_gcs_bucket,
716 | model_random_keys=[i],
717 | save_model=True,
718 | save_model_dir=os.path.join('pfam_experiment_optimizers',
719 | 'large_1_' + str(i)))
720 |
721 | params += create_params(
722 | encoder_lrs=[0.0],
723 | lens_lrs=[5e-5],
724 | predictor_lrs=[5e-4],
725 | encoder_wds=[0.0],
726 | lens_wds=[0.2],
727 | predictor_wds=[0.2],
728 | reduce_fn_kwargs_paths=['linear_pool_1024'],
729 | lens_train_samples=[50],
730 | train_families=10000,
731 | epochs=10,
732 | measurements=1,
733 | encoder_fn_name='transformer',
734 | encoder_fn_kwargs_path='large_transformer_kwargs',
735 | reduce_fn_name='linear_max_pool',
736 | lens_batch_size=8,
737 | knn_batch_size=64,
738 | use_transformer=True,
739 | use_bert=True,
740 | load_gcs_bucket=load_gcs_bucket,
741 | save_gcs_bucket=save_gcs_bucket,
742 | model_random_keys=[i],
743 | save_model=True,
744 | save_model_dir=os.path.join('pfam_experiment_optimizers',
745 | 'large_2_' + str(i)))
746 |
747 | params += create_params(
748 | encoder_lrs=[0.0],
749 | lens_lrs=[5e-5],
750 | predictor_lrs=[5e-4],
751 | encoder_wds=[0.0],
752 | lens_wds=[0.2],
753 | predictor_wds=[0.2],
754 | reduce_fn_kwargs_paths=['linear_pool_1024'],
755 | lens_train_samples=[50],
756 | train_families=10000,
757 | epochs=10,
758 | measurements=1,
759 | encoder_fn_name='transformer',
760 | encoder_fn_kwargs_path='large_transformer_kwargs',
761 | reduce_fn_name='linear_max_pool',
762 | lens_batch_size=8,
763 | knn_batch_size=64,
764 | use_transformer=True,
765 | use_bert=True,
766 | restore_transformer_dir=
767 | 'gs://sequin-public/transformer_models/large_bert/',
768 | load_gcs_bucket=load_gcs_bucket,
769 | save_gcs_bucket=save_gcs_bucket,
770 | model_random_keys=[i],
771 | save_model=True,
772 | save_model_dir=os.path.join('pfam_experiment_optimizers',
773 | 'large_3_' + str(i)))
774 |
775 | params += create_params(
776 | encoder_lrs=[0.0],
777 | lens_lrs=[1e-4],
778 | predictor_lrs=[1e-3],
779 | encoder_wds=[0.0],
780 | lens_wds=[0.1],
781 | predictor_wds=[0.0],
782 | reduce_fn_kwargs_paths=['linear_pool_1024'],
783 | lens_train_samples=[50],
784 | train_families=10000,
785 | epochs=10,
786 | measurements=1,
787 | encoder_fn_name='transformer',
788 | encoder_fn_kwargs_path='large_transformer_kwargs',
789 | reduce_fn_name='linear_max_pool',
790 | lens_batch_size=8,
791 | knn_batch_size=64,
792 | use_transformer=True,
793 | use_bert=True,
794 | load_gcs_bucket=load_gcs_bucket,
795 | save_gcs_bucket=save_gcs_bucket,
796 | model_random_keys=[i],
797 | save_model=True,
798 | save_model_dir=os.path.join('pfam_experiment_optimizers',
799 | 'large_4_' + str(i)))
800 |
801 | params += create_params(
802 | encoder_lrs=[0.0],
803 | lens_lrs=[1e-4],
804 | predictor_lrs=[1e-3],
805 | encoder_wds=[0.0],
806 | lens_wds=[0.1],
807 | predictor_wds=[0.0],
808 | reduce_fn_kwargs_paths=['linear_pool_1024'],
809 | lens_train_samples=[50],
810 | train_families=10000,
811 | epochs=10,
812 | measurements=1,
813 | encoder_fn_name='transformer',
814 | encoder_fn_kwargs_path='large_transformer_kwargs',
815 | reduce_fn_name='linear_max_pool',
816 | lens_batch_size=8,
817 | knn_batch_size=64,
818 | use_transformer=True,
819 | use_bert=True,
820 | load_gcs_bucket=load_gcs_bucket,
821 | save_gcs_bucket=save_gcs_bucket,
822 | restore_transformer_dir=
823 | 'gs://sequin-public/transformer_models/large_bert/',
824 | model_random_keys=[i],
825 | save_model=True,
826 | save_model_dir=os.path.join('pfam_experiment_optimizers',
827 | 'large_5_' + str(i)))
828 |
829 | params += create_params(
830 | encoder_lrs=[0.0],
831 | lens_lrs=[1e-3],
832 | predictor_lrs=[1e-3],
833 | encoder_wds=[0.0],
834 | lens_wds=[0.1],
835 | predictor_wds=[0.0],
836 | reduce_fn_kwargs_paths=['linear_pool_1024'],
837 | lens_train_samples=[50],
838 | train_families=10000,
839 | epochs=10,
840 | measurements=1,
841 | encoder_fn_name='transformer',
842 | encoder_fn_kwargs_path='large_transformer_kwargs',
843 | reduce_fn_name='linear_max_pool',
844 | lens_batch_size=8,
845 | knn_batch_size=64,
846 | use_transformer=True,
847 | use_bert=True,
848 | load_gcs_bucket=load_gcs_bucket,
849 | save_gcs_bucket=save_gcs_bucket,
850 | model_random_keys=[i],
851 | save_model=True,
852 | save_model_dir=os.path.join('pfam_experiment_optimizers',
853 | 'large_6_' + str(i)))
854 |
855 | params += create_params(
856 | encoder_lrs=[0.0],
857 | lens_lrs=[1e-3],
858 | predictor_lrs=[1e-3],
859 | encoder_wds=[0.0],
860 | lens_wds=[0.1],
861 | predictor_wds=[0.0],
862 | reduce_fn_kwargs_paths=['linear_pool_1024'],
863 | lens_train_samples=[50],
864 | train_families=10000,
865 | epochs=10,
866 | measurements=1,
867 | encoder_fn_name='transformer',
868 | encoder_fn_kwargs_path='large_transformer_kwargs',
869 | reduce_fn_name='linear_max_pool',
870 | lens_batch_size=8,
871 | knn_batch_size=64,
872 | use_transformer=True,
873 | use_bert=True,
874 | load_gcs_bucket=load_gcs_bucket,
875 | save_gcs_bucket=save_gcs_bucket,
876 | restore_transformer_dir=
877 | 'gs://sequin-public/transformer_models/large_bert/',
878 | model_random_keys=[i],
879 | save_model=True,
880 | save_model_dir=os.path.join('pfam_experiment_optimizers',
881 | 'large_7_' + str(i)))
882 |
883 |
884 | # 10000 train families random key sweep and save best models
885 | for i in range(7):
886 | params += create_params(
887 | encoder_lrs=[0.0],
888 | lens_lrs=[1e-5],
889 | predictor_lrs=[1e-3],
890 | encoder_wds=[0.0],
891 | lens_wds=[0.05],
892 | predictor_wds=[0.0],
893 | reduce_fn_kwargs_paths=['linear_pool_1024'],
894 | lens_train_samples=[50],
895 | train_families=10000,
896 | epochs=10,
897 | measurements=2,
898 | encoder_fn_name='transformer',
899 | encoder_fn_kwargs_path='medium_transformer_kwargs',
900 | reduce_fn_name='linear_max_pool',
901 | lens_batch_size=64,
902 | knn_batch_size=64,
903 | use_transformer=True,
904 | use_bert=True,
905 | load_gcs_bucket=load_gcs_bucket,
906 | save_gcs_bucket=save_gcs_bucket,
907 | model_random_keys=[i],
908 | save_model=True,
909 | save_model_dir=os.path.join('pfam_experiment_optimizers',
910 | 'medium_' + str(i)))
911 |
912 | params += create_params(
913 | encoder_lrs=[0.0],
914 | lens_lrs=[5e-5],
915 | predictor_lrs=[1e-3],
916 | encoder_wds=[0.0],
917 | lens_wds=[0.2],
918 | predictor_wds=[0.25],
919 | reduce_fn_kwargs_paths=['linear_pool_1024'],
920 | lens_train_samples=[50],
921 | train_families=10000,
922 | epochs=10,
923 | measurements=2,
924 | encoder_fn_name='transformer',
925 | encoder_fn_kwargs_path='medium_transformer_kwargs',
926 | reduce_fn_name='linear_max_pool',
927 | lens_batch_size=64,
928 | knn_batch_size=64,
929 | use_transformer=True,
930 | use_bert=True,
931 | restore_transformer_dir=
932 | 'gs://sequin-public/transformer_models/medium_trembl_bert/',
933 | load_gcs_bucket=load_gcs_bucket,
934 | save_gcs_bucket=save_gcs_bucket,
935 | model_random_keys=[i],
936 | save_model=True,
937 | save_model_dir=os.path.join('pfam_experiment_optimizers',
938 | 'medium_pt_' + str(i)))
939 |
940 | params += create_params(
941 | encoder_lrs=[0.0],
942 | lens_lrs=[1e-3],
943 | predictor_lrs=[1e-3],
944 | encoder_wds=[0.0],
945 | lens_wds=[0.1],
946 | predictor_wds=[0.0],
947 | reduce_fn_kwargs_paths=['linear_pool_1024'],
948 | lens_train_samples=[50],
949 | train_families=10000,
950 | epochs=10,
951 | measurements=1,
952 | encoder_fn_name='transformer',
953 | encoder_fn_kwargs_path='small_transformer_kwargs',
954 | reduce_fn_name='linear_max_pool',
955 | lens_batch_size=64,
956 | knn_batch_size=64,
957 | use_transformer=True,
958 | use_bert=True,
959 | load_gcs_bucket=load_gcs_bucket,
960 | save_gcs_bucket=save_gcs_bucket,
961 | model_random_keys=[i],
962 | save_model=True,
963 | save_model_dir=os.path.join('pfam_experiment_optimizers',
964 | 'small_' + str(i)))
965 |
966 | params += create_params(
967 | encoder_lrs=[0.0],
968 | lens_lrs=[1e-3],
969 | predictor_lrs=[1e-3],
970 | encoder_wds=[0.0],
971 | lens_wds=[0.2],
972 | predictor_wds=[0.2],
973 | reduce_fn_kwargs_paths=['linear_pool_1024'],
974 | lens_train_samples=[50],
975 | train_families=10000,
976 | epochs=10,
977 | measurements=1,
978 | encoder_fn_name='transformer',
979 | encoder_fn_kwargs_path='small_transformer_kwargs',
980 | reduce_fn_name='linear_max_pool',
981 | lens_batch_size=64,
982 | knn_batch_size=64,
983 | use_transformer=True,
984 | use_bert=True,
985 | restore_transformer_dir=
986 | 'gs://sequin-public/transformer_models/small_trembl_bert/',
987 | load_gcs_bucket=load_gcs_bucket,
988 | save_gcs_bucket=save_gcs_bucket,
989 | model_random_keys=[i],
990 | save_model=True,
991 | save_model_dir=os.path.join('pfam_experiment_optimizers',
992 | 'small_pt_' + str(i)))
993 |
994 | params += create_params(
995 | encoder_lrs=[1e-3],
996 | lens_lrs=[5e-5],
997 | predictor_lrs=[5e-5],
998 | encoder_wds=[0.3],
999 | lens_wds=[0.05],
1000 | predictor_wds=[0.0],
1001 | reduce_fn_kwargs_paths=['linear_pool_1024'],
1002 | lens_train_samples=[50],
1003 | train_families=10000,
1004 | epochs=10,
1005 | measurements=2,
1006 | encoder_fn_name='cnn_one_hot',
1007 | encoder_fn_kwargs_path='2-layer_cnn_kwargs',
1008 | reduce_fn_name='linear_max_pool',
1009 | lens_batch_size=512,
1010 | knn_batch_size=512,
1011 | use_transformer=False,
1012 | use_bert=False,
1013 | load_gcs_bucket=load_gcs_bucket,
1014 | save_gcs_bucket=save_gcs_bucket,
1015 | model_random_keys=[i],
1016 | save_model=True,
1017 | save_model_dir=os.path.join('pfam_experiment_optimizers',
1018 | 'cnn_' + str(i)))
1019 |
1020 | for i in range(7):
1021 | params += create_params(
1022 | encoder_lrs=[0.0],
1023 | lens_lrs=[1e-4],
1024 | predictor_lrs=[1e-3],
1025 | encoder_wds=[0.0],
1026 | lens_wds=[0.1],
1027 | predictor_wds=[0.0],
1028 | reduce_fn_kwargs_paths=['linear_pool_1024'],
1029 | lens_train_samples=[50],
1030 | train_families=10000,
1031 | epochs=10,
1032 | measurements=1,
1033 | encoder_fn_name='transformer',
1034 | encoder_fn_kwargs_path='small_transformer_kwargs',
1035 | reduce_fn_name='linear_max_pool',
1036 | lens_batch_size=64,
1037 | knn_batch_size=64,
1038 | use_transformer=True,
1039 | use_bert=True,
1040 | load_gcs_bucket=load_gcs_bucket,
1041 | save_gcs_bucket=save_gcs_bucket,
1042 | model_random_keys=[i],
1043 | save_model=True,
1044 | save_model_dir=os.path.join('pfam_experiment_optimizers',
1045 | 'small_' + str(i+7)))
1046 |
1047 | for i in range(10):
1048 | params += create_params(
1049 | encoder_lrs=[0.0],
1050 | lens_lrs=[1e-4],
1051 | predictor_lrs=[1e-3],
1052 | encoder_wds=[0.0],
1053 | lens_wds=[0.2],
1054 | predictor_wds=[0.2],
1055 | reduce_fn_kwargs_paths=['linear_pool_1024'],
1056 | lens_train_samples=[50],
1057 | train_families=10000,
1058 | epochs=10,
1059 | measurements=1,
1060 | encoder_fn_name='transformer',
1061 | encoder_fn_kwargs_path='small_transformer_kwargs',
1062 | reduce_fn_name='linear_max_pool',
1063 | lens_batch_size=64,
1064 | knn_batch_size=64,
1065 | use_transformer=True,
1066 | use_bert=True,
1067 | restore_transformer_dir=
1068 | 'gs://sequin-public/transformer_models/small_trembl_bert/',
1069 | load_gcs_bucket=load_gcs_bucket,
1070 | save_gcs_bucket=save_gcs_bucket,
1071 | model_random_keys=[i],
1072 | save_model=True,
1073 | save_model_dir=os.path.join('pfam_experiment_optimizers',
1074 | 'small_' + str(i+14)))
1075 |
1076 | params += create_params(
1077 | encoder_lrs=[0.0],
1078 | lens_lrs=[1e-5],
1079 | predictor_lrs=[1e-3],
1080 | encoder_wds=[0.0],
1081 | lens_wds=[0.05],
1082 | predictor_wds=[0.0],
1083 | reduce_fn_kwargs_paths=['linear_pool_1024'],
1084 | lens_train_samples=[50],
1085 | train_families=10000,
1086 | epochs=10,
1087 | measurements=1,
1088 | encoder_fn_name='transformer',
1089 | encoder_fn_kwargs_path='medium_transformer_kwargs',
1090 | reduce_fn_name='linear_max_pool',
1091 | lens_batch_size=64,
1092 | knn_batch_size=64,
1093 | use_transformer=True,
1094 | use_bert=True,
1095 | load_gcs_bucket=load_gcs_bucket,
1096 | save_gcs_bucket=save_gcs_bucket,
1097 | model_random_keys=[i],
1098 | save_model=True,
1099 | save_model_dir=os.path.join('pfam_experiment_optimizers',
1100 | 'medium_' + str(i+7)))
1101 |
1102 | params += create_params(
1103 | encoder_lrs=[0.0],
1104 | lens_lrs=[5e-5],
1105 | predictor_lrs=[1e-3],
1106 | encoder_wds=[0.0],
1107 | lens_wds=[0.2],
1108 | predictor_wds=[0.25],
1109 | reduce_fn_kwargs_paths=['linear_pool_1024'],
1110 | lens_train_samples=[50],
1111 | train_families=10000,
1112 | epochs=10,
1113 | measurements=1,
1114 | encoder_fn_name='transformer',
1115 | encoder_fn_kwargs_path='medium_transformer_kwargs',
1116 | reduce_fn_name='linear_max_pool',
1117 | lens_batch_size=64,
1118 | knn_batch_size=64,
1119 | use_transformer=True,
1120 | use_bert=True,
1121 | restore_transformer_dir=
1122 | 'gs://sequin-public/transformer_models/medium_trembl_bert/',
1123 | load_gcs_bucket=load_gcs_bucket,
1124 | save_gcs_bucket=save_gcs_bucket,
1125 | model_random_keys=[i],
1126 | save_model=True,
1127 | save_model_dir=os.path.join('pfam_experiment_optimizers',
1128 | 'medium_pt_' + str(i+7)))
1129 |
1130 |
1131 | frozen_param_dict_to_label = {}
1132 | label = 0
1133 | for param_dict in params:
1134 | if frozendict(param_dict) not in frozen_param_dict_to_label.keys():
1135 | frozen_param_dict_to_label[frozendict(param_dict)] = label
1136 | label += 1
1137 |
1138 | unique_params = []
1139 | for frozen_param_dict in frozen_param_dict_to_label.keys():
1140 | param_dict = dict(frozen_param_dict)
1141 | param_dict.update(
1142 | {'label': frozen_param_dict_to_label[frozen_param_dict]})
1143 | unique_params.append(param_dict)
1144 | unique_params = sorted(unique_params, key=lambda x: x['label'])
1145 |
1146 | def transform_label(param_dict):
1147 | param_dict['label'] = '%08d' % param_dict['label']
1148 | return param_dict
1149 |
1150 | unique_params = [
1151 | transform_label(param_dict) for param_dict in unique_params
1152 | ]
1153 |
1154 | with open('params_combinations.json', 'w') as f:
1155 | json.dump(unique_params, f)
1156 |
1157 | label_to_params = {}
1158 | for param_dict in unique_params:
1159 | label_to_params[param_dict['label']] = param_dict
1160 |
1161 | with open('label_to_params.json', 'w') as f:
1162 | json.dump(label_to_params, f)
1163 |
1164 |
1165 | if __name__ == '__main__':
1166 | main(load_gcs_bucket='neuralblast_public',
1167 | save_gcs_bucket='sequin-public')
1168 |
--------------------------------------------------------------------------------