├── .gitignore ├── CONTRIBUTING.md ├── setup.py ├── blast ├── convert_to_fasta.py └── convert_blast.py ├── common └── jsonl_utils.py ├── eval ├── eval_auc.py ├── eval_utils.py ├── eval_micro_f1.py ├── eval_protein_f1.py ├── eval_pfam.py └── eval_pfam_utils.py ├── data ├── convert_clean.py ├── convert_pfam.py ├── convert_proteinfer.py └── convert_deepfri.py ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Setup for ProtEx.""" 17 | 18 | import setuptools 19 | 20 | REQUIRED_PACKAGES = ["absl-py", "tensorflow", "numpy", "scikit-learn"] 21 | 22 | 23 | setuptools.setup( 24 | name="protex", 25 | description="Code related to protein function prediction with ProtEx.", 26 | packages=setuptools.find_packages(), 27 | install_requires=REQUIRED_PACKAGES, 28 | license="Apache 2.0", 29 | ) 30 | -------------------------------------------------------------------------------- /blast/convert_to_fasta.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Convert from examples jsonl to fasta format. 17 | 18 | Example of protein in fasta file format: 19 | 20 | >accession="Q0WJ82" 21 | MEKTQSVFIRFIVNGSLVKQILIGLVAGIVLALVST... 22 | """ 23 | 24 | from absl import app 25 | from absl import flags 26 | import tensorflow as tf 27 | 28 | from common import jsonl_utils 29 | 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | 34 | flags.DEFINE_string("input", "", "Path to examples.") 35 | 36 | flags.DEFINE_string("output", "", "Fasta output file.") 37 | 38 | 39 | def main(unused_argv): 40 | examples = jsonl_utils.read(FLAGS.input) 41 | with tf.io.gfile.GFile(FLAGS.output, "w") as fp: 42 | for example in examples: 43 | fp.write(f'>accession="{example["accession"]}"\n') 44 | fp.write(f'{example["sequence"]}\n') 45 | 46 | 47 | if __name__ == "__main__": 48 | app.run(main) 49 | -------------------------------------------------------------------------------- /common/jsonl_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities for reading and writing jsonl files.""" 17 | 18 | import json 19 | 20 | from tensorflow.io import gfile 21 | 22 | 23 | def read(filepath, verbose=True): 24 | """Read jsonl file to a List of Dicts.""" 25 | data = [] 26 | with gfile.GFile(filepath, "r") as jsonl_file: 27 | for idx, line in enumerate(jsonl_file): 28 | if verbose and idx % 1000 == 0: 29 | # Print the index every 1000 lines. 30 | print("Processing line %s." % idx) 31 | try: 32 | data.append(json.loads(line)) 33 | except json.JSONDecodeError as e: 34 | print("Failed to parse line: `%s`" % line) 35 | raise e 36 | if verbose: 37 | print("Loaded %s lines from %s." % (len(data), filepath)) 38 | return data 39 | 40 | 41 | def write(filepath, rows): 42 | """Write a List of Dicts to jsonl file.""" 43 | with gfile.GFile(filepath, "w") as jsonl_file: 44 | for row in rows: 45 | line = "%s\n" % json.dumps(row) 46 | jsonl_file.write(line) 47 | print("Wrote %s lines to %s." % (len(rows), filepath)) 48 | -------------------------------------------------------------------------------- /eval/eval_auc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Computes and prints weighted AUC metric. 17 | 18 | This metric was used in "Enzyme Function Prediction using Contrastive Learning". 19 | """ 20 | 21 | from absl import app 22 | from absl import flags 23 | import sklearn.metrics 24 | 25 | from common import jsonl_utils 26 | from eval import eval_utils 27 | 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | 32 | flags.DEFINE_string("dataset", "", "Path to jsonl dataset file.") 33 | 34 | flags.DEFINE_string("predictions", "", "Path to jsonl predictions file.") 35 | 36 | 37 | def get_test_labels(dataset): 38 | all_labels = set() 39 | for row in dataset: 40 | for label in row["labels"]: 41 | all_labels.add(label) 42 | return all_labels 43 | 44 | 45 | def get_auc(true_labels, pred_scores): 46 | """Return AUC.""" 47 | return sklearn.metrics.roc_auc_score( 48 | true_labels, pred_scores, average="weighted" 49 | ) 50 | 51 | 52 | def main(unused_argv): 53 | predictions = jsonl_utils.read(FLAGS.predictions) 54 | dataset = jsonl_utils.read(FLAGS.dataset) 55 | # Only labels occurring in the test set are considered for this metric. 56 | all_labels = get_test_labels(dataset) 57 | true_labels, pred_scores = eval_utils.preprocess_preds( 58 | dataset, predictions, all_labels 59 | ) 60 | auc = get_auc(true_labels, pred_scores) 61 | print(f"auc: {auc}") 62 | 63 | 64 | if __name__ == "__main__": 65 | app.run(main) 66 | -------------------------------------------------------------------------------- /data/convert_clean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Converts CLEAN files to common format. 17 | 18 | This dataset was proposed in the paper: 19 | "Enzyme Function Prediction using Contrastive Learning" 20 | https://www.science.org/doi/10.1126/science.adf2465 21 | 22 | The dataset files are available here: 23 | https://github.com/tttianhao/CLEAN/tree/main/app/data 24 | """ 25 | 26 | import csv 27 | 28 | from absl import app 29 | from absl import flags 30 | import tensorflow as tf 31 | 32 | from common import jsonl_utils 33 | 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | 38 | flags.DEFINE_string("input", "", "Path to input file.") 39 | 40 | flags.DEFINE_string("output", "", "Output location for jsonl file.") 41 | 42 | 43 | def load_tsv(filepath): 44 | dicts = [] 45 | with tf.io.gfile.GFile(filepath, "r") as f: 46 | reader = csv.DictReader(f, delimiter="\t") 47 | for row_dict in reader: 48 | dicts.append(row_dict) 49 | return dicts 50 | 51 | 52 | def convert_to_example(row_dict): 53 | return { 54 | "accession": row_dict["Entry"], 55 | "sequence": row_dict["Sequence"], 56 | "labels": row_dict["EC number"].split(";"), 57 | } 58 | 59 | 60 | def main(unused_argv): 61 | rows = load_tsv(FLAGS.input) 62 | print(f"Loaded {len(rows)} rows.") 63 | examples = [convert_to_example(row_dict) for row_dict in rows] 64 | jsonl_utils.write(FLAGS.output, examples) 65 | 66 | 67 | if __name__ == "__main__": 68 | app.run(main) 69 | -------------------------------------------------------------------------------- /eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Common utilities for evaluation.""" 17 | 18 | import collections 19 | 20 | import numpy as np 21 | 22 | 23 | NEG_INF = -1e9 24 | 25 | 26 | def preprocess_preds(dataset, scores, all_labels): 27 | """Return predictions and ground truth labels in common format.""" 28 | # Map of accession to map of label to score. 29 | accession_to_predictions = collections.defaultdict(dict) 30 | for row in scores: 31 | score = float(row["score"]) 32 | accession = row["inputs"]["accession"] 33 | label = row["inputs"]["label"] 34 | accession_to_predictions[accession][label] = score 35 | 36 | true_labels = [] 37 | pred_scores = [] 38 | for row in dataset: 39 | accession = row["accession"] 40 | predictions = accession_to_predictions[accession] 41 | gold_labels = set(row["labels"]) 42 | true_labels_row = [] 43 | pred_scores_row = [] 44 | for label in all_labels: 45 | true_label = 1 if label in gold_labels else 0 46 | pred_score = predictions.get(label, NEG_INF) 47 | true_labels_row.append(true_label) 48 | pred_scores_row.append(pred_score) 49 | true_labels.append(true_labels_row) 50 | pred_scores.append(pred_scores_row) 51 | 52 | return np.array(true_labels), np.array(pred_scores) 53 | 54 | 55 | def get_all_labels(dataset, predictions): 56 | """Return union of labels in predictions and dataset.""" 57 | all_labels = set() 58 | for row in dataset: 59 | for label in row["labels"]: 60 | all_labels.add(label) 61 | for row in predictions: 62 | all_labels.add(row["inputs"]["label"]) 63 | return all_labels 64 | -------------------------------------------------------------------------------- /eval/eval_micro_f1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Computes and prints maximum micro-averaged F1 score. 17 | 18 | This metric was used by "ProteInfer, deep networks for protein functional 19 | inference". 20 | """ 21 | 22 | from absl import app 23 | from absl import flags 24 | import numpy as np 25 | import sklearn.metrics 26 | 27 | from common import jsonl_utils 28 | from eval import eval_utils 29 | 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | 34 | flags.DEFINE_string("dataset", "", "Path to jsonl dataset file.") 35 | 36 | flags.DEFINE_string("predictions", "", "Path to jsonl predictions file.") 37 | 38 | 39 | def get_max_f1(true_labels, pred_scores): 40 | """Return maximum micro-averaged F1 score.""" 41 | true_labels = true_labels.flatten() 42 | pred_scores = pred_scores.flatten() 43 | precisions, recalls, thresholds = sklearn.metrics.precision_recall_curve( 44 | true_labels, pred_scores 45 | ) 46 | # The last values have no associated threshold. 47 | precisions = precisions[:-1] 48 | recalls = recalls[:-1] 49 | 50 | # Make denominator robust to zeros. 51 | denominator = np.where(precisions + recalls == 0, 1, precisions + recalls) 52 | f1_scores = 2 * precisions * recalls / denominator 53 | max_f1_score_idx = np.argmax(f1_scores) 54 | max_threshold = thresholds[max_f1_score_idx] 55 | max_f1 = f1_scores[max_f1_score_idx] 56 | print(f"max_threshold: {max_threshold}") 57 | return max_f1 58 | 59 | 60 | def main(unused_argv): 61 | predictions = jsonl_utils.read(FLAGS.predictions) 62 | dataset = jsonl_utils.read(FLAGS.dataset) 63 | all_labels = eval_utils.get_all_labels(dataset, predictions) 64 | true_labels, pred_scores = eval_utils.preprocess_preds( 65 | dataset, predictions, all_labels 66 | ) 67 | max_f1 = get_max_f1(true_labels, pred_scores) 68 | print(f"max_f1: {max_f1}") 69 | 70 | 71 | if __name__ == "__main__": 72 | app.run(main) 73 | -------------------------------------------------------------------------------- /data/convert_pfam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""Converts the PFAM seed clustered split. 17 | 18 | This data split is from the paper: 19 | "Using deep learning to annotate the protein universe" 20 | https://www.nature.com/articles/s41587-021-01179-w 21 | 22 | The data can be obtained with the following command: 23 | 24 | wget 25 | https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/clustered_split/pfam_clustered_split__train_dev_test.tar.gz 26 | 27 | tar xzf pfam_clustered_split__train_dev_test.tar.gz 28 | 29 | Sample command: 30 | python -m data.convert_pfam \ 31 | --input=/pfam_clustered_split/dev/* \ 32 | --output=/pfam_dev.jsonl 33 | """ 34 | 35 | from absl import app 36 | from absl import flags 37 | import pandas as pd 38 | import tensorflow as tf 39 | 40 | from common import jsonl_utils 41 | 42 | 43 | flags.DEFINE_string("input", "", "Path to PFAM input file pattern.") 44 | 45 | flags.DEFINE_string("output", "", "Output path for json file.") 46 | 47 | FLAGS = flags.FLAGS 48 | 49 | 50 | def read_df(file_pattern): 51 | file_paths = tf.io.gfile.glob(file_pattern) 52 | dfs = [] 53 | for file_path in file_paths: 54 | with tf.io.gfile.GFile(file_path, "r") as f: 55 | df = pd.read_csv(f) 56 | dfs.append(df) 57 | 58 | df = pd.concat(dfs).reset_index(drop=True) 59 | return df 60 | 61 | 62 | def convert_to_example(accession, sequence, label): 63 | return { 64 | "accession": accession, 65 | "sequence": sequence, 66 | "labels": [label], 67 | } 68 | 69 | 70 | def main(unused_argv): 71 | data_df = read_df(FLAGS.input) 72 | print(f"Loaded {data_df.shape[0]} rows.") 73 | 74 | accessions = data_df["sequence_name"].values.tolist() 75 | sequences = data_df["sequence"].values.tolist() 76 | labels = data_df["family_accession"].values.tolist() 77 | 78 | examples = [ 79 | convert_to_example(accession, sequence, label) 80 | for accession, sequence, label in zip(accessions, sequences, labels) 81 | ] 82 | jsonl_utils.write(FLAGS.output, examples) 83 | 84 | 85 | if __name__ == "__main__": 86 | app.run(main) 87 | -------------------------------------------------------------------------------- /blast/convert_blast.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""Converts BLAST scores to the same format as T5X inference output. 17 | 18 | Output format should have rows like: 19 | {"inputs": {"accession": "Q0HZU6", "label": "GO:0008150"}, "score": -0.04761} 20 | """ 21 | 22 | import collections 23 | 24 | from absl import app 25 | from absl import flags 26 | import tensorflow as tf 27 | 28 | from common import jsonl_utils 29 | 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | flags.DEFINE_string("input", "", "Path to blast output tsv.") 34 | 35 | flags.DEFINE_string("database_records", "", "Path to database records.") 36 | 37 | flags.DEFINE_string("output", "", "Path to write generated scores.") 38 | 39 | flags.DEFINE_integer("topk", 1, "Limit to this many neighbors per accession.") 40 | 41 | 42 | def _extract_accession(original_string): 43 | return original_string.replace('accession="', "").replace('"', "") 44 | 45 | 46 | def _read_blast_tsv(path): 47 | """Load TSV file generated by BLAST.""" 48 | rows = [] 49 | with tf.io.gfile.GFile(path, "r") as tsv_file: 50 | for line in tsv_file: 51 | line = line.rstrip() 52 | cols = line.split("\t") 53 | query_accession = _extract_accession(cols[0]) 54 | neighbor_accession = _extract_accession(cols[1]) 55 | # Column 11 is the bit score (i.e. alignment score). 56 | score = float(cols[11]) 57 | rows.append((query_accession, neighbor_accession, score)) 58 | print("Loaded %s rows from %s." % (len(rows), path)) 59 | return rows 60 | 61 | 62 | def _load_accession_to_labels_map(path): 63 | dataset = jsonl_utils.read(path) 64 | accession_to_labels = {} 65 | for record in dataset: 66 | accession_to_labels[record["accession"]] = record["label"] 67 | return accession_to_labels 68 | 69 | 70 | def _load_accession_to_neighbors_dict(path): 71 | accession_to_neighbors = collections.defaultdict(list) 72 | blast_rows = _read_blast_tsv(path) 73 | for query_accession, database_accession, score in blast_rows: 74 | accession_to_neighbors[query_accession].append((database_accession, score)) 75 | return accession_to_neighbors 76 | 77 | 78 | def main(unused_argv): 79 | accession_to_neighbors = _load_accession_to_neighbors_dict(FLAGS.input) 80 | accession_to_labels = _load_accession_to_labels_map(FLAGS.database_records) 81 | 82 | rows = [] 83 | for query_accession, neighbors in accession_to_neighbors.items(): 84 | for database_accession, score in neighbors[: FLAGS.topk]: 85 | labels = accession_to_labels[database_accession] 86 | for label in labels: 87 | json_dict = { 88 | "inputs": { 89 | "accession": query_accession, 90 | "label": label, 91 | }, 92 | "score": score, 93 | } 94 | rows.append(json_dict) 95 | jsonl_utils.write(FLAGS.output, rows) 96 | 97 | 98 | if __name__ == "__main__": 99 | app.run(main) 100 | -------------------------------------------------------------------------------- /data/convert_proteinfer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Converts ProteInfer file to common jsonl format. 17 | 18 | This dataset was proposed in the paper: 19 | "ProteInfer, deep networks for protein functional inference": 20 | https://google-research.github.io/proteinfer/ 21 | 22 | The dataset files are available here: 23 | https://console.cloud.google.com/storage/browser/brain-genomics-public/research/proteins/proteinfer/datasets/. 24 | """ 25 | 26 | import random 27 | 28 | from absl import app 29 | from absl import flags 30 | import tensorflow as tf 31 | 32 | from common import jsonl_utils 33 | 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | 38 | flags.DEFINE_string("input", "", "Path to ProteInfer file.") 39 | 40 | flags.DEFINE_string("output", "", "Output location for jsonl file.") 41 | 42 | flags.DEFINE_enum("labels", "ec", ["ec", "go"], "Which labels to use.") 43 | 44 | flags.DEFINE_float("sample_frac", 1.0, "Sample a fraction of the input.") 45 | 46 | 47 | def load_tf_examples(path): 48 | filepaths = tf.io.gfile.glob(path) 49 | dataset = tf.data.TFRecordDataset(filepaths) 50 | records = [] 51 | for raw_record in dataset: 52 | record = tf.train.Example.FromString(raw_record.numpy()) 53 | records.append(record) 54 | return records 55 | 56 | 57 | def get_bytes_feature(example: tf.train.Example, key: str) -> bytes: 58 | return example.features.feature[key].bytes_list.value[0] 59 | 60 | 61 | def get_text_feature(example: tf.train.Example, key: str) -> str: 62 | return get_bytes_feature(example, key).decode("utf-8") 63 | 64 | 65 | def get_repeated_text_feature(example: tf.train.Example, key: str) -> list[str]: 66 | values = [] 67 | for value in example.features.feature[key].bytes_list.value: 68 | values.append(value.decode("utf-8")) 69 | return values 70 | 71 | 72 | def filter_labels(labels): 73 | if FLAGS.labels == "go": 74 | return [label for label in labels if label.startswith("GO:")] 75 | elif FLAGS.labels == "ec": 76 | return [label for label in labels if label.startswith("EC:")] 77 | else: 78 | raise ValueError("Unknown label type: %s" % FLAGS.labels) 79 | 80 | 81 | def load_examples(path): 82 | """Load tfrecord file.""" 83 | examples = [] 84 | tf_examples = load_tf_examples(path) 85 | for example in tf_examples: 86 | sequence = get_text_feature(example, "sequence") 87 | accession = get_text_feature(example, "id") 88 | labels = get_repeated_text_feature(example, "label") 89 | labels = filter_labels(labels) 90 | example = { 91 | "sequence": sequence, 92 | "accession": accession, 93 | "labels": labels, 94 | } 95 | examples.append(example) 96 | return examples 97 | 98 | 99 | def main(unused_argv): 100 | examples = [] 101 | for file_path in tf.io.gfile.glob(FLAGS.input): 102 | examples.extend(load_examples(file_path)) 103 | if FLAGS.sample_frac < 1.0: 104 | cutoff = int(FLAGS.sample_frac * len(examples)) 105 | random.shuffle(examples) 106 | examples = examples[:cutoff] 107 | jsonl_utils.write(FLAGS.output, examples) 108 | 109 | 110 | if __name__ == "__main__": 111 | app.run(main) 112 | -------------------------------------------------------------------------------- /data/convert_deepfri.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Converts PDB-based EC data. 17 | 18 | This data is from the paper: 19 | "Structure-based protein function prediction using graph convolutional networks" 20 | https://www.nature.com/articles/s41467-021-23303-9 21 | 22 | The data is available here: 23 | https://github.com/flatironinstitute/DeepFRI/tree/master/preprocessing/data 24 | """ 25 | 26 | from absl import app 27 | from absl import flags 28 | import tensorflow as tf 29 | 30 | from common import jsonl_utils 31 | 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | 36 | flags.DEFINE_string("split", "", "Path to split file.") 37 | 38 | flags.DEFINE_string("sequences", "", "Path to sequences file.") 39 | 40 | flags.DEFINE_string("annotations", "", "Path to annotations file.") 41 | 42 | flags.DEFINE_string("output", "", "Path to write jsonl.") 43 | 44 | 45 | def read_txt(filepath): 46 | """Read newline separated text file.""" 47 | rows = [] 48 | with tf.io.gfile.GFile(filepath, "r") as tsv_file: 49 | for line in tsv_file: 50 | line = line.rstrip("\n") 51 | rows.append(line) 52 | print("Loaded %s rows from %s." % (len(rows), filepath)) 53 | return rows 54 | 55 | 56 | def read_fasta(filepath): 57 | """Parse FASTA file.""" 58 | description = None 59 | sequence = "" 60 | with tf.io.gfile.GFile(filepath, "r") as fp: 61 | for line in fp: 62 | line = line.strip(" \t\n\r") 63 | if line.startswith(">"): 64 | if description is not None: 65 | yield description, sequence 66 | description = line[1:] 67 | sequence = "" 68 | else: 69 | sequence += line 70 | yield description, sequence 71 | 72 | 73 | def write_examples(accession_to_labels, accession_to_sequence): 74 | """Write examples.""" 75 | accessions = read_txt(FLAGS.split) 76 | 77 | examples = [] 78 | for accession in accessions: 79 | labels = accession_to_labels[accession] 80 | sequence = accession_to_sequence[accession] 81 | # Note these are not UniProt accessions, but we will store them as such 82 | # so that the data fields match other datasets. 83 | example = { 84 | "accession": accession, 85 | "sequence": sequence, 86 | "labels": labels, 87 | } 88 | examples.append(example) 89 | 90 | jsonl_utils.write(FLAGS.output, examples) 91 | 92 | 93 | def main(unused_argv): 94 | # Load sequences. 95 | sequences_tuples = read_fasta(FLAGS.sequences) 96 | 97 | # Create id to sequence map. 98 | accession_to_sequence = {} 99 | for header, sequence in sequences_tuples: 100 | accession, meta = header.split(" ") # pytype: disable=attribute-error 101 | if meta != "nrPDB": 102 | raise ValueError(meta) 103 | accession_to_sequence[accession] = sequence 104 | 105 | # Load EC annotations. 106 | annotations_rows = read_txt(FLAGS.annotations) 107 | 108 | accession_to_labels = {} 109 | # Skip 3 header rows. 110 | for row in annotations_rows[3:]: 111 | accession, ec_list = row.split("\t") 112 | labels = ec_list.split(",") 113 | accession_to_labels[accession] = labels 114 | 115 | write_examples(accession_to_labels, accession_to_sequence) 116 | 117 | 118 | if __name__ == "__main__": 119 | app.run(main) 120 | -------------------------------------------------------------------------------- /eval/eval_protein_f1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Computes and prints maximum protein-centric F1 score. 17 | 18 | This is a commonly used metric for evaluating protein function prediction 19 | methods. Details can be found in "A large-scale evaluation of computational 20 | protein function prediction" (https://www.nature.com/articles/nmeth.2340). 21 | 22 | Note that more efficient implementations exist, such as this one that depends 23 | on PyTorch: 24 | https://github.com/DeepGraphLearning/torchdrug/blob/6066fbd82360abb5f270cba1eca560af01b8cc90/torchdrug/metrics/metric.py#L234 25 | 26 | However, our goal was to implement the metric in a way that is easier to 27 | understand and verify and without additional dependencies. 28 | """ 29 | 30 | from absl import app 31 | from absl import flags 32 | import numpy as np 33 | 34 | from common import jsonl_utils 35 | from eval import eval_utils 36 | 37 | 38 | FLAGS = flags.FLAGS 39 | 40 | 41 | flags.DEFINE_string("dataset", "", "Path to jsonl dataset file.") 42 | 43 | flags.DEFINE_string("predictions", "", "Path to jsonl predictions file.") 44 | 45 | flags.DEFINE_integer( 46 | "precision", 47 | 2, 48 | "Round scores to this many decimals if >0." 49 | "Helps speed up computation by considering fewer thresholds.", 50 | ) 51 | 52 | 53 | def get_counts(y_true, y_pred): 54 | y_true_and_pred = y_true * y_pred 55 | tp = np.sum(y_true_and_pred, axis=1) 56 | fp = np.sum(y_pred, axis=1) - tp 57 | fn = np.sum(y_true, axis=1) - tp 58 | return tp, fp, fn 59 | 60 | 61 | def get_protein_centric_f1(y_true, y_pred): 62 | """Computes protein-centric F1 score.""" 63 | tp, fp, fn = get_counts(y_true, y_pred) 64 | # If there are no predictions, then precision is undefined and does not count 65 | # towards the overall average, per the definition of protein-centric F1. 66 | # Set undefined values to 0. 67 | precision_num = np.divide( 68 | tp, tp + fp, out=np.zeros_like(tp, dtype=np.float32), where=tp != 0 69 | ) 70 | precision_denom = (tp + fp) > 0 71 | recall = tp / (tp + fn) 72 | precision_avg = np.sum(precision_num) / np.sum(precision_denom) 73 | recall_avg = np.mean(recall) 74 | f1 = 2 * precision_avg * recall_avg / (precision_avg + recall_avg) 75 | return f1 76 | 77 | 78 | def get_thresholds(preds): 79 | return np.sort(np.unique(preds.flatten())) 80 | 81 | 82 | def get_max_protein_centric_f1(pred_scores, target_labels): 83 | pred_scores = np.array(pred_scores) 84 | target_labels = np.array(target_labels) 85 | thresholds = get_thresholds(pred_scores) 86 | print(f"num thresholds: {len(thresholds)}") 87 | f1_scores = [] 88 | for threshold in thresholds: 89 | pred_labels = pred_scores >= threshold 90 | f1_scores.append(get_protein_centric_f1(target_labels, pred_labels)) 91 | return max(f1_scores) 92 | 93 | 94 | def main(unused_argv): 95 | predictions = jsonl_utils.read(FLAGS.predictions) 96 | dataset = jsonl_utils.read(FLAGS.dataset) 97 | all_labels = eval_utils.get_all_labels(dataset, predictions) 98 | true_labels, pred_scores = eval_utils.preprocess_preds( 99 | dataset, predictions, all_labels 100 | ) 101 | if FLAGS.precision > 0: 102 | pred_scores = np.around(pred_scores, decimals=FLAGS.precision) 103 | 104 | max_f1 = get_max_protein_centric_f1(pred_scores, true_labels) 105 | print(f"max_f1: {max_f1}") 106 | 107 | 108 | if __name__ == "__main__": 109 | app.run(main) 110 | -------------------------------------------------------------------------------- /eval/eval_pfam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Evaluate PFAM predictions. 17 | 18 | Metrics used are: 19 | -Family accuracy 20 | -Lifted clan accuracy. 21 | -Average per class family accuracy. 22 | 23 | The dataset should be in the jsonl format indicated by convert_pfam.py where 24 | each line is of the form: 25 | {"accession": "C6WKU9_ACTMD/45-114", 26 | "sequence": "ARNDCEF...." 27 | "labels": ["PF03793.19"]} 28 | 29 | The predictions should be in jsonl format where each line is of the form: 30 | {"accession": "C6WKU9_ACTMD/45-114", "predicted_label": "PF03793.19"} 31 | 32 | The family to clan mapping file can be found here: 33 | 34 | wget 35 | ftp://ftp.ebi.ac.uk/pub/databases/Pfam/releases/Pfam32.0/Pfam-A.clans.tsv.gz . 36 | gzip -d Pfam-A.clans.tsv.gz 37 | """ 38 | 39 | import collections 40 | 41 | from absl import app 42 | from absl import flags 43 | import numpy as np 44 | import pandas as pd 45 | 46 | from common import jsonl_utils 47 | from eval import eval_pfam_utils 48 | 49 | FLAGS = flags.FLAGS 50 | 51 | 52 | flags.DEFINE_string('dataset', '', 'Path to jsonl dataset file.') 53 | 54 | flags.DEFINE_string('predictions', '', 'Path to jsonl predictions file.') 55 | 56 | flags.DEFINE_string('clan_mapping', None, 'Path to clan mapping file.') 57 | 58 | 59 | def convert_predictions_to_df(jsonl_data): 60 | accessions = [] 61 | predicted_labels = [] 62 | for row in jsonl_data: 63 | accessions.append(row['accession']) 64 | predicted_labels.append(row['predicted_label']) 65 | 66 | return pd.DataFrame( 67 | {'accession': accessions, 'predicted_label': predicted_labels} 68 | ) 69 | 70 | 71 | def convert_dataset_to_df(jsonl_data): 72 | accessions = [] 73 | labels = [] 74 | for row in jsonl_data: 75 | accessions.append(row['accession']) 76 | assert len(row['labels']) == 1 77 | labels.append(row['labels'][0]) 78 | 79 | return pd.DataFrame({'accession': accessions, 'true_label': labels}) 80 | 81 | 82 | def mean_per_class_accuracy(predictions_dataframe): 83 | """Compute accuracy of predictions, giving equal weight to all classes. 84 | 85 | Args: 86 | predictions_dataframe: pandas DataFrame with 3 columns, 87 | classification_util.PREDICTION_FILE_COLUMN_NAMES. 88 | 89 | Returns: 90 | float. The average of all class-level accuracies. 91 | """ 92 | grouped_predictions = collections.defaultdict(list) 93 | for row in predictions_dataframe.itertuples(): 94 | grouped_predictions[row.true_label].append(row.predicted_label) 95 | 96 | accuracy_per_class = { 97 | true_label: np.mean(predicted_label == np.array(true_label)) 98 | for true_label, predicted_label in grouped_predictions.items() 99 | } 100 | 101 | return np.mean(list(accuracy_per_class.values())) 102 | 103 | 104 | def main(unused_argv): 105 | predictions = jsonl_utils.read(FLAGS.predictions) 106 | dataset = jsonl_utils.read(FLAGS.dataset) 107 | 108 | prediction_df = convert_predictions_to_df(predictions) 109 | reference_df = convert_dataset_to_df(dataset) 110 | 111 | merged_df = prediction_df.merge( 112 | reference_df, on='accession', how='inner', validate='one_to_one' 113 | ) 114 | 115 | family_accuracy = eval_pfam_utils.raw_unweighted_accuracy(merged_df) 116 | per_class_family_accuracy = eval_pfam_utils.mean_per_class_accuracy(merged_df) 117 | 118 | print('Family accuracy: %.1f' % (family_accuracy * 100)) 119 | print('Avg. Per-Family accuracy: %.1f' % (per_class_family_accuracy * 100)) 120 | 121 | if FLAGS.clan_mapping: 122 | lifted_clan_accuracy = eval_pfam_utils.get_unweighted_lifted_clan_accuracy( 123 | merged_df, FLAGS.clan_mapping 124 | ) 125 | print('Lifted Clan accuracy: %.1f' % (lifted_clan_accuracy * 100)) 126 | 127 | 128 | if __name__ == '__main__': 129 | app.run(main) 130 | -------------------------------------------------------------------------------- /eval/eval_pfam_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities for evaluating pfam predictions.""" 17 | 18 | import collections 19 | 20 | import numpy as np 21 | import pandas as pd 22 | import tensorflow as tf 23 | 24 | 25 | def read_pfam_clan_file(path: str) -> pd.DataFrame: 26 | """Parses pfam clan tsv file. 27 | 28 | Args: 29 | path: Path to tsv clan file. 30 | 31 | Returns: 32 | pd.DataFrame with columns family_accession (str), clan_accession (str), 33 | clan_description (str), family_name (str), family_description (str). 34 | """ 35 | with tf.io.gfile.GFile(path, 'r') as f: 36 | return pd.read_csv( 37 | f, 38 | names=[ 39 | 'family_accession', 40 | 'clan_accession', 41 | 'clan_description', 42 | 'family_name', 43 | 'family_description', 44 | ], 45 | sep='\t', 46 | # Some fields are missing, and we want to keep those 47 | # as empty strings instead of the default behavior, 48 | # which is to convert them to NaNs. 49 | keep_default_na=False, 50 | ) 51 | 52 | 53 | def family_to_clan_mapping(path: str) -> dict[str, str]: 54 | """Parse tsv contents, returning dict from pfam family to clan accession. 55 | 56 | Families without a clan will get their own clan in 57 | the returned dictionary, with clan name == to the accession (e.g. PF12345 58 | -> PF12345). 59 | 60 | Args: 61 | path: Path to tsv clan file. 62 | 63 | Returns: 64 | dict from string to string, e.g. {'PF12345': 'CL9999'}. 65 | """ 66 | dataframe = read_pfam_clan_file(path) 67 | 68 | dataframe['clan_accession'] = dataframe.apply( 69 | axis='columns', 70 | func=lambda row: row.clan_accession # pylint: disable=g-long-lambda 71 | if row.clan_accession 72 | else row.family_accession, 73 | ) 74 | 75 | # Filter family names without clans (they are are stored in the csv 76 | # as empty strings). If we're using lifted clan semantics, every family will 77 | # have a clan (see docstring). 78 | return dict( 79 | (family_id, clan_id) # pylint: disable=g-complex-comprehension 80 | for family_id, clan_id in zip( 81 | dataframe['family_accession'].values, 82 | dataframe['clan_accession'].values, 83 | ) 84 | if clan_id 85 | ) 86 | 87 | 88 | def raw_unweighted_accuracy( 89 | predictions_df: pd.DataFrame, 90 | true_label_column: str = 'true_label', 91 | predicted_label_column: str = 'predicted_label', 92 | ) -> float: 93 | """Compute accuracy, regardless of which class each prediction corresponds to. 94 | 95 | Args: 96 | predictions_df: pandas DataFrame with at least 2 columns, true_label and 97 | predicted_label. 98 | true_label_column: Column name of true labels. 99 | predicted_label_column: str. Column name of predicted labels. 100 | 101 | Returns: 102 | Accuracy. 103 | """ 104 | num_correct = ( 105 | predictions_df[true_label_column] 106 | == predictions_df[predicted_label_column] 107 | ).sum() 108 | total = len(predictions_df) 109 | return num_correct / total 110 | 111 | 112 | def mean_per_class_accuracy( 113 | predictions_df: pd.DataFrame, 114 | true_label_column: str = 'true_label', 115 | predicted_label_column: str = 'predicted_label', 116 | ) -> float: 117 | """Compute accuracy of predictions, giving equal weight to all classes. 118 | 119 | Args: 120 | predictions_df: pandas DataFrame with at least 2 columns, true_label and 121 | predicted_label. 122 | true_label_column: Column name of true labels. 123 | predicted_label_column: str. Column name of predicted labels. 124 | 125 | Returns: 126 | The average of all class-level accuracies. 127 | """ 128 | grouped_predictions = collections.defaultdict(list) 129 | for _, row in predictions_df.iterrows(): 130 | grouped_predictions[row[true_label_column]].append( 131 | row[predicted_label_column] 132 | ) 133 | 134 | accuracy_per_class = { 135 | true_label: np.mean(predicted_label == np.array(true_label)) 136 | for true_label, predicted_label in grouped_predictions.items() 137 | } 138 | 139 | return np.mean(list(accuracy_per_class.values())) 140 | 141 | 142 | def get_unweighted_lifted_clan_accuracy( 143 | predictions_df: pd.DataFrame, 144 | clan_mapping_path: str, 145 | true_label_column: str = 'true_label', 146 | predicted_label_column: str = 'predicted_label', 147 | ): 148 | """Compute accuracy, where each label is mapped to its clan. 149 | 150 | Args: 151 | predictions_df: pandas DataFrame with at least 2 columns, true_label and 152 | predicted_label. 153 | clan_mapping_path: Path to tsv clan file. 154 | true_label_column: Column name of true labels. 155 | predicted_label_column: Column name of predicted labels. 156 | 157 | Returns: 158 | Lifted Clan Accuracy. 159 | """ 160 | 161 | def pfam_accession_helper(x): 162 | x_split = x.split('.') 163 | assert len(x_split) == 2 164 | return x_split[0] 165 | 166 | family_to_clan = family_to_clan_mapping(clan_mapping_path) 167 | 168 | predictions_df['true_clan_label'] = predictions_df[true_label_column].apply( 169 | lambda x: family_to_clan.get( 170 | pfam_accession_helper(x), pfam_accession_helper(x) 171 | ) 172 | ) 173 | 174 | predictions_df['predicted_clan_label'] = predictions_df[ 175 | predicted_label_column 176 | ].apply( 177 | lambda x: family_to_clan.get( 178 | pfam_accession_helper(x), pfam_accession_helper(x) 179 | ) 180 | ) 181 | 182 | return raw_unweighted_accuracy( 183 | predictions_df, 184 | true_label_column='true_clan_label', 185 | predicted_label_column='predicted_clan_label', 186 | ) 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProtEx 2 | 3 | This repository contains open source code related to the paper [ProtEx: A Retrieval-Augmented Approach for Protein Function Prediction](https://www.biorxiv.org/content/10.1101/2024.05.30.596539v1). 4 | 5 | ## Installation 6 | 7 | Clone the repository: 8 | 9 | ```shell 10 | git clone https://github.com/google-deepmind/protex.git 11 | ``` 12 | 13 | It is then recommended to setup a virtual environment. We provide an example 14 | using `conda`: 15 | 16 | ```shell 17 | conda create -n protex python=3.10 18 | conda activate protex 19 | ``` 20 | 21 | Then install dependencies specified in `setup.py`: 22 | 23 | ```shell 24 | pip install . 25 | ``` 26 | 27 | ## Overview 28 | 29 | The code, along with the released model predictions, support reproducing the main results from the paper. The code is organized as follows: 30 | 31 | * `blast/` - Contains conversion scripts for reproducing BLAST results. 32 | * `common/` - Some common utility libraries. 33 | * `data/` - Contains conversion scripts for various datasets to a common format. 34 | * `eval/` - Contains tools for computing various evaluation metrics. 35 | 36 | We convert datasets to a common format consisting of newline separated json files, where each has the following keys: 37 | 38 | * `sequence` - String of protein sequence. 39 | * `accession` - String for unique identifier, e.g. UniProt accession. 40 | * `labels` - List of strings for labels, e.g. EC numbers. 41 | 42 | ## Usage Examples 43 | 44 | ### ProteInfer 45 | 46 | Here we provide a usage example focused on reproducing the results for the ProteInfer dataset for the clustered EC split. Conversion and evaluation scripts for other datasets 47 | can be found in `data/` and `/eval`, and usages are similar. 48 | 49 | The [original dataset](https://google-research.github.io/proteinfer/) is available on GCP at `gs:///brain-genomics-public/research/proteins/proteinfer/datasets/swissprot/`. We can set our input to the path to the EC clustered test split: 50 | 51 | ```shell 52 | CLUSTERED_EC_TEST_TFR="gs://brain-genomics-public/research/proteins/proteinfer/datasets/swissprot/clustered/test.tfrecord" 53 | ``` 54 | 55 | We will assume that the variable `DATA_DIR` is set to readable and writable 56 | directory, such as `DATA_DIR=/tmp/`. 57 | 58 | We can then run the data conversion script: 59 | 60 | ```shell 61 | CLUSTERED_EC_TEST_JSONL="${DATA_DIR}/proteinfer_clustered_ec_test.jsonl" 62 | python -m data.convert_proteinfer \ 63 | --alsologtostderr \ 64 | --input=${CLUSTERED_EC_TEST_TFR} \ 65 | --output=${CLUSTERED_EC_TEST_JSONL} \ 66 | --labels=ec 67 | ``` 68 | 69 | Model predictions for ProtEx on all test splits are available at `gs://protex/predictions`. Specifically, the clustered EC predictions are here: 70 | 71 | ``` 72 | PREDS_PROTEX=gs://protex/predictions/proteinfer-clustered-ec-test-protex.jsonl 73 | ``` 74 | 75 | We can then reproduce the max micro-averaged F1 metrics reported for this split with the following script: 76 | 77 | ```shell 78 | python -m eval.eval_micro_f1 \ 79 | --alsologtostderr \ 80 | --dataset=${CLUSTERED_EC_TEST_JSONL} \ 81 | --predictions=${PREDS_PROTEX} 82 | ``` 83 | 84 | We also released BLAST predictions, so the above script can also be used with the following `--predictions` argument to reproduce the reported BLAST results: 85 | 86 | ``` 87 | PREDS_BLAST=gs://protex/predictions/proteinfer-clustered-ec-test-protex.jsonl 88 | ``` 89 | 90 | #### Reproducing BLAST 91 | 92 | We also released code to reproduce the BLAST predictions. For this we need to also convert the ProteInfer training set: 93 | 94 | ```shell 95 | CLUSTERED_EC_TRAIN_TFR="gs://brain-genomics-public/research/proteins/proteinfer/datasets/swissprot/clustered/train.tfrecord" 96 | CLUSTERED_EC_TRAIN_JSONL="${DATA_DIR}/proteinfer_clustered_ec_train.jsonl 97 | python -m data.convert_proteinfer \ 98 | --alsologtostderr \ 99 | --input=${CLUSTERED_EC_TRAIN_TFR} \ 100 | --output=${CLUSTERED_EC_TRAIN_JSONL} \ 101 | --labels=ec 102 | ``` 103 | 104 | We then need to convert both train and test splits to `.fasta` format: 105 | 106 | ```shell 107 | CLUSTERED_EC_TRAIN_FASTA="${DATA_DIR}/proteinfer_clustered_ec_train.fasta 108 | python -m blast.convert_to_fasta \ 109 | --alsologtostderr \ 110 | --input=${CLUSTERED_EC_TRAIN_JSONL} \ 111 | --output=${CLUSTERED_EC_TRAIN_FASTA} 112 | 113 | CLUSTERED_EC_TEST_FASTA="${DATA_DIR}/proteinfer_clustered_ec_test.fasta 114 | python -m blast.convert_to_fasta \ 115 | --alsologtostderr \ 116 | --input=${CLUSTERED_EC_TEST_JSONL} \ 117 | --output=${CLUSTERED_EC_TEST_FASTA} 118 | ``` 119 | 120 | Note that if `DATA_DIR` refers to a GCP bucket rather than a local directory, the files may need to be copied locally so that they can be read by the BLAST command line tool before proceeding to the next step. We will assume `BLAST_DIR` is set to the location of the BLAST binaries, 121 | e.g. `BLAST_DIR=".../ncbi-blast-2.14.1+/bin"`. 122 | 123 | We can then run BLAST. 124 | 125 | ```shell 126 | BLAST_TSV="${DATA_DIR}/blast_proteinfer_clustered_ec_test.tsv" 127 | ${BLAST_DIR}/makeblastdb -in ${CLUSTERED_EC_TRAIN_FASTA} -dbtype prot 128 | ${BLAST_DIR}/blastp -query ${CLUSTERED_EC_TEST_FASTA} -db ${CLUSTERED_EC_TRAIN_FASTA} -outfmt 6 -max_hsps 1 -num_threads 16 -max_target_seqs 1 -out ${BLAST_TSV} 129 | ``` 130 | 131 | Finally, we can convert the tsv file generated by BLAST to the standard predictions format we are using: 132 | 133 | ```shell 134 | BLAST_JSONL=${DATA_DIR}/blast_proteinfer_clustered_ec_test.jsonl 135 | python -m blast.convert_blast \ 136 | --alsologtostderr \ 137 | --input=${BLAST_TSV} \ 138 | --database_records=${CLUSTERED_EC_TRAIN_FASTA} \ 139 | --output=${BLAST_JSONL} 140 | ``` 141 | 142 | ## Citing this work 143 | 144 | You can cite the preprint of our work as follows: 145 | 146 | ```latex 147 | @article{shaw2024protex, 148 | title={ProtEx: A Retrieval-Augmented Approach for Protein Function Prediction}, 149 | author={Shaw, Peter and Gurram, Bhaskar and Belanger, David and Gane, Andreea and Bileschi, Maxwell L and Colwell, Lucy J and Toutanova, Kristina and Parikh, Ankur P}, 150 | journal={bioRxiv}, 151 | URL = {https://www.biorxiv.org/content/early/2024/06/02/2024.05.30.596539}, 152 | year={2024}, 153 | } 154 | ``` 155 | 156 | ## License and disclaimer 157 | 158 | Copyright 2024 DeepMind Technologies Limited 159 | 160 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 161 | you may not use this file except in compliance with the Apache 2.0 license. 162 | You may obtain a copy of the Apache 2.0 license at: 163 | https://www.apache.org/licenses/LICENSE-2.0 164 | 165 | All other materials are licensed under the Creative Commons Attribution 4.0 166 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 167 | https://creativecommons.org/licenses/by/4.0/legalcode 168 | 169 | Unless required by applicable law or agreed to in writing, all software and 170 | materials distributed here under the Apache 2.0 or CC-BY licenses are 171 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 172 | either express or implied. See the licenses for the specific language governing 173 | permissions and limitations under those licenses. 174 | 175 | This is not an official Google product. 176 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | --------------------------------------------------------------------------------