├── .gitignore
├── CONTRIBUTING.md
├── setup.py
├── blast
├── convert_to_fasta.py
└── convert_blast.py
├── common
└── jsonl_utils.py
├── eval
├── eval_auc.py
├── eval_utils.py
├── eval_micro_f1.py
├── eval_protein_f1.py
├── eval_pfam.py
└── eval_pfam_utils.py
├── data
├── convert_clean.py
├── convert_pfam.py
├── convert_proteinfer.py
└── convert_deepfri.py
├── README.md
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Setup for ProtEx."""
17 |
18 | import setuptools
19 |
20 | REQUIRED_PACKAGES = ["absl-py", "tensorflow", "numpy", "scikit-learn"]
21 |
22 |
23 | setuptools.setup(
24 | name="protex",
25 | description="Code related to protein function prediction with ProtEx.",
26 | packages=setuptools.find_packages(),
27 | install_requires=REQUIRED_PACKAGES,
28 | license="Apache 2.0",
29 | )
30 |
--------------------------------------------------------------------------------
/blast/convert_to_fasta.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Convert from examples jsonl to fasta format.
17 |
18 | Example of protein in fasta file format:
19 |
20 | >accession="Q0WJ82"
21 | MEKTQSVFIRFIVNGSLVKQILIGLVAGIVLALVST...
22 | """
23 |
24 | from absl import app
25 | from absl import flags
26 | import tensorflow as tf
27 |
28 | from common import jsonl_utils
29 |
30 |
31 | FLAGS = flags.FLAGS
32 |
33 |
34 | flags.DEFINE_string("input", "", "Path to examples.")
35 |
36 | flags.DEFINE_string("output", "", "Fasta output file.")
37 |
38 |
39 | def main(unused_argv):
40 | examples = jsonl_utils.read(FLAGS.input)
41 | with tf.io.gfile.GFile(FLAGS.output, "w") as fp:
42 | for example in examples:
43 | fp.write(f'>accession="{example["accession"]}"\n')
44 | fp.write(f'{example["sequence"]}\n')
45 |
46 |
47 | if __name__ == "__main__":
48 | app.run(main)
49 |
--------------------------------------------------------------------------------
/common/jsonl_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilities for reading and writing jsonl files."""
17 |
18 | import json
19 |
20 | from tensorflow.io import gfile
21 |
22 |
23 | def read(filepath, verbose=True):
24 | """Read jsonl file to a List of Dicts."""
25 | data = []
26 | with gfile.GFile(filepath, "r") as jsonl_file:
27 | for idx, line in enumerate(jsonl_file):
28 | if verbose and idx % 1000 == 0:
29 | # Print the index every 1000 lines.
30 | print("Processing line %s." % idx)
31 | try:
32 | data.append(json.loads(line))
33 | except json.JSONDecodeError as e:
34 | print("Failed to parse line: `%s`" % line)
35 | raise e
36 | if verbose:
37 | print("Loaded %s lines from %s." % (len(data), filepath))
38 | return data
39 |
40 |
41 | def write(filepath, rows):
42 | """Write a List of Dicts to jsonl file."""
43 | with gfile.GFile(filepath, "w") as jsonl_file:
44 | for row in rows:
45 | line = "%s\n" % json.dumps(row)
46 | jsonl_file.write(line)
47 | print("Wrote %s lines to %s." % (len(rows), filepath))
48 |
--------------------------------------------------------------------------------
/eval/eval_auc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Computes and prints weighted AUC metric.
17 |
18 | This metric was used in "Enzyme Function Prediction using Contrastive Learning".
19 | """
20 |
21 | from absl import app
22 | from absl import flags
23 | import sklearn.metrics
24 |
25 | from common import jsonl_utils
26 | from eval import eval_utils
27 |
28 |
29 | FLAGS = flags.FLAGS
30 |
31 |
32 | flags.DEFINE_string("dataset", "", "Path to jsonl dataset file.")
33 |
34 | flags.DEFINE_string("predictions", "", "Path to jsonl predictions file.")
35 |
36 |
37 | def get_test_labels(dataset):
38 | all_labels = set()
39 | for row in dataset:
40 | for label in row["labels"]:
41 | all_labels.add(label)
42 | return all_labels
43 |
44 |
45 | def get_auc(true_labels, pred_scores):
46 | """Return AUC."""
47 | return sklearn.metrics.roc_auc_score(
48 | true_labels, pred_scores, average="weighted"
49 | )
50 |
51 |
52 | def main(unused_argv):
53 | predictions = jsonl_utils.read(FLAGS.predictions)
54 | dataset = jsonl_utils.read(FLAGS.dataset)
55 | # Only labels occurring in the test set are considered for this metric.
56 | all_labels = get_test_labels(dataset)
57 | true_labels, pred_scores = eval_utils.preprocess_preds(
58 | dataset, predictions, all_labels
59 | )
60 | auc = get_auc(true_labels, pred_scores)
61 | print(f"auc: {auc}")
62 |
63 |
64 | if __name__ == "__main__":
65 | app.run(main)
66 |
--------------------------------------------------------------------------------
/data/convert_clean.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Converts CLEAN files to common format.
17 |
18 | This dataset was proposed in the paper:
19 | "Enzyme Function Prediction using Contrastive Learning"
20 | https://www.science.org/doi/10.1126/science.adf2465
21 |
22 | The dataset files are available here:
23 | https://github.com/tttianhao/CLEAN/tree/main/app/data
24 | """
25 |
26 | import csv
27 |
28 | from absl import app
29 | from absl import flags
30 | import tensorflow as tf
31 |
32 | from common import jsonl_utils
33 |
34 |
35 | FLAGS = flags.FLAGS
36 |
37 |
38 | flags.DEFINE_string("input", "", "Path to input file.")
39 |
40 | flags.DEFINE_string("output", "", "Output location for jsonl file.")
41 |
42 |
43 | def load_tsv(filepath):
44 | dicts = []
45 | with tf.io.gfile.GFile(filepath, "r") as f:
46 | reader = csv.DictReader(f, delimiter="\t")
47 | for row_dict in reader:
48 | dicts.append(row_dict)
49 | return dicts
50 |
51 |
52 | def convert_to_example(row_dict):
53 | return {
54 | "accession": row_dict["Entry"],
55 | "sequence": row_dict["Sequence"],
56 | "labels": row_dict["EC number"].split(";"),
57 | }
58 |
59 |
60 | def main(unused_argv):
61 | rows = load_tsv(FLAGS.input)
62 | print(f"Loaded {len(rows)} rows.")
63 | examples = [convert_to_example(row_dict) for row_dict in rows]
64 | jsonl_utils.write(FLAGS.output, examples)
65 |
66 |
67 | if __name__ == "__main__":
68 | app.run(main)
69 |
--------------------------------------------------------------------------------
/eval/eval_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Common utilities for evaluation."""
17 |
18 | import collections
19 |
20 | import numpy as np
21 |
22 |
23 | NEG_INF = -1e9
24 |
25 |
26 | def preprocess_preds(dataset, scores, all_labels):
27 | """Return predictions and ground truth labels in common format."""
28 | # Map of accession to map of label to score.
29 | accession_to_predictions = collections.defaultdict(dict)
30 | for row in scores:
31 | score = float(row["score"])
32 | accession = row["inputs"]["accession"]
33 | label = row["inputs"]["label"]
34 | accession_to_predictions[accession][label] = score
35 |
36 | true_labels = []
37 | pred_scores = []
38 | for row in dataset:
39 | accession = row["accession"]
40 | predictions = accession_to_predictions[accession]
41 | gold_labels = set(row["labels"])
42 | true_labels_row = []
43 | pred_scores_row = []
44 | for label in all_labels:
45 | true_label = 1 if label in gold_labels else 0
46 | pred_score = predictions.get(label, NEG_INF)
47 | true_labels_row.append(true_label)
48 | pred_scores_row.append(pred_score)
49 | true_labels.append(true_labels_row)
50 | pred_scores.append(pred_scores_row)
51 |
52 | return np.array(true_labels), np.array(pred_scores)
53 |
54 |
55 | def get_all_labels(dataset, predictions):
56 | """Return union of labels in predictions and dataset."""
57 | all_labels = set()
58 | for row in dataset:
59 | for label in row["labels"]:
60 | all_labels.add(label)
61 | for row in predictions:
62 | all_labels.add(row["inputs"]["label"])
63 | return all_labels
64 |
--------------------------------------------------------------------------------
/eval/eval_micro_f1.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Computes and prints maximum micro-averaged F1 score.
17 |
18 | This metric was used by "ProteInfer, deep networks for protein functional
19 | inference".
20 | """
21 |
22 | from absl import app
23 | from absl import flags
24 | import numpy as np
25 | import sklearn.metrics
26 |
27 | from common import jsonl_utils
28 | from eval import eval_utils
29 |
30 |
31 | FLAGS = flags.FLAGS
32 |
33 |
34 | flags.DEFINE_string("dataset", "", "Path to jsonl dataset file.")
35 |
36 | flags.DEFINE_string("predictions", "", "Path to jsonl predictions file.")
37 |
38 |
39 | def get_max_f1(true_labels, pred_scores):
40 | """Return maximum micro-averaged F1 score."""
41 | true_labels = true_labels.flatten()
42 | pred_scores = pred_scores.flatten()
43 | precisions, recalls, thresholds = sklearn.metrics.precision_recall_curve(
44 | true_labels, pred_scores
45 | )
46 | # The last values have no associated threshold.
47 | precisions = precisions[:-1]
48 | recalls = recalls[:-1]
49 |
50 | # Make denominator robust to zeros.
51 | denominator = np.where(precisions + recalls == 0, 1, precisions + recalls)
52 | f1_scores = 2 * precisions * recalls / denominator
53 | max_f1_score_idx = np.argmax(f1_scores)
54 | max_threshold = thresholds[max_f1_score_idx]
55 | max_f1 = f1_scores[max_f1_score_idx]
56 | print(f"max_threshold: {max_threshold}")
57 | return max_f1
58 |
59 |
60 | def main(unused_argv):
61 | predictions = jsonl_utils.read(FLAGS.predictions)
62 | dataset = jsonl_utils.read(FLAGS.dataset)
63 | all_labels = eval_utils.get_all_labels(dataset, predictions)
64 | true_labels, pred_scores = eval_utils.preprocess_preds(
65 | dataset, predictions, all_labels
66 | )
67 | max_f1 = get_max_f1(true_labels, pred_scores)
68 | print(f"max_f1: {max_f1}")
69 |
70 |
71 | if __name__ == "__main__":
72 | app.run(main)
73 |
--------------------------------------------------------------------------------
/data/convert_pfam.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""Converts the PFAM seed clustered split.
17 |
18 | This data split is from the paper:
19 | "Using deep learning to annotate the protein universe"
20 | https://www.nature.com/articles/s41587-021-01179-w
21 |
22 | The data can be obtained with the following command:
23 |
24 | wget
25 | https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/clustered_split/pfam_clustered_split__train_dev_test.tar.gz
26 |
27 | tar xzf pfam_clustered_split__train_dev_test.tar.gz
28 |
29 | Sample command:
30 | python -m data.convert_pfam \
31 | --input=/pfam_clustered_split/dev/* \
32 | --output=/pfam_dev.jsonl
33 | """
34 |
35 | from absl import app
36 | from absl import flags
37 | import pandas as pd
38 | import tensorflow as tf
39 |
40 | from common import jsonl_utils
41 |
42 |
43 | flags.DEFINE_string("input", "", "Path to PFAM input file pattern.")
44 |
45 | flags.DEFINE_string("output", "", "Output path for json file.")
46 |
47 | FLAGS = flags.FLAGS
48 |
49 |
50 | def read_df(file_pattern):
51 | file_paths = tf.io.gfile.glob(file_pattern)
52 | dfs = []
53 | for file_path in file_paths:
54 | with tf.io.gfile.GFile(file_path, "r") as f:
55 | df = pd.read_csv(f)
56 | dfs.append(df)
57 |
58 | df = pd.concat(dfs).reset_index(drop=True)
59 | return df
60 |
61 |
62 | def convert_to_example(accession, sequence, label):
63 | return {
64 | "accession": accession,
65 | "sequence": sequence,
66 | "labels": [label],
67 | }
68 |
69 |
70 | def main(unused_argv):
71 | data_df = read_df(FLAGS.input)
72 | print(f"Loaded {data_df.shape[0]} rows.")
73 |
74 | accessions = data_df["sequence_name"].values.tolist()
75 | sequences = data_df["sequence"].values.tolist()
76 | labels = data_df["family_accession"].values.tolist()
77 |
78 | examples = [
79 | convert_to_example(accession, sequence, label)
80 | for accession, sequence, label in zip(accessions, sequences, labels)
81 | ]
82 | jsonl_utils.write(FLAGS.output, examples)
83 |
84 |
85 | if __name__ == "__main__":
86 | app.run(main)
87 |
--------------------------------------------------------------------------------
/blast/convert_blast.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | r"""Converts BLAST scores to the same format as T5X inference output.
17 |
18 | Output format should have rows like:
19 | {"inputs": {"accession": "Q0HZU6", "label": "GO:0008150"}, "score": -0.04761}
20 | """
21 |
22 | import collections
23 |
24 | from absl import app
25 | from absl import flags
26 | import tensorflow as tf
27 |
28 | from common import jsonl_utils
29 |
30 |
31 | FLAGS = flags.FLAGS
32 |
33 | flags.DEFINE_string("input", "", "Path to blast output tsv.")
34 |
35 | flags.DEFINE_string("database_records", "", "Path to database records.")
36 |
37 | flags.DEFINE_string("output", "", "Path to write generated scores.")
38 |
39 | flags.DEFINE_integer("topk", 1, "Limit to this many neighbors per accession.")
40 |
41 |
42 | def _extract_accession(original_string):
43 | return original_string.replace('accession="', "").replace('"', "")
44 |
45 |
46 | def _read_blast_tsv(path):
47 | """Load TSV file generated by BLAST."""
48 | rows = []
49 | with tf.io.gfile.GFile(path, "r") as tsv_file:
50 | for line in tsv_file:
51 | line = line.rstrip()
52 | cols = line.split("\t")
53 | query_accession = _extract_accession(cols[0])
54 | neighbor_accession = _extract_accession(cols[1])
55 | # Column 11 is the bit score (i.e. alignment score).
56 | score = float(cols[11])
57 | rows.append((query_accession, neighbor_accession, score))
58 | print("Loaded %s rows from %s." % (len(rows), path))
59 | return rows
60 |
61 |
62 | def _load_accession_to_labels_map(path):
63 | dataset = jsonl_utils.read(path)
64 | accession_to_labels = {}
65 | for record in dataset:
66 | accession_to_labels[record["accession"]] = record["label"]
67 | return accession_to_labels
68 |
69 |
70 | def _load_accession_to_neighbors_dict(path):
71 | accession_to_neighbors = collections.defaultdict(list)
72 | blast_rows = _read_blast_tsv(path)
73 | for query_accession, database_accession, score in blast_rows:
74 | accession_to_neighbors[query_accession].append((database_accession, score))
75 | return accession_to_neighbors
76 |
77 |
78 | def main(unused_argv):
79 | accession_to_neighbors = _load_accession_to_neighbors_dict(FLAGS.input)
80 | accession_to_labels = _load_accession_to_labels_map(FLAGS.database_records)
81 |
82 | rows = []
83 | for query_accession, neighbors in accession_to_neighbors.items():
84 | for database_accession, score in neighbors[: FLAGS.topk]:
85 | labels = accession_to_labels[database_accession]
86 | for label in labels:
87 | json_dict = {
88 | "inputs": {
89 | "accession": query_accession,
90 | "label": label,
91 | },
92 | "score": score,
93 | }
94 | rows.append(json_dict)
95 | jsonl_utils.write(FLAGS.output, rows)
96 |
97 |
98 | if __name__ == "__main__":
99 | app.run(main)
100 |
--------------------------------------------------------------------------------
/data/convert_proteinfer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Converts ProteInfer file to common jsonl format.
17 |
18 | This dataset was proposed in the paper:
19 | "ProteInfer, deep networks for protein functional inference":
20 | https://google-research.github.io/proteinfer/
21 |
22 | The dataset files are available here:
23 | https://console.cloud.google.com/storage/browser/brain-genomics-public/research/proteins/proteinfer/datasets/.
24 | """
25 |
26 | import random
27 |
28 | from absl import app
29 | from absl import flags
30 | import tensorflow as tf
31 |
32 | from common import jsonl_utils
33 |
34 |
35 | FLAGS = flags.FLAGS
36 |
37 |
38 | flags.DEFINE_string("input", "", "Path to ProteInfer file.")
39 |
40 | flags.DEFINE_string("output", "", "Output location for jsonl file.")
41 |
42 | flags.DEFINE_enum("labels", "ec", ["ec", "go"], "Which labels to use.")
43 |
44 | flags.DEFINE_float("sample_frac", 1.0, "Sample a fraction of the input.")
45 |
46 |
47 | def load_tf_examples(path):
48 | filepaths = tf.io.gfile.glob(path)
49 | dataset = tf.data.TFRecordDataset(filepaths)
50 | records = []
51 | for raw_record in dataset:
52 | record = tf.train.Example.FromString(raw_record.numpy())
53 | records.append(record)
54 | return records
55 |
56 |
57 | def get_bytes_feature(example: tf.train.Example, key: str) -> bytes:
58 | return example.features.feature[key].bytes_list.value[0]
59 |
60 |
61 | def get_text_feature(example: tf.train.Example, key: str) -> str:
62 | return get_bytes_feature(example, key).decode("utf-8")
63 |
64 |
65 | def get_repeated_text_feature(example: tf.train.Example, key: str) -> list[str]:
66 | values = []
67 | for value in example.features.feature[key].bytes_list.value:
68 | values.append(value.decode("utf-8"))
69 | return values
70 |
71 |
72 | def filter_labels(labels):
73 | if FLAGS.labels == "go":
74 | return [label for label in labels if label.startswith("GO:")]
75 | elif FLAGS.labels == "ec":
76 | return [label for label in labels if label.startswith("EC:")]
77 | else:
78 | raise ValueError("Unknown label type: %s" % FLAGS.labels)
79 |
80 |
81 | def load_examples(path):
82 | """Load tfrecord file."""
83 | examples = []
84 | tf_examples = load_tf_examples(path)
85 | for example in tf_examples:
86 | sequence = get_text_feature(example, "sequence")
87 | accession = get_text_feature(example, "id")
88 | labels = get_repeated_text_feature(example, "label")
89 | labels = filter_labels(labels)
90 | example = {
91 | "sequence": sequence,
92 | "accession": accession,
93 | "labels": labels,
94 | }
95 | examples.append(example)
96 | return examples
97 |
98 |
99 | def main(unused_argv):
100 | examples = []
101 | for file_path in tf.io.gfile.glob(FLAGS.input):
102 | examples.extend(load_examples(file_path))
103 | if FLAGS.sample_frac < 1.0:
104 | cutoff = int(FLAGS.sample_frac * len(examples))
105 | random.shuffle(examples)
106 | examples = examples[:cutoff]
107 | jsonl_utils.write(FLAGS.output, examples)
108 |
109 |
110 | if __name__ == "__main__":
111 | app.run(main)
112 |
--------------------------------------------------------------------------------
/data/convert_deepfri.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Converts PDB-based EC data.
17 |
18 | This data is from the paper:
19 | "Structure-based protein function prediction using graph convolutional networks"
20 | https://www.nature.com/articles/s41467-021-23303-9
21 |
22 | The data is available here:
23 | https://github.com/flatironinstitute/DeepFRI/tree/master/preprocessing/data
24 | """
25 |
26 | from absl import app
27 | from absl import flags
28 | import tensorflow as tf
29 |
30 | from common import jsonl_utils
31 |
32 |
33 | FLAGS = flags.FLAGS
34 |
35 |
36 | flags.DEFINE_string("split", "", "Path to split file.")
37 |
38 | flags.DEFINE_string("sequences", "", "Path to sequences file.")
39 |
40 | flags.DEFINE_string("annotations", "", "Path to annotations file.")
41 |
42 | flags.DEFINE_string("output", "", "Path to write jsonl.")
43 |
44 |
45 | def read_txt(filepath):
46 | """Read newline separated text file."""
47 | rows = []
48 | with tf.io.gfile.GFile(filepath, "r") as tsv_file:
49 | for line in tsv_file:
50 | line = line.rstrip("\n")
51 | rows.append(line)
52 | print("Loaded %s rows from %s." % (len(rows), filepath))
53 | return rows
54 |
55 |
56 | def read_fasta(filepath):
57 | """Parse FASTA file."""
58 | description = None
59 | sequence = ""
60 | with tf.io.gfile.GFile(filepath, "r") as fp:
61 | for line in fp:
62 | line = line.strip(" \t\n\r")
63 | if line.startswith(">"):
64 | if description is not None:
65 | yield description, sequence
66 | description = line[1:]
67 | sequence = ""
68 | else:
69 | sequence += line
70 | yield description, sequence
71 |
72 |
73 | def write_examples(accession_to_labels, accession_to_sequence):
74 | """Write examples."""
75 | accessions = read_txt(FLAGS.split)
76 |
77 | examples = []
78 | for accession in accessions:
79 | labels = accession_to_labels[accession]
80 | sequence = accession_to_sequence[accession]
81 | # Note these are not UniProt accessions, but we will store them as such
82 | # so that the data fields match other datasets.
83 | example = {
84 | "accession": accession,
85 | "sequence": sequence,
86 | "labels": labels,
87 | }
88 | examples.append(example)
89 |
90 | jsonl_utils.write(FLAGS.output, examples)
91 |
92 |
93 | def main(unused_argv):
94 | # Load sequences.
95 | sequences_tuples = read_fasta(FLAGS.sequences)
96 |
97 | # Create id to sequence map.
98 | accession_to_sequence = {}
99 | for header, sequence in sequences_tuples:
100 | accession, meta = header.split(" ") # pytype: disable=attribute-error
101 | if meta != "nrPDB":
102 | raise ValueError(meta)
103 | accession_to_sequence[accession] = sequence
104 |
105 | # Load EC annotations.
106 | annotations_rows = read_txt(FLAGS.annotations)
107 |
108 | accession_to_labels = {}
109 | # Skip 3 header rows.
110 | for row in annotations_rows[3:]:
111 | accession, ec_list = row.split("\t")
112 | labels = ec_list.split(",")
113 | accession_to_labels[accession] = labels
114 |
115 | write_examples(accession_to_labels, accession_to_sequence)
116 |
117 |
118 | if __name__ == "__main__":
119 | app.run(main)
120 |
--------------------------------------------------------------------------------
/eval/eval_protein_f1.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Computes and prints maximum protein-centric F1 score.
17 |
18 | This is a commonly used metric for evaluating protein function prediction
19 | methods. Details can be found in "A large-scale evaluation of computational
20 | protein function prediction" (https://www.nature.com/articles/nmeth.2340).
21 |
22 | Note that more efficient implementations exist, such as this one that depends
23 | on PyTorch:
24 | https://github.com/DeepGraphLearning/torchdrug/blob/6066fbd82360abb5f270cba1eca560af01b8cc90/torchdrug/metrics/metric.py#L234
25 |
26 | However, our goal was to implement the metric in a way that is easier to
27 | understand and verify and without additional dependencies.
28 | """
29 |
30 | from absl import app
31 | from absl import flags
32 | import numpy as np
33 |
34 | from common import jsonl_utils
35 | from eval import eval_utils
36 |
37 |
38 | FLAGS = flags.FLAGS
39 |
40 |
41 | flags.DEFINE_string("dataset", "", "Path to jsonl dataset file.")
42 |
43 | flags.DEFINE_string("predictions", "", "Path to jsonl predictions file.")
44 |
45 | flags.DEFINE_integer(
46 | "precision",
47 | 2,
48 | "Round scores to this many decimals if >0."
49 | "Helps speed up computation by considering fewer thresholds.",
50 | )
51 |
52 |
53 | def get_counts(y_true, y_pred):
54 | y_true_and_pred = y_true * y_pred
55 | tp = np.sum(y_true_and_pred, axis=1)
56 | fp = np.sum(y_pred, axis=1) - tp
57 | fn = np.sum(y_true, axis=1) - tp
58 | return tp, fp, fn
59 |
60 |
61 | def get_protein_centric_f1(y_true, y_pred):
62 | """Computes protein-centric F1 score."""
63 | tp, fp, fn = get_counts(y_true, y_pred)
64 | # If there are no predictions, then precision is undefined and does not count
65 | # towards the overall average, per the definition of protein-centric F1.
66 | # Set undefined values to 0.
67 | precision_num = np.divide(
68 | tp, tp + fp, out=np.zeros_like(tp, dtype=np.float32), where=tp != 0
69 | )
70 | precision_denom = (tp + fp) > 0
71 | recall = tp / (tp + fn)
72 | precision_avg = np.sum(precision_num) / np.sum(precision_denom)
73 | recall_avg = np.mean(recall)
74 | f1 = 2 * precision_avg * recall_avg / (precision_avg + recall_avg)
75 | return f1
76 |
77 |
78 | def get_thresholds(preds):
79 | return np.sort(np.unique(preds.flatten()))
80 |
81 |
82 | def get_max_protein_centric_f1(pred_scores, target_labels):
83 | pred_scores = np.array(pred_scores)
84 | target_labels = np.array(target_labels)
85 | thresholds = get_thresholds(pred_scores)
86 | print(f"num thresholds: {len(thresholds)}")
87 | f1_scores = []
88 | for threshold in thresholds:
89 | pred_labels = pred_scores >= threshold
90 | f1_scores.append(get_protein_centric_f1(target_labels, pred_labels))
91 | return max(f1_scores)
92 |
93 |
94 | def main(unused_argv):
95 | predictions = jsonl_utils.read(FLAGS.predictions)
96 | dataset = jsonl_utils.read(FLAGS.dataset)
97 | all_labels = eval_utils.get_all_labels(dataset, predictions)
98 | true_labels, pred_scores = eval_utils.preprocess_preds(
99 | dataset, predictions, all_labels
100 | )
101 | if FLAGS.precision > 0:
102 | pred_scores = np.around(pred_scores, decimals=FLAGS.precision)
103 |
104 | max_f1 = get_max_protein_centric_f1(pred_scores, true_labels)
105 | print(f"max_f1: {max_f1}")
106 |
107 |
108 | if __name__ == "__main__":
109 | app.run(main)
110 |
--------------------------------------------------------------------------------
/eval/eval_pfam.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Evaluate PFAM predictions.
17 |
18 | Metrics used are:
19 | -Family accuracy
20 | -Lifted clan accuracy.
21 | -Average per class family accuracy.
22 |
23 | The dataset should be in the jsonl format indicated by convert_pfam.py where
24 | each line is of the form:
25 | {"accession": "C6WKU9_ACTMD/45-114",
26 | "sequence": "ARNDCEF...."
27 | "labels": ["PF03793.19"]}
28 |
29 | The predictions should be in jsonl format where each line is of the form:
30 | {"accession": "C6WKU9_ACTMD/45-114", "predicted_label": "PF03793.19"}
31 |
32 | The family to clan mapping file can be found here:
33 |
34 | wget
35 | ftp://ftp.ebi.ac.uk/pub/databases/Pfam/releases/Pfam32.0/Pfam-A.clans.tsv.gz .
36 | gzip -d Pfam-A.clans.tsv.gz
37 | """
38 |
39 | import collections
40 |
41 | from absl import app
42 | from absl import flags
43 | import numpy as np
44 | import pandas as pd
45 |
46 | from common import jsonl_utils
47 | from eval import eval_pfam_utils
48 |
49 | FLAGS = flags.FLAGS
50 |
51 |
52 | flags.DEFINE_string('dataset', '', 'Path to jsonl dataset file.')
53 |
54 | flags.DEFINE_string('predictions', '', 'Path to jsonl predictions file.')
55 |
56 | flags.DEFINE_string('clan_mapping', None, 'Path to clan mapping file.')
57 |
58 |
59 | def convert_predictions_to_df(jsonl_data):
60 | accessions = []
61 | predicted_labels = []
62 | for row in jsonl_data:
63 | accessions.append(row['accession'])
64 | predicted_labels.append(row['predicted_label'])
65 |
66 | return pd.DataFrame(
67 | {'accession': accessions, 'predicted_label': predicted_labels}
68 | )
69 |
70 |
71 | def convert_dataset_to_df(jsonl_data):
72 | accessions = []
73 | labels = []
74 | for row in jsonl_data:
75 | accessions.append(row['accession'])
76 | assert len(row['labels']) == 1
77 | labels.append(row['labels'][0])
78 |
79 | return pd.DataFrame({'accession': accessions, 'true_label': labels})
80 |
81 |
82 | def mean_per_class_accuracy(predictions_dataframe):
83 | """Compute accuracy of predictions, giving equal weight to all classes.
84 |
85 | Args:
86 | predictions_dataframe: pandas DataFrame with 3 columns,
87 | classification_util.PREDICTION_FILE_COLUMN_NAMES.
88 |
89 | Returns:
90 | float. The average of all class-level accuracies.
91 | """
92 | grouped_predictions = collections.defaultdict(list)
93 | for row in predictions_dataframe.itertuples():
94 | grouped_predictions[row.true_label].append(row.predicted_label)
95 |
96 | accuracy_per_class = {
97 | true_label: np.mean(predicted_label == np.array(true_label))
98 | for true_label, predicted_label in grouped_predictions.items()
99 | }
100 |
101 | return np.mean(list(accuracy_per_class.values()))
102 |
103 |
104 | def main(unused_argv):
105 | predictions = jsonl_utils.read(FLAGS.predictions)
106 | dataset = jsonl_utils.read(FLAGS.dataset)
107 |
108 | prediction_df = convert_predictions_to_df(predictions)
109 | reference_df = convert_dataset_to_df(dataset)
110 |
111 | merged_df = prediction_df.merge(
112 | reference_df, on='accession', how='inner', validate='one_to_one'
113 | )
114 |
115 | family_accuracy = eval_pfam_utils.raw_unweighted_accuracy(merged_df)
116 | per_class_family_accuracy = eval_pfam_utils.mean_per_class_accuracy(merged_df)
117 |
118 | print('Family accuracy: %.1f' % (family_accuracy * 100))
119 | print('Avg. Per-Family accuracy: %.1f' % (per_class_family_accuracy * 100))
120 |
121 | if FLAGS.clan_mapping:
122 | lifted_clan_accuracy = eval_pfam_utils.get_unweighted_lifted_clan_accuracy(
123 | merged_df, FLAGS.clan_mapping
124 | )
125 | print('Lifted Clan accuracy: %.1f' % (lifted_clan_accuracy * 100))
126 |
127 |
128 | if __name__ == '__main__':
129 | app.run(main)
130 |
--------------------------------------------------------------------------------
/eval/eval_pfam_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utilities for evaluating pfam predictions."""
17 |
18 | import collections
19 |
20 | import numpy as np
21 | import pandas as pd
22 | import tensorflow as tf
23 |
24 |
25 | def read_pfam_clan_file(path: str) -> pd.DataFrame:
26 | """Parses pfam clan tsv file.
27 |
28 | Args:
29 | path: Path to tsv clan file.
30 |
31 | Returns:
32 | pd.DataFrame with columns family_accession (str), clan_accession (str),
33 | clan_description (str), family_name (str), family_description (str).
34 | """
35 | with tf.io.gfile.GFile(path, 'r') as f:
36 | return pd.read_csv(
37 | f,
38 | names=[
39 | 'family_accession',
40 | 'clan_accession',
41 | 'clan_description',
42 | 'family_name',
43 | 'family_description',
44 | ],
45 | sep='\t',
46 | # Some fields are missing, and we want to keep those
47 | # as empty strings instead of the default behavior,
48 | # which is to convert them to NaNs.
49 | keep_default_na=False,
50 | )
51 |
52 |
53 | def family_to_clan_mapping(path: str) -> dict[str, str]:
54 | """Parse tsv contents, returning dict from pfam family to clan accession.
55 |
56 | Families without a clan will get their own clan in
57 | the returned dictionary, with clan name == to the accession (e.g. PF12345
58 | -> PF12345).
59 |
60 | Args:
61 | path: Path to tsv clan file.
62 |
63 | Returns:
64 | dict from string to string, e.g. {'PF12345': 'CL9999'}.
65 | """
66 | dataframe = read_pfam_clan_file(path)
67 |
68 | dataframe['clan_accession'] = dataframe.apply(
69 | axis='columns',
70 | func=lambda row: row.clan_accession # pylint: disable=g-long-lambda
71 | if row.clan_accession
72 | else row.family_accession,
73 | )
74 |
75 | # Filter family names without clans (they are are stored in the csv
76 | # as empty strings). If we're using lifted clan semantics, every family will
77 | # have a clan (see docstring).
78 | return dict(
79 | (family_id, clan_id) # pylint: disable=g-complex-comprehension
80 | for family_id, clan_id in zip(
81 | dataframe['family_accession'].values,
82 | dataframe['clan_accession'].values,
83 | )
84 | if clan_id
85 | )
86 |
87 |
88 | def raw_unweighted_accuracy(
89 | predictions_df: pd.DataFrame,
90 | true_label_column: str = 'true_label',
91 | predicted_label_column: str = 'predicted_label',
92 | ) -> float:
93 | """Compute accuracy, regardless of which class each prediction corresponds to.
94 |
95 | Args:
96 | predictions_df: pandas DataFrame with at least 2 columns, true_label and
97 | predicted_label.
98 | true_label_column: Column name of true labels.
99 | predicted_label_column: str. Column name of predicted labels.
100 |
101 | Returns:
102 | Accuracy.
103 | """
104 | num_correct = (
105 | predictions_df[true_label_column]
106 | == predictions_df[predicted_label_column]
107 | ).sum()
108 | total = len(predictions_df)
109 | return num_correct / total
110 |
111 |
112 | def mean_per_class_accuracy(
113 | predictions_df: pd.DataFrame,
114 | true_label_column: str = 'true_label',
115 | predicted_label_column: str = 'predicted_label',
116 | ) -> float:
117 | """Compute accuracy of predictions, giving equal weight to all classes.
118 |
119 | Args:
120 | predictions_df: pandas DataFrame with at least 2 columns, true_label and
121 | predicted_label.
122 | true_label_column: Column name of true labels.
123 | predicted_label_column: str. Column name of predicted labels.
124 |
125 | Returns:
126 | The average of all class-level accuracies.
127 | """
128 | grouped_predictions = collections.defaultdict(list)
129 | for _, row in predictions_df.iterrows():
130 | grouped_predictions[row[true_label_column]].append(
131 | row[predicted_label_column]
132 | )
133 |
134 | accuracy_per_class = {
135 | true_label: np.mean(predicted_label == np.array(true_label))
136 | for true_label, predicted_label in grouped_predictions.items()
137 | }
138 |
139 | return np.mean(list(accuracy_per_class.values()))
140 |
141 |
142 | def get_unweighted_lifted_clan_accuracy(
143 | predictions_df: pd.DataFrame,
144 | clan_mapping_path: str,
145 | true_label_column: str = 'true_label',
146 | predicted_label_column: str = 'predicted_label',
147 | ):
148 | """Compute accuracy, where each label is mapped to its clan.
149 |
150 | Args:
151 | predictions_df: pandas DataFrame with at least 2 columns, true_label and
152 | predicted_label.
153 | clan_mapping_path: Path to tsv clan file.
154 | true_label_column: Column name of true labels.
155 | predicted_label_column: Column name of predicted labels.
156 |
157 | Returns:
158 | Lifted Clan Accuracy.
159 | """
160 |
161 | def pfam_accession_helper(x):
162 | x_split = x.split('.')
163 | assert len(x_split) == 2
164 | return x_split[0]
165 |
166 | family_to_clan = family_to_clan_mapping(clan_mapping_path)
167 |
168 | predictions_df['true_clan_label'] = predictions_df[true_label_column].apply(
169 | lambda x: family_to_clan.get(
170 | pfam_accession_helper(x), pfam_accession_helper(x)
171 | )
172 | )
173 |
174 | predictions_df['predicted_clan_label'] = predictions_df[
175 | predicted_label_column
176 | ].apply(
177 | lambda x: family_to_clan.get(
178 | pfam_accession_helper(x), pfam_accession_helper(x)
179 | )
180 | )
181 |
182 | return raw_unweighted_accuracy(
183 | predictions_df,
184 | true_label_column='true_clan_label',
185 | predicted_label_column='predicted_clan_label',
186 | )
187 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ProtEx
2 |
3 | This repository contains open source code related to the paper [ProtEx: A Retrieval-Augmented Approach for Protein Function Prediction](https://www.biorxiv.org/content/10.1101/2024.05.30.596539v1).
4 |
5 | ## Installation
6 |
7 | Clone the repository:
8 |
9 | ```shell
10 | git clone https://github.com/google-deepmind/protex.git
11 | ```
12 |
13 | It is then recommended to setup a virtual environment. We provide an example
14 | using `conda`:
15 |
16 | ```shell
17 | conda create -n protex python=3.10
18 | conda activate protex
19 | ```
20 |
21 | Then install dependencies specified in `setup.py`:
22 |
23 | ```shell
24 | pip install .
25 | ```
26 |
27 | ## Overview
28 |
29 | The code, along with the released model predictions, support reproducing the main results from the paper. The code is organized as follows:
30 |
31 | * `blast/` - Contains conversion scripts for reproducing BLAST results.
32 | * `common/` - Some common utility libraries.
33 | * `data/` - Contains conversion scripts for various datasets to a common format.
34 | * `eval/` - Contains tools for computing various evaluation metrics.
35 |
36 | We convert datasets to a common format consisting of newline separated json files, where each has the following keys:
37 |
38 | * `sequence` - String of protein sequence.
39 | * `accession` - String for unique identifier, e.g. UniProt accession.
40 | * `labels` - List of strings for labels, e.g. EC numbers.
41 |
42 | ## Usage Examples
43 |
44 | ### ProteInfer
45 |
46 | Here we provide a usage example focused on reproducing the results for the ProteInfer dataset for the clustered EC split. Conversion and evaluation scripts for other datasets
47 | can be found in `data/` and `/eval`, and usages are similar.
48 |
49 | The [original dataset](https://google-research.github.io/proteinfer/) is available on GCP at `gs:///brain-genomics-public/research/proteins/proteinfer/datasets/swissprot/`. We can set our input to the path to the EC clustered test split:
50 |
51 | ```shell
52 | CLUSTERED_EC_TEST_TFR="gs://brain-genomics-public/research/proteins/proteinfer/datasets/swissprot/clustered/test.tfrecord"
53 | ```
54 |
55 | We will assume that the variable `DATA_DIR` is set to readable and writable
56 | directory, such as `DATA_DIR=/tmp/`.
57 |
58 | We can then run the data conversion script:
59 |
60 | ```shell
61 | CLUSTERED_EC_TEST_JSONL="${DATA_DIR}/proteinfer_clustered_ec_test.jsonl"
62 | python -m data.convert_proteinfer \
63 | --alsologtostderr \
64 | --input=${CLUSTERED_EC_TEST_TFR} \
65 | --output=${CLUSTERED_EC_TEST_JSONL} \
66 | --labels=ec
67 | ```
68 |
69 | Model predictions for ProtEx on all test splits are available at `gs://protex/predictions`. Specifically, the clustered EC predictions are here:
70 |
71 | ```
72 | PREDS_PROTEX=gs://protex/predictions/proteinfer-clustered-ec-test-protex.jsonl
73 | ```
74 |
75 | We can then reproduce the max micro-averaged F1 metrics reported for this split with the following script:
76 |
77 | ```shell
78 | python -m eval.eval_micro_f1 \
79 | --alsologtostderr \
80 | --dataset=${CLUSTERED_EC_TEST_JSONL} \
81 | --predictions=${PREDS_PROTEX}
82 | ```
83 |
84 | We also released BLAST predictions, so the above script can also be used with the following `--predictions` argument to reproduce the reported BLAST results:
85 |
86 | ```
87 | PREDS_BLAST=gs://protex/predictions/proteinfer-clustered-ec-test-protex.jsonl
88 | ```
89 |
90 | #### Reproducing BLAST
91 |
92 | We also released code to reproduce the BLAST predictions. For this we need to also convert the ProteInfer training set:
93 |
94 | ```shell
95 | CLUSTERED_EC_TRAIN_TFR="gs://brain-genomics-public/research/proteins/proteinfer/datasets/swissprot/clustered/train.tfrecord"
96 | CLUSTERED_EC_TRAIN_JSONL="${DATA_DIR}/proteinfer_clustered_ec_train.jsonl
97 | python -m data.convert_proteinfer \
98 | --alsologtostderr \
99 | --input=${CLUSTERED_EC_TRAIN_TFR} \
100 | --output=${CLUSTERED_EC_TRAIN_JSONL} \
101 | --labels=ec
102 | ```
103 |
104 | We then need to convert both train and test splits to `.fasta` format:
105 |
106 | ```shell
107 | CLUSTERED_EC_TRAIN_FASTA="${DATA_DIR}/proteinfer_clustered_ec_train.fasta
108 | python -m blast.convert_to_fasta \
109 | --alsologtostderr \
110 | --input=${CLUSTERED_EC_TRAIN_JSONL} \
111 | --output=${CLUSTERED_EC_TRAIN_FASTA}
112 |
113 | CLUSTERED_EC_TEST_FASTA="${DATA_DIR}/proteinfer_clustered_ec_test.fasta
114 | python -m blast.convert_to_fasta \
115 | --alsologtostderr \
116 | --input=${CLUSTERED_EC_TEST_JSONL} \
117 | --output=${CLUSTERED_EC_TEST_FASTA}
118 | ```
119 |
120 | Note that if `DATA_DIR` refers to a GCP bucket rather than a local directory, the files may need to be copied locally so that they can be read by the BLAST command line tool before proceeding to the next step. We will assume `BLAST_DIR` is set to the location of the BLAST binaries,
121 | e.g. `BLAST_DIR=".../ncbi-blast-2.14.1+/bin"`.
122 |
123 | We can then run BLAST.
124 |
125 | ```shell
126 | BLAST_TSV="${DATA_DIR}/blast_proteinfer_clustered_ec_test.tsv"
127 | ${BLAST_DIR}/makeblastdb -in ${CLUSTERED_EC_TRAIN_FASTA} -dbtype prot
128 | ${BLAST_DIR}/blastp -query ${CLUSTERED_EC_TEST_FASTA} -db ${CLUSTERED_EC_TRAIN_FASTA} -outfmt 6 -max_hsps 1 -num_threads 16 -max_target_seqs 1 -out ${BLAST_TSV}
129 | ```
130 |
131 | Finally, we can convert the tsv file generated by BLAST to the standard predictions format we are using:
132 |
133 | ```shell
134 | BLAST_JSONL=${DATA_DIR}/blast_proteinfer_clustered_ec_test.jsonl
135 | python -m blast.convert_blast \
136 | --alsologtostderr \
137 | --input=${BLAST_TSV} \
138 | --database_records=${CLUSTERED_EC_TRAIN_FASTA} \
139 | --output=${BLAST_JSONL}
140 | ```
141 |
142 | ## Citing this work
143 |
144 | You can cite the preprint of our work as follows:
145 |
146 | ```latex
147 | @article{shaw2024protex,
148 | title={ProtEx: A Retrieval-Augmented Approach for Protein Function Prediction},
149 | author={Shaw, Peter and Gurram, Bhaskar and Belanger, David and Gane, Andreea and Bileschi, Maxwell L and Colwell, Lucy J and Toutanova, Kristina and Parikh, Ankur P},
150 | journal={bioRxiv},
151 | URL = {https://www.biorxiv.org/content/early/2024/06/02/2024.05.30.596539},
152 | year={2024},
153 | }
154 | ```
155 |
156 | ## License and disclaimer
157 |
158 | Copyright 2024 DeepMind Technologies Limited
159 |
160 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
161 | you may not use this file except in compliance with the Apache 2.0 license.
162 | You may obtain a copy of the Apache 2.0 license at:
163 | https://www.apache.org/licenses/LICENSE-2.0
164 |
165 | All other materials are licensed under the Creative Commons Attribution 4.0
166 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
167 | https://creativecommons.org/licenses/by/4.0/legalcode
168 |
169 | Unless required by applicable law or agreed to in writing, all software and
170 | materials distributed here under the Apache 2.0 or CC-BY licenses are
171 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
172 | either express or implied. See the licenses for the specific language governing
173 | permissions and limitations under those licenses.
174 |
175 | This is not an official Google product.
176 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------