├── mammal
├── __init__.py
├── examples
│ ├── molnet
│ │ ├── __init__.py
│ │ └── molnet_infer.py
│ ├── carcinogenicity
│ │ ├── __init__.py
│ │ ├── config.yaml
│ │ ├── main_infer.py
│ │ ├── pl_data_module.py
│ │ └── task.py
│ ├── scrna_cell_type
│ │ ├── __init__.py
│ │ ├── cell_type_mapping.csv
│ │ ├── data
│ │ │ ├── process_h5ad_data.py
│ │ │ └── Zheng68k_to_anndata.py
│ │ ├── anndata_op.py
│ │ ├── config.yaml
│ │ ├── scRNA_infer.py
│ │ ├── pl_data_module.py
│ │ └── task.py
│ ├── dti_bindingdb_kd
│ │ ├── __init__.py
│ │ ├── config.yaml
│ │ ├── main_infer.py
│ │ ├── pl_data_module.py
│ │ └── task.py
│ ├── protein_solubility
│ │ ├── __init__.py
│ │ ├── config.yaml
│ │ ├── main_infer.py
│ │ ├── pl_data_module.py
│ │ └── task.py
│ ├── tests
│ │ ├── test_molnet.py
│ │ ├── test_tcr_epitope_binding_inference.py
│ │ ├── test_simple_inference.py
│ │ ├── test_protein_solubility_prediction.py
│ │ ├── test_drug_carcinogenicity_classification.py
│ │ ├── test_dti_bindingdb_kd.py
│ │ └── test_main_finetune.py
│ └── tcr_epitope_binding
│ │ └── main_infer.py
├── lora.py
├── lr_schedulers.py
├── keys.py
├── main_finetune.py
├── losses.py
├── task.py
└── metrics.py
├── mammal.png
├── mammal_preview_2.pdf
├── mammal_mcp
├── .env.example
├── pyproject.toml
├── util.py
├── .gitignore
├── dependencies.py
└── README.md
├── CLONE.md
├── .github
├── ISSUE_TEMPLATE
│ ├── ---question.md
│ ├── ---bug-report.md
│ └── ---feature-request.md
└── workflows
│ ├── github-actions.yml
│ ├── python-publish.yml
│ └── clone.yml
├── CONTRIBUTING.md
├── .pre-commit-config.yaml
├── .gitignore
├── pyproject.toml
├── tutorials
└── begginer_inference.ipynb
└── README.md
/mammal/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mammal/examples/molnet/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mammal/examples/carcinogenicity/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mammal/examples/scrna_cell_type/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mammal/examples/dti_bindingdb_kd/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mammal/examples/protein_solubility/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mammal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BiomedSciAI/biomed-multi-alignment/HEAD/mammal.png
--------------------------------------------------------------------------------
/mammal_preview_2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BiomedSciAI/biomed-multi-alignment/HEAD/mammal_preview_2.pdf
--------------------------------------------------------------------------------
/mammal_mcp/.env.example:
--------------------------------------------------------------------------------
1 | PROTEIN_PROTEIN_INTERACTION=false
2 | PROTEIN_SOLUBILITY=false
3 | TCR_EPITOPE_BINDING=false
4 | DRUG_TARGET_BINDING=false
5 | DRUG_TARGET_BINDING_FASTA=true
6 |
7 | STREAMABLE_HTTP=false
8 | SSE=false
9 | PORT=8001
10 |
--------------------------------------------------------------------------------
/mammal/examples/tests/test_molnet.py:
--------------------------------------------------------------------------------
1 | import socket
2 |
3 | import pytest
4 |
5 | from mammal.examples.molnet.molnet_infer import load_model, task_infer
6 |
7 |
8 | @pytest.mark.skipif(
9 | "ccc" not in socket.gethostname(),
10 | reason="Train consumes too much memory for a Travis run.",
11 | )
12 | def test_infer():
13 | smiles_seq = "C(Cl)Cl"
14 | for task_name in ["BBBP", "TOXICITY", "FDA_APPR"]:
15 | task_dict = load_model(task_name=task_name, device="cpu")
16 | result = task_infer(task_dict=task_dict, smiles_seq=smiles_seq)
17 | print(f"The prediction for {smiles_seq=} is {result}")
18 |
--------------------------------------------------------------------------------
/CLONE.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | **Markdown**
4 |
5 | ```markdown
6 | [](https://github.com/MShawon/github-clone-count-badge)
7 |
8 | ```
9 |
10 | **HTML**
11 | ```html
12 |
13 | ```
14 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/---question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: ❓ Question
3 | description: Use this template to ask a question about MAMMAL's code or its paper details.
4 | title: "[Q]: "
5 | labels:
6 | - "ty:question"
7 |
8 | body:
9 |
10 | - type: markdown
11 | attributes:
12 | value: >
13 | **Thanks :heart: for taking the time to fill out this bug report!** We kindly ask that you search to see if an
14 | issue [already exists](https://github.com/BiomedSciAI/biomed-multi-alignment/issues) for the bug you encountered.
15 |
16 | - type: textarea
17 | attributes:
18 | label: Ask your question
19 | value: |
20 |
21 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to MAMMAL🦍
2 |
3 | First off, thanks for taking the time to contribute!
4 |
5 | ## How do I report a bug or suggest an enhancement?
6 |
7 | - As a first step, please search in the [existing issues](https://github.com/BiomedSciAI/biomed-multi-alignment/issues) to check if your point has already been addressed.
8 | - If that is not the case, go ahead and [create an issue](https://github.com/BiomedSciAI/biomed-multi-alignment/issues/new/choose) of the respective type, providing the details as instructed in the template.
9 |
10 | ## How do I submit a change?
11 |
12 | We welcome contributions via pull requests:
13 |
14 | - Fork (or clone) the repo and create your branch from the up-to-date default branch.
15 | - If you have added code that should be tested - add tests.
16 | - If any documentation updates are needed - make them.
17 | - Ensure the test suite passes and the code lints.
18 | - Submit the pull request.
19 |
--------------------------------------------------------------------------------
/.github/workflows/github-actions.yml:
--------------------------------------------------------------------------------
1 | name: GitHub Actions - MAMMAL
2 |
3 | on:
4 | pull_request:
5 | branches: [ "main" ]
6 |
7 | jobs:
8 | build-linux:
9 | runs-on: ubuntu-latest
10 | strategy:
11 | fail-fast: false
12 | matrix:
13 | python-version: ["3.10", "3.11"]
14 |
15 | steps:
16 | - uses: actions/checkout@v4
17 | - name: Set up Python ${{ matrix.python-version }}
18 | uses: actions/setup-python@v3
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install -q .[examples]
25 | pip install pre-commit
26 |
27 | - name: Pre-Commit Hooks
28 | run: |
29 | pre-commit install
30 | pre-commit run --all-files --show-diff-on-failure
31 | - name: Test with pytest
32 | run: |
33 | python -m pytest --capture=no .
34 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/---bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: 🐞 Bug Report
3 | description: Use this template to report a bug in MAMMAL.
4 | title: "[Bug]: "
5 | labels:
6 | - "a:app"
7 | - "ty:bug"
8 |
9 | body:
10 | - type: markdown
11 | attributes:
12 | value: >
13 | **Thanks :heart: for taking the time to fill out this bug report!** We kindly ask that you search to see if an
14 | issue [already exists](https://github.com/BiomedSciAI/biomed-multi-alignment/issues) for the bug you encountered.
15 |
16 | - type: textarea
17 | attributes:
18 | label: Describe the bug
19 | description: >
20 | A clear and concise description of the issue you're experiencing. If relevant, add screenshots or videos to help
21 | explain the bug, or any steps to reproduce the bug. If relevant describe the expected behavior. Please include
22 | the following information:
23 |
24 | - python version
25 |
26 | - torch version
27 |
28 | - os type and version
29 | value: |
30 |
31 |
--------------------------------------------------------------------------------
/mammal/examples/tests/test_tcr_epitope_binding_inference.py:
--------------------------------------------------------------------------------
1 | from mammal.examples.tcr_epitope_binding.main_infer import load_model, task_infer
2 |
3 |
4 | def test_infer() -> None:
5 | """
6 | A test for TCR beta chain and epitope binding example on HF, https://huggingface.co/ibm/biomed.omics.bl.sm.ma-ted-458m
7 | """
8 | # positive 1:
9 | tcr_beta_seq = "NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT"
10 | epitope_seq = "LLQTGIHVRVSQPSL"
11 |
12 | # positive 2:
13 | # tcr_beta_seq = "GAVVSQHPSWVICKSGTSVKIECRSLDFQATTMFWYRQFPKQSLMLMATSNEGSKATYEQGVEKDKFLINHASLTLSTLTVTSAHPEDSSFYICSASEGTSSYEQYFGPGTRLTVT"
14 | # epitope_seq = "FLKEKGGL"
15 |
16 | model_inst, tokenizer_op = load_model(device="cpu")
17 | result = task_infer(
18 | model=model_inst,
19 | tokenizer_op=tokenizer_op,
20 | tcr_beta_seq=tcr_beta_seq,
21 | epitope_seq=epitope_seq,
22 | )
23 | print(f"The prediction for {epitope_seq} and {tcr_beta_seq} is {result}")
24 |
25 |
26 | if __name__ == "__main__":
27 | test_infer()
28 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/mammal/examples/scrna_cell_type/cell_type_mapping.csv:
--------------------------------------------------------------------------------
1 | celltype,cell_type_ontology_term_name,cell_type_ontology_term_id,comments
2 | "CD14+ Monocyte","CD14-positive monocyte",CL:0001054,
3 | "CD19+ B","B cell",CL:0000236,"exact mapping is CL:0001201 but it doesn't exists in cellxgene so a parent celltype was chosen"
4 | "CD34+","CD34-positive, CD38-positive common myeloid progenitor OR CD34-positive, CD38-positive common lymphoid progenitor",CL:0000049,"exact mapping is CL:0000995 but it doesn't exists in cellxgene so a parent celltype was chosen"
5 | "CD4+/CD25 T Reg","CD4-positive, CD25-positive, alpha-beta regulatory T cell",CL:0000792,
6 | "CD4+/CD45RA+/CD25- Naive T","naive thymus-derived CD4-positive, alpha-beta T cell",CL:0000895,
7 | "CD4+/CD45RO+ Memory","CD4-positive, alpha-beta memory T cell, CD45RO-positive",CL:0001204,
8 | "CD4+ T Helper2","T-helper 2 cell",CL:0000546,
9 | "CD56+ NK","CD16-positive, CD56-dim natural killer cell, human",CL:0000939,
10 | "CD8+/CD45RA+ Naive Cytotoxic","Effector Memory Cd8-Positive, Alpha-Beta T Cell, Terminally Differentiated",CL:0001062,"mapped also to parentsm, may need review"
11 | "CD8+ Cytotoxic T","CD8-positive, alpha-beta cytotoxic T cell",CL:0000794,
12 | "Dendritic","dendritic cell",CL:0000451,"maybe we need the child cell type - dendtiric, human"
13 |
--------------------------------------------------------------------------------
/mammal_mcp/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "mammal_mcp"
3 | version = "0.1.0"
4 | description = "Add your description here"
5 | authors = [
6 | { name = "IBM Research" },
7 | { name = "Ash Evans", email = "ash.evans@ibm.com" },
8 | { name = "Jennifer Kelly", email = "jennifer.kelly@ibm.com" },
9 | { name = "Laura-Jayne Gardiner", email = "laura-jayne.gardiner@ibm.com" }
10 | ]
11 | readme = "README.md"
12 | requires-python = ">=3.10, <3.12"
13 | dependencies = [
14 | "fastmcp>=2.6.1",
15 | "fuse-med-ml>=0.4.0",
16 | "mcp[cli]>=1.6.0",
17 | "mypy>=1.17.1",
18 | "pre-commit>=4.3.0",
19 | "pydantic>=2.11.1",
20 | "pydantic-settings>=2.8.1",
21 | "pytdc>=0.4.1",
22 | "pytest>=8.3.5",
23 | "python-dotenv>=1.1.0",
24 | "requests==2.24.0",
25 | "types-requests>=2.31.0.6",
26 | "uv>=0.6.11",
27 | "bio>=1.7.1",
28 | "biomed-multi-alignment",
29 | ]
30 |
31 | [dependency-groups]
32 | dev = [
33 | "types-requests>=2.31.0.6",
34 | ]
35 |
36 | [tool.uv.sources]
37 | biomed-multi-alignment = { git = "https://github.com/BiomedSciAI/biomed-multi-alignment.git", rev = "v0.2.2" }
38 |
39 | [project.optional-dependencies]
40 | dev = [
41 | "types-requests",
42 | "mypy",
43 | "types-PyYAML",
44 | "types-decorator",
45 | "types-simplejson",
46 | "types-tabulate",
47 | ]
48 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/---feature-request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: 🛠️ Feature Request
3 | description: Suggest an idea to help us improve MAMMAL
4 | title: "[Feature]: "
5 | labels:
6 | - "ty:feature"
7 |
8 | body:
9 | - type: markdown
10 | attributes:
11 | value: >
12 | **Thanks :heart: for taking the time to fill out this feature request report!** We kindly ask that you search to
13 | see if an issue [already exists](https://github.com/BiomedSciAI/biomed-multi-alignment/issues) for
14 | your feature.
15 |
16 | We are also happy to accept contributions from our users. For more details see
17 | [here](https://github.com/BiomedSciAI/biomed-multi-alignment/blob/main/CONTRIBUTING.md).
18 |
19 | - type: textarea
20 | attributes:
21 | label: Description
22 | description: |
23 | A clear and concise description of the feature you're interested in.
24 | value: |
25 |
26 | validations:
27 | required: true
28 |
29 | - type: textarea
30 | attributes:
31 | label: Suggested Solution
32 | description: >
33 | Describe the solution you'd like. A clear and concise description of what you want to happen. If you have
34 | considered alternatives, please describe them.
35 | value: |
36 |
37 | validations:
38 | required: false
39 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: .*\.pdb$
2 |
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: check-case-conflict
8 | - id: end-of-file-fixer
9 | - id: mixed-line-ending
10 | - id: trailing-whitespace
11 | - repo: https://github.com/psf/black
12 | rev: 25.1.0
13 | hooks:
14 | - id: black
15 | - repo: https://github.com/PyCQA/flake8
16 | rev: 7.1.2
17 | hooks:
18 | - id: flake8
19 | args:
20 | - "--ignore=E203,E266,E501,F405,F403,W503"
21 | - "--statistics"
22 |
23 | - repo: https://github.com/astral-sh/ruff-pre-commit
24 | # Ruff version.
25 | rev: v0.9.10
26 | hooks:
27 | - id: ruff
28 | args:
29 | - "--fix"
30 | - "--select"
31 | - "UP,PT,I,E"#,F,W,C90,I,N,F405,E402" # Specify the rules to select
32 | - "--line-length"
33 | - "88"
34 | - "--exit-non-zero-on-fix"
35 | - "--ignore"
36 | - "F405,F403,E501,E402,PT018,PT015,E722,E741"
37 | types_or: [ python, pyi] #, jupyter ]
38 | - repo: https://github.com/pre-commit/mirrors-mypy
39 | rev: v1.15.0
40 | hooks:
41 | - id: mypy
42 | additional_dependencies: [types-requests]
43 |
44 | - repo: https://github.com/srstevenson/nb-clean
45 | rev: "4.0.1"
46 | hooks:
47 | - id: nb-clean
48 | args:
49 | - --remove-empty-cells
50 | - --preserve-cell-outputs
51 |
--------------------------------------------------------------------------------
/mammal/examples/scrna_cell_type/data/process_h5ad_data.py:
--------------------------------------------------------------------------------
1 | import anndata
2 | import click
3 |
4 | from mammal.examples.scrna_cell_type.pl_data_module import preprocess_ann_data
5 |
6 |
7 | @click.command()
8 | @click.option(
9 | "--input-h5ad-file",
10 | "-i",
11 | prompt=True,
12 | help="name of input H5AD file",
13 | )
14 | @click.option(
15 | "--output-h5ad-file",
16 | "-o",
17 | prompt=True,
18 | help="name of output H5AD file",
19 | )
20 | @click.option(
21 | "--min-genes",
22 | "-m",
23 | type=click.INT,
24 | help="minimal number of different genes per cell. Used for filtering",
25 | default=200,
26 | )
27 | @click.option(
28 | "--num-bins",
29 | "-b",
30 | type=click.INT,
31 | help="number of expression bins to use",
32 | default=10,
33 | )
34 | def main(
35 | input_h5ad_file: str,
36 | output_h5ad_file: str,
37 | min_genes: int = 200,
38 | num_bins: int = 10,
39 | ):
40 |
41 | anndata_object = anndata.read_h5ad(input_h5ad_file)
42 | # process the data - filter out cells with shallow reads, normelize depth and change to log scale of about 0-10 (log_2(1001)~=10)
43 | preprocess_ann_data(
44 | anndata_object=anndata_object,
45 | min_genes=min_genes,
46 | num_bins=num_bins,
47 | )
48 | # Save result anndata object to disk
49 | anndata_object.write_h5ad(output_h5ad_file)
50 | print(f"processed AnnData file saved to {output_h5ad_file}")
51 |
52 |
53 | if __name__ == "__main__":
54 | main()
55 |
--------------------------------------------------------------------------------
/mammal/examples/scrna_cell_type/anndata_op.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from anndata import AnnData
3 | from fuse.data import OpBase
4 | from fuse.utils.ndict import NDict
5 |
6 | from mammal.keys import SAMPLE_ID
7 |
8 |
9 | class OpReadAnnData(OpBase):
10 | """
11 | Op reading data from anndata.
12 | Each row will be added as a value to sample dict.
13 | """
14 |
15 | def __init__(
16 | self,
17 | data: AnnData | None = None,
18 | key_name: str = SAMPLE_ID,
19 | label_column: str = "label",
20 | ):
21 | """
22 | :param data: input AnnData object
23 | :param key_name: name of value in sample_dict which will be used as the key/index
24 | :param label_column: name of the column which contains the label
25 | """
26 | super().__init__()
27 |
28 | self._key_name = key_name
29 | self._data = data
30 | self.label_column = label_column
31 | self.gene_names = np.array(self._data.var_names)
32 |
33 | def __call__(
34 | self, sample_dict: NDict, prefix: str | None = None
35 | ) -> None | dict | list[dict]:
36 | """
37 | See base class
38 |
39 | :param prefix: specify a prefix for the sample dict keys.
40 | For example, with prefix 'data.features' and a df with the columns ['height', 'weight', 'sex'],
41 | the matching keys will be: 'data.features.height', 'data.features.weight', 'data.features.sex'.
42 | """
43 |
44 | key = sample_dict[self._key_name]
45 |
46 | # locate the required item
47 | sample_dict[f"{prefix}.scrna"] = self._data[key, :].X
48 | sample_dict["data.label"] = self._data.obs.iloc[key].get(self.label_column)
49 | sample_dict[f"{prefix}.gene_names"] = self.gene_names
50 |
51 | return sample_dict
52 |
--------------------------------------------------------------------------------
/mammal_mcp/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
3 |
4 |
5 | def process_model_output(
6 | tokenizer_op: ModularTokenizerOp,
7 | decoder_output: np.ndarray,
8 | decoder_output_scores: np.ndarray,
9 | ) -> dict:
10 | """
11 | Extract predicted solubility class and scores
12 | expecting decoder output to be <0> or <1>
13 | note - the normalized version will calculate the positive ('<1>') score divided by the sum of the scores for both '<0>' and '<1>'
14 | BE CAREFUL as both negative and positive absolute scores can be drastically low, and normalized score could be very high.
15 | outputs a dictionary containing:
16 | dict(
17 | predicted_token_str = #... e.g. '<1>'
18 | not_normalized_score = #the score for the positive token... e.g. 0.01
19 | normalized_score = #... (positive_token_score) / (positive_token_score+negative_token_score)
20 | )
21 | if there is any error in parsing the model output, None is returned.
22 | """
23 |
24 | negative_token_id = tokenizer_op.get_token_id("<0>")
25 | positive_token_id = tokenizer_op.get_token_id("<1>")
26 | label_id_to_int = {
27 | negative_token_id: 0,
28 | positive_token_id: 1,
29 | }
30 | classification_position = 1
31 |
32 | if decoder_output_scores is not None:
33 | not_normalized_score = decoder_output_scores[
34 | classification_position, positive_token_id
35 | ]
36 | normalized_score = not_normalized_score / (
37 | not_normalized_score
38 | + decoder_output_scores[classification_position, negative_token_id]
39 | + 1e-10
40 | )
41 | ans = dict(
42 | pred=label_id_to_int.get(int(decoder_output[classification_position]), -1),
43 | not_normalized_scores=not_normalized_score,
44 | normalized_scores=normalized_score,
45 | )
46 |
47 | return ans
48 |
--------------------------------------------------------------------------------
/mammal/examples/tests/test_simple_inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
3 |
4 | from mammal.keys import *
5 | from mammal.model import Mammal
6 |
7 |
8 | def test_simple_infer() -> None:
9 | """
10 | A simple function that test the proposed inference example on HF, https://huggingface.co/ibm/biomed.omics.bl.sm.ma-ted-458m
11 | """
12 |
13 | # Load Model
14 | model = Mammal.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m")
15 | model.eval()
16 |
17 | # Load Tokenizer
18 | tokenizer_op = ModularTokenizerOp.from_pretrained(
19 | "ibm/biomed.omics.bl.sm.ma-ted-458m"
20 | )
21 |
22 | # Prepare Input Prompt
23 | protein_calmodulin = "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK"
24 | protein_calcineurin = "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ"
25 |
26 | # Create and load sample
27 | sample_dict = dict()
28 | # Formatting prompt to match pre-training syntax
29 | sample_dict[ENCODER_INPUTS_STR] = (
30 | f"<@TOKENIZER-TYPE=AA>{protein_calmodulin}{protein_calcineurin}"
31 | )
32 |
33 | # Tokenize
34 | tokenizer_op(
35 | sample_dict=sample_dict,
36 | key_in=ENCODER_INPUTS_STR,
37 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
38 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
39 | )
40 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
41 | sample_dict[ENCODER_INPUTS_TOKENS]
42 | )
43 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
44 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK]
45 | )
46 |
47 | # Generate Prediction
48 | batch_dict = model.generate(
49 | [sample_dict],
50 | output_scores=True,
51 | return_dict_in_generate=True,
52 | max_new_tokens=5,
53 | )
54 |
55 | # Get output
56 | generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])
57 | print(f"{generated_output=}")
58 |
--------------------------------------------------------------------------------
/mammal/examples/carcinogenicity/config.yaml:
--------------------------------------------------------------------------------
1 | name: carcinogenicity_finetune
2 | root: "."
3 | model_dir: ${root}/${name}
4 | seed: 4224
5 |
6 | task:
7 | _target_: mammal.examples.carcinogenicity.task.CarcinogenicityTask
8 | _partial_: True
9 | # details about the arguments below can be found on mammal.examples.carcinogenicity.task.CarcinogenicityTask()
10 |
11 | data_module_kwargs:
12 | # details about the arguments below can be found on mammal.examples.carcinogenicity.pl_data_module.CarcinogenicityDataModule()
13 | batch_size: 15
14 | drug_max_seq_length: 300 # Maximum drug length in the dataset is 292.
15 | encoder_input_max_seq_len: 320 # 20 chars buffer for special tokens.
16 | labels_max_seq_len: 4
17 |
18 |
19 | # tokenizer
20 | tokenizer:
21 | tokenizer_path: ibm/biomed.omics.bl.sm.ma-ted-458m
22 | new_special_tokens:
23 | - ""
24 |
25 | model:
26 | mammal_kwargs: null # arguments for Mammal.__init__()
27 | pretrained_kwargs: # arguments for Mammal.from_pretrained() which triggered only if pretrained_model_name_or_path is not None
28 | pretrained_model_name_or_path: ibm/biomed.omics.bl.sm.ma-ted-458m
29 |
30 | # lightning module
31 | module:
32 | opt_callable:
33 | _target_: torch.optim.AdamW
34 | _partial_: true
35 | # arguments for torch.optim.AdamW()
36 | lr: 0.00001
37 |
38 | lr_sch_callable:
39 | _target_: mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler
40 | _partial_: True
41 | T_max: 10000
42 | num_warmup_steps: 300
43 | eta_min_factor: 0.1
44 |
45 | model_dir: ${model_dir}
46 | best_epoch_source:
47 | monitor: validation.metrics.carcinogenicity_acc # possible options are validation.metrics._
48 |
49 | mode: max
50 |
51 | # train
52 | trainer:
53 | # arguments for pytorch_lightning.Trainer()
54 | max_epochs: 100
55 | default_root_dir: ${model_dir}
56 | accelerator: "auto"
57 | devices: 1
58 | num_nodes: 1
59 | num_sanity_val_steps: 0
60 |
61 |
62 | # experiment tracker
63 | track_clearml: # arguments for fuse.dl.lightning.pl_funcs.start_clearml_logger
64 | project_name: "mammal/opensource"
65 | task_name: ${name}
66 | tags: "mammal"
67 | reuse_last_task_id: True
68 | continue_last_task: False
69 | offline_mode: False
70 |
71 | evaluate : false #if true then it will use lightning's validate on the test dataloader
72 |
73 | hydra:
74 | run:
75 | dir: ${model_dir}
76 | job:
77 | chdir: False
78 |
--------------------------------------------------------------------------------
/mammal/examples/scrna_cell_type/config.yaml:
--------------------------------------------------------------------------------
1 | name: cell_type_finetune
2 | root: "."
3 | model_dir: ${root}/${name}
4 | seed: 2024
5 |
6 | task:
7 | _target_: mammal.examples.scrna_cell_type.task.CellTypeTask
8 | _partial_: True
9 | data_module_kwargs:
10 | data_path: "data/Zheng_68k_preprocessed.h5ad" # this should be absolute or relative to the directory with the example code
11 | # this is the name of the observation the model will try to predict
12 | label_name: "cell-type"
13 | batch_size: 20
14 | # tokenizer_op is provided later, dynamically
15 | train_dl_kwargs: # Dataloader constructor parameters
16 | num_workers: 8
17 | valid_dl_kwargs: # Dataloader constructor parameters
18 | num_workers: 8
19 | # data_preprocessing is provided later, dynamically
20 | input_max_seq_length: 500
21 | encoder_input_max_seq_len: 512
22 | labels_max_seq_len: 20
23 |
24 | # tokenizer
25 | tokenizer:
26 | tokenizer_path: ibm-research/biomed.omics.bl.sm.ma-ted-458m
27 |
28 | model:
29 | pretrained_kwargs: # arguments for Mammal.from_pretrained() which triggered only if pretrained_model_name_or_path is not None
30 | pretrained_model_name_or_path: ibm-research/biomed.omics.bl.sm.ma-ted-458m
31 | # config_overrides:
32 | # use_lora: True
33 |
34 | # lightning module
35 | module:
36 | opt_callable:
37 | _target_: torch.optim.AdamW
38 | _partial_: true # should get also parameters
39 | lr: 0.00001
40 |
41 | lr_sch_callable:
42 | _target_: mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler
43 | _partial_: True
44 | T_max: 10000
45 | num_warmup_steps: 300
46 | eta_min_factor: 0.1
47 |
48 | model_dir: ${model_dir}
49 | best_epoch_source:
50 | monitor: validation.metrics.cell_type_acc # see metrics.py:classification_metrics
51 | mode: max
52 |
53 | # train
54 | trainer:
55 | # arguments for pytorch_lightning.Trainer()
56 | max_epochs: 100
57 | default_root_dir: ${model_dir}
58 | num_sanity_val_steps: 0
59 | # val_check_interval: 0.1
60 |
61 |
62 | # experiment tracker
63 | track_clearml: # arguments for fuse.dl.lightning.pl_funcs.start_clearml_logger
64 | project_name: "mammal/opensource"
65 | task_name: ${name}
66 | tags: "mammal"
67 | reuse_last_task_id: True
68 | continue_last_task: False
69 | offline_mode: False
70 |
71 | evaluate : false #if true then it will use lightning's validate on the test dataloader
72 |
73 | hydra:
74 | run:
75 | dir: ${model_dir}
76 | job:
77 | chdir: False
78 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | t.py
131 | sol_test
132 |
133 | .vscode/settings.json
134 | .vscode/launch.json
135 | example_solubility_data
136 |
--------------------------------------------------------------------------------
/mammal/examples/carcinogenicity/main_infer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import click
4 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
5 |
6 | from mammal.examples.carcinogenicity.task import CarcinogenicityTask
7 | from mammal.keys import CLS_PRED, SCORES
8 | from mammal.model import Mammal
9 |
10 |
11 | @click.command()
12 | @click.argument("finetune_output_dir")
13 | @click.argument("drug_seq")
14 | @click.option(
15 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')."
16 | )
17 | def main(finetune_output_dir: str, drug_seq: str, device: str) -> None:
18 | click.echo(f"Using device: {device}")
19 | infer(finetune_output_dir=finetune_output_dir, drug_seq=drug_seq, device=device)
20 |
21 |
22 | def infer(finetune_output_dir: str, drug_seq: str, device: str) -> dict:
23 | """
24 | :param finetune_output_dir: model_dir argument in fine-tuning
25 | :param drug_seq: smiles sequence of a drug
26 | """
27 | # Load tokenizer from the checkpoint dir.
28 | # NOTE It's important to load the tokenizer from the fine-tuning phase, since we've introduced a new token to it.
29 | tokenizer_op = ModularTokenizerOp.from_pretrained(
30 | os.path.join(finetune_output_dir, "tokenizer")
31 | )
32 |
33 | # Load model from the best checkpoint.
34 | # NOTE The total order of the checkpoints is induced by the monitored metric (see config)
35 | nn_model = Mammal.from_pretrained(
36 | pretrained_model_name_or_path=os.path.join(
37 | finetune_output_dir, "best_epoch.ckpt"
38 | )
39 | )
40 | nn_model.eval()
41 | nn_model.to(device=device)
42 |
43 | # Format the input drug sequence value into a prompt that fits MAMMAL's training paradigm.
44 | sample_dict = {"drug_seq": drug_seq}
45 | sample_dict = CarcinogenicityTask.data_preprocessing(
46 | sample_dict=sample_dict,
47 | sequence_key="drug_seq",
48 | tokenizer_op=tokenizer_op,
49 | device=nn_model.device,
50 | )
51 |
52 | # running in generate mode
53 | batch_dict = nn_model.generate(
54 | [sample_dict],
55 | output_scores=True,
56 | return_dict_in_generate=True,
57 | max_new_tokens=5,
58 | )
59 |
60 | # Post-process the model's output
61 | ans = CarcinogenicityTask.process_model_output(
62 | tokenizer_op=tokenizer_op,
63 | decoder_output=batch_dict[CLS_PRED][0],
64 | decoder_output_scores=batch_dict[SCORES][0],
65 | )
66 |
67 | # Print prediction
68 | print(ans)
69 | return ans
70 |
71 |
72 | if __name__ == "__main__":
73 | main()
74 |
--------------------------------------------------------------------------------
/mammal/examples/protein_solubility/config.yaml:
--------------------------------------------------------------------------------
1 | name: mammal_solubility_finetune
2 | root: "."
3 | model_dir: ${root}/${name}
4 | seed: 1234
5 |
6 | task:
7 | _target_: mammal.examples.protein_solubility.task.ProteinSolubilityTask
8 | _partial_: True
9 | # details about the arguments below can be found on mammal.examples.protein_solubility.task.ProteinSolubilityTask()
10 | name: solubility_prediction
11 | seed: ${seed}
12 | data_module_kwargs:
13 | # details about the arguments below can be found on mammal.examples.protein_solubility.pl_data_module.ProteinSolubilityDataModule()
14 | data_path: ./example_solubility_data
15 |
16 | train_dl_kwargs: # Dataloader constructor parameters
17 | num_workers: 8
18 | valid_dl_kwargs: # Dataloader constructor parameters
19 | num_workers: 8
20 |
21 | batch_size: 6
22 | protein_max_seq_length: 1250
23 | encoder_input_max_seq_len: 1260
24 | labels_max_seq_len: 4
25 |
26 | # tokenizer
27 | tokenizer:
28 | tokenizer_path: ibm/biomed.omics.bl.sm.ma-ted-458m
29 |
30 | model:
31 | mammal_kwargs: null # arguments for Mammal.__init__()
32 | pretrained_kwargs: # arguments for Mammal.from_pretrained() which triggered only if pretrained_model_name_or_path is not None
33 | pretrained_model_name_or_path: ibm/biomed.omics.bl.sm.ma-ted-458m
34 |
35 | # lightning module
36 | module:
37 | opt_callable:
38 | _target_: torch.optim.AdamW
39 | _partial_: true
40 | # arguments for torch.optim.AdamW()
41 | lr: 0.00001
42 |
43 | lr_sch_callable:
44 | _target_: mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler
45 | _partial_: True
46 | # arguments for mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler()
47 | T_max: 20000
48 | eta_min_factor: 0.1
49 |
50 | model_dir: ${model_dir}
51 | best_epoch_source:
52 | # arguments for pytorch_lightning.callbacks.ModelCheckpoint()
53 | monitor: validation.metrics.solubility_prediction_acc # possible options are validation.metrics._
54 |
55 | mode: max
56 |
57 | # train
58 | trainer:
59 | # arguments for pytorch_lightning.Trainer()
60 | max_epochs: 1000
61 | default_root_dir: ${model_dir}
62 | accelerator: "auto"
63 | devices: 1
64 | num_nodes: 1
65 | strategy: "ddp_find_unused_parameters_true"
66 | use_distributed_sampler: False # Must be set when using a batch sampler
67 | num_sanity_val_steps: 0
68 | # limit_train_batches: 128
69 | # limit_val_batches: 128
70 |
71 | # experiment tracker
72 | track_clearml: # arguments for fuse.dl.lightning.pl_funcs.start_clearml_logger
73 | project_name: "mammal/opensource"
74 | task_name: ${name}
75 | tags: "mammal"
76 | reuse_last_task_id: True
77 | continue_last_task: False
78 | offline_mode: False
79 |
80 | evaluate : false #if true then it will use lightning's validate on the test dataloader
81 |
82 | hydra:
83 | run:
84 | dir: ${model_dir}
85 | job:
86 | chdir: False
87 |
--------------------------------------------------------------------------------
/mammal/examples/protein_solubility/main_infer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import click
4 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
5 |
6 | from mammal.examples.protein_solubility.task import ProteinSolubilityTask
7 | from mammal.keys import CLS_PRED, SCORES
8 | from mammal.model import Mammal
9 |
10 |
11 | @click.command()
12 | @click.option(
13 | "--finetune_output_dir",
14 | default=None,
15 | help="Specify the model dir (default: None to download from huggingface).",
16 | )
17 | @click.argument(
18 | "protein_seq",
19 | default="NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC",
20 | )
21 | @click.option(
22 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')."
23 | )
24 | def main(finetune_output_dir: str, protein_seq: str, device: str):
25 | protein_solubility_infer(
26 | finetune_output_dir=finetune_output_dir, protein_seq=protein_seq, device=device
27 | )
28 |
29 |
30 | def protein_solubility_infer(
31 | finetune_output_dir: str | None, protein_seq: str, device: str
32 | ):
33 | """
34 | :param finetune_output_dir: model_dir argument in finetuning or None for downloading from huggingface
35 | :param protein_seq: amino acid sequence of a protein
36 | """
37 | if finetune_output_dir is not None:
38 | # load tokenizer and model from finetune_output_dir
39 | tokenizer_op = ModularTokenizerOp.from_pretrained(
40 | os.path.join(finetune_output_dir, "tokenizer")
41 | )
42 | nn_model = Mammal.from_pretrained(
43 | pretrained_model_name_or_path=os.path.join(
44 | finetune_output_dir, "best_epoch.ckpt"
45 | )
46 | )
47 | else:
48 | # load tokenizer and model from huggingface
49 | tokenizer_op = ModularTokenizerOp.from_pretrained(
50 | "ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility"
51 | )
52 | nn_model = Mammal.from_pretrained(
53 | "ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility"
54 | )
55 | nn_model.eval()
56 | nn_model.to(device=device)
57 |
58 | # convert to MAMMAL style
59 | sample_dict = {"protein_seq": protein_seq}
60 | sample_dict = ProteinSolubilityTask.data_preprocessing(
61 | sample_dict=sample_dict,
62 | protein_sequence_key="protein_seq",
63 | tokenizer_op=tokenizer_op,
64 | device=nn_model.device,
65 | )
66 |
67 | # running in generate mode
68 | batch_dict = nn_model.generate(
69 | [sample_dict],
70 | output_scores=True,
71 | return_dict_in_generate=True,
72 | max_new_tokens=5,
73 | )
74 |
75 | # Post-process the model's output
76 | ans = ProteinSolubilityTask.process_model_output(
77 | tokenizer_op=tokenizer_op,
78 | decoder_output=batch_dict[CLS_PRED][0],
79 | decoder_output_scores=batch_dict[SCORES][0],
80 | )
81 |
82 | # Print prediction
83 | print(ans)
84 | return ans
85 |
86 |
87 | if __name__ == "__main__":
88 | main()
89 |
--------------------------------------------------------------------------------
/mammal/examples/dti_bindingdb_kd/config.yaml:
--------------------------------------------------------------------------------
1 | name: mammal_tdc_dti_bindingdb_kd
2 | root: "."
3 | model_dir: ${root}/${name}
4 | seed: 1234
5 |
6 |
7 | task:
8 | _target_: mammal.examples.dti_bindingdb_kd.task.DtiBindingdbKdTask
9 | _partial_: True
10 | # details about the arguments below can be found on mammal.examples.dti_bindingdb_kd.task.DtiBindingdbKdTask()
11 | name: dti_bindingdb_kd
12 | seed: ${seed}
13 | norm_y_mean: 5.79384684128215
14 | norm_y_std: 1.33808027428196
15 |
16 | data_module_kwargs:
17 | # details about the arguments below can be found on mammal.examples.dti_bindingdb_kd.pl_data_module.DtiBindingdbKdDataModule()
18 | load_datasets_kwargs:
19 | split_type: "cold_split"
20 | split_column: ["Drug", "Target"]
21 | train_dl_kwargs: # Dataloader constructor parameters
22 | num_workers: 8
23 | valid_dl_kwargs: # Dataloader constructor parameters
24 | num_workers: 8
25 |
26 | batch_size: 8 # over a100_80g
27 | target_max_seq_length: 1250
28 | drug_max_seq_length: 256
29 | encoder_input_max_seq_len: 1560
30 |
31 |
32 | # tokenizer
33 | tokenizer:
34 | tokenizer_path: ibm/biomed.omics.bl.sm.ma-ted-458m
35 |
36 | model:
37 | mammal_kwargs: null # arguments for Mammal.__init__()
38 | pretrained_kwargs: # arguments for Mammal.from_pretrained() which triggered only if pretrained_model_name_or_path is not None
39 | pretrained_model_name_or_path: ibm/biomed.omics.bl.sm.ma-ted-458m
40 |
41 | # lightning module
42 | module:
43 | opt_callable:
44 | _target_: torch.optim.AdamW
45 | _partial_: true
46 | # arguments for torch.optim.AdamW()
47 | lr: 0.00001
48 |
49 | lr_sch_callable:
50 | _target_: mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler
51 | _partial_: True
52 | # arguments for mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler()
53 | T_max: 100000
54 | eta_min_factor: 0.1
55 |
56 | model_dir: ${model_dir}
57 | best_epoch_source:
58 | # arguments for pytorch_lightning.callbacks.ModelCheckpoint()
59 | monitor: validation.losses.dti_bindingdb_kd_scalars_mse # possible options are validation.metrics._
60 |
61 | mode: min
62 |
63 | # train
64 | trainer:
65 | # arguments for pytorch_lightning.Trainer()
66 | max_epochs: 1000
67 | default_root_dir: ${model_dir}
68 | accelerator: "auto"
69 | devices: 1
70 | num_nodes: 1
71 | strategy: "ddp_find_unused_parameters_true"
72 | use_distributed_sampler: False # Must be set when using a batch sampler
73 | num_sanity_val_steps: 0
74 | # limit_train_batches: 128
75 | # limit_val_batches: 128
76 |
77 | # experiment tracker
78 | track_clearml: # arguments for fuse.dl.lightning.pl_funcs.start_clearml_logger
79 | project_name: "mammal/opensource"
80 | task_name: ${name}
81 | tags: "mammal"
82 | reuse_last_task_id: True
83 | continue_last_task: False
84 | offline_mode: False
85 |
86 | evaluate : false #if true then it will use lightning's validate on the test dataloader
87 |
88 | hydra:
89 | run:
90 | dir: ${model_dir}
91 | job:
92 | chdir: False
93 |
--------------------------------------------------------------------------------
/mammal/examples/tests/test_protein_solubility_prediction.py:
--------------------------------------------------------------------------------
1 | import socket
2 | from pathlib import Path
3 |
4 | import hydra
5 | import pytest
6 | from hydra.core.global_hydra import GlobalHydra
7 |
8 | from mammal.examples.protein_solubility.main_infer import protein_solubility_infer
9 | from mammal.main_finetune import main as main_finetune
10 |
11 | TEST_CONFIG_DIRPATH = str(Path(__file__).parents[0] / "../protein_solubility")
12 | TEST_CONFIG_FILENAME = "config.yaml"
13 |
14 |
15 | @pytest.fixture(autouse=True, scope="session")
16 | def _clean_hydra() -> None:
17 | GlobalHydra.instance().clear()
18 |
19 |
20 | @pytest.fixture(scope="session")
21 | def model_dir(tmp_path_factory: pytest.TempPathFactory):
22 | if "ccc" not in socket.gethostname():
23 | pytest.skip("Full tests requires resources")
24 |
25 | model_dir_path = tmp_path_factory.mktemp("test_protein_solubility") / "test"
26 | return model_dir_path
27 |
28 |
29 | def test_finetune(model_dir: str):
30 | print(model_dir)
31 | OVERRIDES = [
32 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials.
33 | "trainer.max_epochs=2", # Small number for a faster run.
34 | "+trainer.limit_train_batches=3", # Small number for a faster run.
35 | "+trainer.limit_val_batches=2", # Small number for a faster run.
36 | f"model_dir={model_dir}",
37 | "task.data_module_kwargs.train_dl_kwargs.num_workers=0", # Using parallelization cause co
38 | "task.data_module_kwargs.valid_dl_kwargs.num_workers=0", # Using parallelization cause co
39 | "root=.",
40 | "name=sol_test",
41 | ]
42 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"):
43 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES)
44 | cfg = hydra.utils.instantiate(_cfg)
45 | main_finetune(cfg)
46 |
47 |
48 | def test_evaluate(model_dir: str):
49 | OVERRIDES = [
50 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials.
51 | "trainer.max_epochs=1", # Small number for a faster run.
52 | "+trainer.limit_test_batches=10", # Small number for a faster run.
53 | f"model_dir={model_dir}",
54 | "task.data_module_kwargs.train_dl_kwargs.num_workers=0", # Using parallelization cause co
55 | "task.data_module_kwargs.valid_dl_kwargs.num_workers=0", # Using parallelization cause co
56 | "root=.",
57 | "name=sol_test",
58 | "evaluate=True",
59 | f"model.pretrained_kwargs.pretrained_model_name_or_path={model_dir}/best_epoch.ckpt",
60 | ]
61 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"):
62 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES)
63 | cfg = hydra.utils.instantiate(_cfg)
64 | main_finetune(cfg)
65 |
66 |
67 | def test_infer(model_dir: str):
68 | protein_solubility_infer(
69 | finetune_output_dir=model_dir,
70 | protein_seq="NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC",
71 | device="cpu",
72 | )
73 |
--------------------------------------------------------------------------------
/mammal/examples/tests/test_drug_carcinogenicity_classification.py:
--------------------------------------------------------------------------------
1 | import socket
2 | from pathlib import Path
3 |
4 | import hydra
5 | import pytest
6 | from hydra.core.global_hydra import GlobalHydra
7 |
8 | from mammal.examples.carcinogenicity.main_infer import infer
9 | from mammal.main_finetune import main as main_finetune
10 | from mammal.model import Mammal
11 |
12 | TEST_CONFIG_DIRPATH = str(Path(__file__).parents[0] / "../carcinogenicity")
13 | TEST_CONFIG_FILENAME = "config.yaml"
14 |
15 |
16 | @pytest.fixture(autouse=True, scope="session")
17 | def _clean_hydra() -> None:
18 | GlobalHydra.instance().clear()
19 |
20 |
21 | @pytest.fixture(scope="session")
22 | def tmp_model_dir(tmp_path_factory):
23 | model_dir_path = tmp_path_factory.mktemp("test_carcinogenicity") / "test"
24 | return model_dir_path
25 |
26 |
27 | @pytest.fixture(scope="session")
28 | def finetuned_model_dir(tmp_model_dir: str):
29 | if "ccc" not in socket.gethostname():
30 | pytest.skip("Full tests requires resources")
31 | model_dir = tmp_model_dir
32 | print(f"\n{model_dir=}")
33 | OVERRIDES = [
34 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials.
35 | "trainer.max_epochs=2", # Small number for a faster run.
36 | "+trainer.limit_train_batches=3", # Small number for a faster run.
37 | "+trainer.limit_val_batches=2", # Small number for a faster run.
38 | f"model_dir={model_dir}",
39 | "root=.",
40 | "name=carcinogenicity_test",
41 | ]
42 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"):
43 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES)
44 | main_finetune(_cfg)
45 | return model_dir
46 |
47 |
48 | def test_finetune(finetuned_model_dir: Path):
49 | # the actual work is done in the fixture, here we just read the finetuned model.
50 | model = Mammal.from_pretrained(
51 | pretrained_model_name_or_path=str(finetuned_model_dir / "best_epoch.ckpt")
52 | )
53 | assert model is not None
54 |
55 |
56 | def test_evaluate(finetuned_model_dir: str):
57 | OVERRIDES = [
58 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials.
59 | "trainer.max_epochs=1", # Small number for a faster run.
60 | "+trainer.limit_test_batches=10", # Small number for a faster run.
61 | f"model_dir={finetuned_model_dir}",
62 | "root=.",
63 | "name=carcinogenicity_test",
64 | "evaluate=True",
65 | f"model.pretrained_kwargs.pretrained_model_name_or_path={finetuned_model_dir}/best_epoch.ckpt",
66 | ]
67 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"):
68 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES)
69 | # main_finetune does "evaluate" if evalute=True, as is set above.
70 | main_finetune(_cfg)
71 |
72 |
73 | def test_infer(finetuned_model_dir: str):
74 | infer(
75 | finetune_output_dir=finetuned_model_dir,
76 | drug_seq="CC(CCl)OC(C)CCl",
77 | device="cpu",
78 | )
79 |
--------------------------------------------------------------------------------
/mammal/lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from peft import LoraConfig, TaskType, get_peft_model
3 | from peft.utils import PeftType
4 | from transformers import PreTrainedModel
5 |
6 |
7 | def get_lora_model(
8 | model: PreTrainedModel,
9 | peft_type: str | PeftType | None = None,
10 | auto_mapping: dict | None = None,
11 | base_model_name_or_path: str | None = None,
12 | revision: str | None = None,
13 | task_type: str | TaskType | None = None,
14 | inference_mode: bool = False,
15 | r: int = 8,
16 | target_modules: list[str] | str | None = None,
17 | lora_alpha: int = 8,
18 | lora_dropout: float = 0,
19 | fan_in_fan_out: bool = False,
20 | bias: str = "none",
21 | modules_to_save: list[str] | None = None,
22 | init_lora_weights: bool = True,
23 | layers_to_transform: list[int] | None = None,
24 | layers_pattern: str | None = None,
25 | ) -> torch.nn.Module:
26 | """
27 | Freeze model params and make lora config. Then convert model to lora.
28 |
29 | Args:
30 | model (`PreTrainedModel`): The model to convert to Lora.
31 | r (`int`): Lora attention dimension.
32 | target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to.
33 | lora_alpha (`int`): The alpha parameter for Lora scaling.
34 | lora_dropout (`float`): The dropout probability for Lora layers.
35 | fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out).
36 | For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.:
37 | bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'
38 | modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable
39 | and saved in the final checkpoint.
40 | layers_to_transform (`Union[List[int],int]`):
41 | The layer indexes to transform, if this argument is specified, it will apply the LoRA transformations on
42 | the layer indexes that are specified in this list. If a single integer is passed, it will apply the LoRA
43 | transformations on the layer at this index.
44 | layers_pattern (`str`):
45 | The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer
46 | pattern is not in the common layers pattern.
47 | """
48 |
49 | # build lora config
50 | config = LoraConfig(
51 | peft_type=peft_type,
52 | auto_mapping=auto_mapping,
53 | base_model_name_or_path=base_model_name_or_path,
54 | revision=revision,
55 | task_type=task_type,
56 | inference_mode=inference_mode,
57 | r=r,
58 | target_modules=target_modules,
59 | lora_alpha=lora_alpha,
60 | lora_dropout=lora_dropout,
61 | fan_in_fan_out=fan_in_fan_out,
62 | bias=bias,
63 | modules_to_save=modules_to_save,
64 | init_lora_weights=init_lora_weights,
65 | layers_to_transform=layers_to_transform,
66 | layers_pattern=layers_pattern,
67 | )
68 |
69 | model = get_peft_model(model, config)
70 | model.print_trainable_parameters()
71 | return model
72 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # configuration approach followed:
2 | # - whenever possible, prefer pyproject.toml
3 | # - for configurations insufficiently supported by pyproject.toml, use setup.cfg instead
4 | # - setup.py discouraged; minimal stub included only for compatibility with legacy tools
5 |
6 |
7 | [build-system]
8 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
9 | build-backend = "setuptools.build_meta"
10 |
11 | [project]
12 | name = "biomed-multi-alignment"
13 | description = "MAMMAL (Molecular Aligned Multi-Modal Architecture and Language), a flexible, multi-domain architecture with an adaptable task prompt syntax."
14 | authors = [
15 | {name="IBM Research"},
16 | {name="Moshe Raboh", email="moshiko.raboh@ibm.com"}
17 | ]
18 | version = "0.2.1"
19 | readme = "README.md"
20 | license = {file = "LICENSE.txt"}
21 | # due to how PEP 440 defines version matching, prefer [incl, excl) definitions like below:
22 | requires-python = ">=3.10, <3.13"
23 | dependencies = [
24 | "fuse-med-ml==0.4.0",
25 | "tensorflow>=2.17",
26 | "peft",
27 | "tabulate",
28 | "clearml",
29 | "hydra-core",
30 | "pytest",
31 | ]
32 |
33 | [project.optional-dependencies]
34 | examples = [
35 | "PyTDC",
36 | "anndata",
37 | "click",
38 | ]
39 |
40 | [project.urls]
41 | repository = "https://github.com/BiomedSciAI/biomed-multi-alignment"
42 |
43 | [tool.setuptools.packages]
44 | find = {}
45 |
46 | [tool.ruff]
47 | target-version = "py310"
48 | extend-include = ["*.ipynb"]
49 |
50 | # Activate all the rules that are pyupgrade-related
51 | lint.select = [
52 | # "UP", # pyupgrade
53 | "D", # pydocstyle
54 | "PT", # pytest style checking
55 | "C4", # comprehensions style checking
56 | "PD", # pandas style checking
57 | "F", # pyflakes: is-literal
58 | "W605", # pycodestyle: invalid-escape-sequence
59 | "I", # isort
60 | ]
61 | # On top of the Google convention, disable `D417`, which requires
62 | # documentation for every function parameter.
63 | lint.ignore = [
64 | "D100", # pydocstyle: Missing module docstring
65 | "D101", # pydocstyle: Missing module-level docstring
66 | "D102", # pydocstyle: Missing docstring in public module
67 | "D103", # pydocstyle: Missing class docstring
68 | "D105", # pydocstyle: Missing docstring in magic method
69 | "D107", # pydocstyle: Missing parameter descriptions in the docstring
70 | "D203", # pydocstyle: 1 blank line required before class docstring
71 | "D205", # pydocstyle: 1 blank line required between summary line and description
72 | "D212", # pydocstyle: Multi-line docstring summary should start at the first line
73 | "D401", # pydocstyle: First line should be in imperative mood
74 | "D417", # pydocstyle: Missing argument descriptions in the docstring
75 | "F841", # flake8: unused variable
76 | "PD011", # pandas do not use .values (false positives causing bugs in torch code)
77 | "PD015", # Use .merge method instead of pd.merge function. They have equivalent functionality.
78 | "PT011", #TODO remove
79 | "UP035", # TODO types. remove
80 | ]
81 | [lint.per-file-ignores]
82 | "__init__.py" = ["I001"]
83 |
84 |
85 | [tool.coverage.report]
86 |
87 | exclude_lines = ["pragma: no cover", "abc.abstractmethod", "@abstract"]
88 |
89 | #[tool.coverage.run]
90 | #omit = ["gene_benchmark/tests/*"]
91 |
92 | [tool.mypy]
93 | disable_error_code = [
94 | "index",
95 | "override",
96 | "arg-type",
97 | "union-attr"
98 | ]
99 |
--------------------------------------------------------------------------------
/mammal/lr_schedulers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim import Optimizer
3 | from torch.optim.lr_scheduler import ConstantLR, LinearLR, MultiStepLR, SequentialLR
4 | from transformers import get_inverse_sqrt_schedule
5 |
6 |
7 | def inverse_sqrt_with_warmup_lr_scheduler(
8 | optimizer: Optimizer,
9 | num_warmup_steps: int = 10000,
10 | timescale: int = 1,
11 | last_epoch: int = -1,
12 | ) -> SequentialLR:
13 | """
14 | T5 learning rate scheduler with the original hyper parameter (not necessarily fits out use case)
15 |
16 | '''
17 | During pre-training, we use an “inverse square root” learning rate schedule: 1/sqrt(max(n, k))
18 | where n is the current training iteration and k is the number of warm-up steps (set to 10^4 in all of our experiments).
19 | This sets a constant learning rate of 0.01 for the first 104 steps, then exponentially decays the learning rate until pre-training is over.
20 | '''
21 | """
22 | constant_sch = ConstantLR(optimizer, factor=1.0, total_iters=num_warmup_steps)
23 | inverse_sqrt_schedule = get_inverse_sqrt_schedule(
24 | optimizer=optimizer,
25 | num_warmup_steps=0,
26 | timescale=timescale,
27 | last_epoch=last_epoch,
28 | )
29 | return SequentialLR(
30 | optimizer,
31 | schedulers=[constant_sch, inverse_sqrt_schedule],
32 | milestones=[num_warmup_steps],
33 | )
34 |
35 |
36 | def cosine_annealing_with_warmup_lr_scheduler(
37 | optimizer: Optimizer,
38 | *,
39 | num_warmup_steps: int = 2000,
40 | start_factor: float = 0.3333333333333333,
41 | T_max: int = 40000,
42 | eta_min_factor: float = 0.1,
43 | ) -> SequentialLR:
44 | """
45 | cosine annealing with warmup, followed by a constant learning rate
46 | num_warmup_steps: warmup period, number of steps until the learning rate reach the maximum
47 | start_factor: the factor to start with in warmup period
48 | T_max: number of steps until the cosine reach the minimum
49 | eta_min_factor: minimum learning factor of the cosine scheduler
50 | """
51 | linear_sch = LinearLR(
52 | optimizer, start_factor=start_factor, total_iters=num_warmup_steps
53 | )
54 | assert (
55 | len(optimizer.param_groups) == 1
56 | ), f"this learning rate scheduler support single params group, got {optimizer.param_groups=}"
57 |
58 | initial_lr = [group["initial_lr"] for group in optimizer.param_groups][0]
59 | eta_min = eta_min_factor * initial_lr
60 | cosine_lr_sch = torch.optim.lr_scheduler.CosineAnnealingLR(
61 | optimizer, T_max=T_max, eta_min=eta_min
62 | )
63 | multi_step_lr_sch = MultiStepLR(
64 | optimizer=optimizer, milestones=[0], gamma=eta_min_factor
65 | )
66 | return SequentialLR(
67 | optimizer,
68 | schedulers=[linear_sch, cosine_lr_sch, multi_step_lr_sch],
69 | milestones=[num_warmup_steps, T_max],
70 | )
71 |
72 |
73 | def multistep_with_warmup_lr_scheduler(
74 | optimizer: Optimizer,
75 | *,
76 | num_warmup_steps: int = 2000,
77 | start_factor: float = 0.3333333333333333,
78 | milestones: list[int],
79 | gamma: float = 0.1,
80 | ) -> MultiStepLR:
81 | """
82 | multistep LR with warmup - used in finetunning
83 | """
84 | linear_sch = LinearLR(
85 | optimizer, start_factor=start_factor, total_iters=num_warmup_steps
86 | )
87 | multi_step_lr_sch = MultiStepLR(
88 | optimizer=optimizer,
89 | milestones=[m - num_warmup_steps for m in milestones],
90 | gamma=gamma,
91 | )
92 | return SequentialLR(
93 | optimizer,
94 | schedulers=[linear_sch, multi_step_lr_sch],
95 | milestones=[num_warmup_steps],
96 | )
97 |
--------------------------------------------------------------------------------
/mammal/keys.py:
--------------------------------------------------------------------------------
1 | """
2 | List the keys available for each task / expected from each task in this t5 multitask generic implementation
3 | """
4 |
5 | # DATA
6 | SAMPLE_ID = "data.sample_id"
7 |
8 | # expected outputs foreach task - data pipeline
9 | ENCODER_INPUTS_TOKENS = "data.encoder_input_token_ids" # encoder input token ids
10 | ENCODER_INPUTS_STR = "data.query.encoder_input" # the original string representation of encoder input - used for debug
11 | ENCODER_INPUTS_ATTENTION_MASK = "data.encoder_input_attention_mask" # attention mask of the tokenized encoder input (output of the tokenizer)
12 | DECODER_INPUTS_STR = "data.query.decoder_input" # the original string representation of decoder input - used for debug
13 | DECODER_INPUTS_TOKENS = "data.decoder_input_token_ids" # decoder input token ids (decoder start token followed by labels token ids)
14 | DECODER_INPUTS_ATTENTION_MASK = "data.decoder_input_attention_mask" # attention mask of the tokenized decoder input (output of the tokenizer)
15 | LABELS_TOKENS = "data.labels_token_ids" # labels token ids
16 | LABELS_ATTENTION_MASK = "data.labels_attention_mask" # attention mask of the tokenized labels (output of the tokenizer)
17 | LABELS_STR = (
18 | "data.query.labels" # the original string representation of labels - used for debug
19 | )
20 |
21 | # adding custom embeddings
22 | ENCODER_INPUT_ADD_EMBEDDINGS = "data.encoder_input_add_embeddings" # optional, can be used to add (in additional to token_ids) custom embeddings
23 |
24 |
25 | # the list of keys each task (data_module) must add to sample_dict for training in encoder only mode
26 | DATA_KEYS_ENCODER = [
27 | ENCODER_INPUTS_TOKENS,
28 | ENCODER_INPUTS_ATTENTION_MASK,
29 | ENCODER_INPUTS_STR,
30 | LABELS_TOKENS,
31 | LABELS_ATTENTION_MASK,
32 | LABELS_STR,
33 | SAMPLE_ID,
34 | ]
35 |
36 | # the list of keys each task (data_module) must add to sample_dict for training in encoder-decoder mode
37 | DATA_KEYS = DATA_KEYS_ENCODER + [DECODER_INPUTS_TOKENS, DECODER_INPUTS_ATTENTION_MASK]
38 |
39 | # MODEL
40 | # expected model outputs
41 | LOGITS = "model.out.logits"
42 | SCORES = "model.out.scores" # model output after softmax
43 | CLS_PRED = "model.out.cls_pred" # result argmax() to get teacher forcing prediction
44 | CE_LOSS = "model.out.loss" # cross-entropy loss calculated in t5 model
45 | ENCODER_LAST_HIDDEN_STATE = "model.out.encoder_last_hidden_state" # encoder head logits
46 |
47 | ########################################
48 | #### related to scalars inputs/outputs
49 | ########################################
50 |
51 |
52 | # logits of the scalars output prediction head - a single scalar is predicted per input element
53 | # active only in encoder-only mode!
54 | SCALARS_PREDICTION_HEAD_LOGITS = "model.out.scalars_prediction_logits"
55 |
56 | ENCODER_INPUTS_SCALARS = "data.encoder_input.scalars"
57 | # a float tensor with the values of scalars. A default value is used for elements that are not scalars. Its length is the number of input elements
58 | ENCODER_INPUTS_SCALARS_VALUES = ENCODER_INPUTS_SCALARS + ".values"
59 | # a boolean tensor with the values of scalars. True/False per element describes what are scalar elements . Its length is the number of input elements
60 | ENCODER_INPUTS_SCALARS_VALID_MASK = ENCODER_INPUTS_SCALARS + ".valid_mask"
61 |
62 |
63 | LABELS_SCALARS = "data.labels.scalars"
64 | # a float tensor with the values of scalars. A default value is used for elements that are not scalars. Its length is the number of input elements
65 | LABELS_SCALARS_VALUES = LABELS_SCALARS + ".values"
66 | # a boolean tensor with the values of scalars. True/False per element describes what are scalar elements . Its length is the number of input elements
67 | LABELS_SCALARS_VALID_MASK = LABELS_SCALARS + ".valid_mask"
68 |
--------------------------------------------------------------------------------
/tutorials/begginer_inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Inference using MAMMAL"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "Install `biomed-multi-alignment` package. One can also clone and install it in editable model."
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "!pip install biomed-multi-alignment"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "Run simplest inference script"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "import torch\n",
40 | "from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp\n",
41 | "from mammal.model import Mammal\n",
42 | "from mammal.keys import *"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "# Check if CUDA is available\n",
52 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
53 | "print(f\"Using device: {device}\")\n",
54 | "\n",
55 | "# Load Model and set it to evaluation mode\n",
56 | "model = Mammal.from_pretrained(\"ibm/biomed.omics.bl.sm.ma-ted-458m\")\n",
57 | "model.eval()\n",
58 | "model.to(device=device)\n",
59 | "\n",
60 | "\n",
61 | "# Load Tokenizer\n",
62 | "tokenizer_op = ModularTokenizerOp.from_pretrained(\"ibm/biomed.omics.bl.sm.ma-ted-458m\")\n",
63 | "\n",
64 | "# Prepare Input Prompt\n",
65 | "protein_calmodulin = \"MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK\"\n",
66 | "protein_calcineurin = \"MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ\"\n",
67 | "\n",
68 | "# Create and load sample\n",
69 | "sample_dict = dict()\n",
70 | "# Formatting prompt to match pre-training syntax\n",
71 | "sample_dict[ENCODER_INPUTS_STR] = f\"<@TOKENIZER-TYPE=AA>{protein_calmodulin}{protein_calcineurin}\"\n",
72 | "\n",
73 | "# Tokenize\n",
74 | "tokenizer_op(\n",
75 | " sample_dict=sample_dict,\n",
76 | " key_in=ENCODER_INPUTS_STR,\n",
77 | " key_out_tokens_ids=ENCODER_INPUTS_TOKENS,\n",
78 | " key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,\n",
79 | ")\n",
80 | "sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(sample_dict[ENCODER_INPUTS_TOKENS]).to(device=device)\n",
81 | "sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(sample_dict[ENCODER_INPUTS_ATTENTION_MASK]).to(device=device)\n",
82 | "\n",
83 | "# Generate Prediction\n",
84 | "batch_dict = model.generate(\n",
85 | " [sample_dict],\n",
86 | " output_scores=True,\n",
87 | " return_dict_in_generate=True,\n",
88 | " max_new_tokens=5,\n",
89 | ")\n",
90 | "\n",
91 | "# Get output\n",
92 | "generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])\n",
93 | "print(f\"{generated_output=}\")"
94 | ]
95 | }
96 | ],
97 | "metadata": {
98 | "language_info": {
99 | "name": "python"
100 | }
101 | },
102 | "nbformat": 4,
103 | "nbformat_minor": 2
104 | }
105 |
--------------------------------------------------------------------------------
/mammal/examples/tests/test_dti_bindingdb_kd.py:
--------------------------------------------------------------------------------
1 | import socket
2 | from pathlib import Path
3 |
4 | import hydra
5 | import pytest
6 | from hydra.core.global_hydra import GlobalHydra
7 |
8 | from mammal.examples.dti_bindingdb_kd.main_infer import dti_bindingdb_kd_infer
9 | from mammal.main_finetune import main as main_finetune
10 |
11 | TEST_CONFIG_DIRPATH = str(Path(__file__).parents[0] / "../dti_bindingdb_kd")
12 | TEST_CONFIG_FILENAME = "config.yaml"
13 |
14 |
15 | @pytest.fixture(autouse=True, scope="session")
16 | def _clean_hydra() -> None:
17 | GlobalHydra.instance().clear()
18 |
19 |
20 | @pytest.fixture(scope="session")
21 | def model_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
22 | if "ccc" not in socket.gethostname():
23 | pytest.skip("Full tests requires resources")
24 |
25 | model_dir_path = tmp_path_factory.mktemp("test_dti_bindingdb_kd") / "test"
26 | return model_dir_path
27 |
28 |
29 | @pytest.mark.xfail(reason="tokenizer only available on the CCC for now")
30 | def test_finetune(model_dir: str) -> None:
31 | print(model_dir)
32 | OVERRIDES = [
33 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials.
34 | "trainer.accelerator=auto",
35 | "trainer.max_epochs=2", # Small number for a faster run.
36 | "+trainer.limit_train_batches=3", # Small number for a faster run. This is not applicable for validation mb per epoch! (WHY?)
37 | "+trainer.limit_val_batches=2", # Small number for a faster run. This is not applicable for validation mb per epoch! (WHY?)
38 | "task.data_module_kwargs.batch_size=1",
39 | f"model_dir={model_dir}",
40 | "task.data_module_kwargs.train_dl_kwargs.num_workers=0", # Using parallelization cause co
41 | "task.data_module_kwargs.valid_dl_kwargs.num_workers=0", # Using parallelization cause co
42 | "root=.",
43 | "name=dti_bindingdb_kd_test",
44 | ]
45 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"):
46 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES)
47 | cfg = hydra.utils.instantiate(_cfg)
48 | main_finetune(cfg)
49 |
50 |
51 | @pytest.mark.xfail(reason="tokenizer only available on the CCC for now")
52 | def test_evaluate(model_dir: str):
53 | OVERRIDES = [
54 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials.
55 | # "train.trainer.accelerator=cpu", # Travis doesn't have GPU.
56 | "trainer.accelerator=auto",
57 | "trainer.max_epochs=1",
58 | "task.data_module_kwargs.batch_size=1",
59 | "+trainer.limit_test_batches=10", # Small number for a faster run. This is not applicable for validation mb per epoch! (WHY?)
60 | f"model_dir={model_dir}",
61 | "task.data_module_kwargs.train_dl_kwargs.num_workers=0", # Using parallelization cause co
62 | "task.data_module_kwargs.valid_dl_kwargs.num_workers=0", # Using parallelization cause co
63 | "root=.",
64 | "name=dti_bindingdb_kd_test",
65 | "evaluate=True",
66 | f"model.pretrained_kwargs.pretrained_model_name_or_path={model_dir}/best_epoch.ckpt",
67 | ]
68 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"):
69 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES)
70 | cfg = hydra.utils.instantiate(_cfg)
71 | main_finetune(cfg)
72 |
73 |
74 | @pytest.mark.xfail(reason="tokenizer only available on the CCC for now")
75 | def test_infer(model_dir: str):
76 | dti_bindingdb_kd_infer(
77 | finetune_output_dir=model_dir,
78 | target_seq="NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC",
79 | drug_seq="CC(=O)NCCC1=CNc2c1cc(OC)cc2",
80 | norm_y_mean=0.0,
81 | norm_y_std=1.0,
82 | device="cpu",
83 | )
84 |
--------------------------------------------------------------------------------
/.github/workflows/clone.yml:
--------------------------------------------------------------------------------
1 | name: GitHub Clone Count Update Everyday
2 |
3 | on:
4 | schedule:
5 | - cron: "0 */24 * * *"
6 | workflow_dispatch:
7 |
8 | jobs:
9 | build:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v2
14 |
15 | - name: gh login
16 | run: echo "${{ secrets.SECRET_TOKEN }}" | gh auth login --with-token
17 |
18 | - name: parse latest clone count
19 | run: |
20 | curl --user "${{ github.actor }}:${{ secrets.SECRET_TOKEN }}" \
21 | -H "Accept: application/vnd.github.v3+json" \
22 | https://api.github.com/repos/${{ github.repository }}/traffic/clones \
23 | > clone.json
24 |
25 | - name: create gist and download previous count
26 | id: set_id
27 | run: |
28 | if gh secret list | grep -q "GIST_ID"
29 | then
30 | echo "GIST_ID found"
31 | echo "GIST=${{ secrets.GIST_ID }}" >> $GITHUB_OUTPUT
32 | curl https://gist.githubusercontent.com/${{ github.actor }}/${{ secrets.GIST_ID }}/raw/clone.json > clone_before.json
33 | if cat clone_before.json | grep '404: Not Found'; then
34 | echo "GIST_ID not valid anymore. Creating another gist..."
35 | gist_id=$(gh gist create clone.json | awk -F / '{print $NF}')
36 | echo $gist_id | gh secret set GIST_ID
37 | echo "GIST=$gist_id" >> $GITHUB_OUTPUT
38 | cp clone.json clone_before.json
39 | git rm --ignore-unmatch CLONE.md
40 | fi
41 | else
42 | echo "GIST_ID not found. Creating a gist..."
43 | gist_id=$(gh gist create clone.json | awk -F / '{print $NF}')
44 | echo $gist_id | gh secret set GIST_ID
45 | echo "GIST=$gist_id" >> $GITHUB_OUTPUT
46 | cp clone.json clone_before.json
47 | fi
48 |
49 | - name: update clone.json
50 | run: |
51 | curl https://raw.githubusercontent.com/MShawon/github-clone-count-badge/master/main.py > main.py
52 | python3 main.py
53 |
54 | - name: Update gist with latest count
55 | run: |
56 | content=$(sed -e 's/\\/\\\\/g' -e 's/\t/\\t/g' -e 's/\"/\\"/g' -e 's/\r//g' "clone.json" | sed -E ':a;N;$!ba;s/\r{0,1}\n/\\n/g')
57 | echo '{"description": "${{ github.repository }} clone statistics", "files": {"clone.json": {"content": "'"$content"'"}}}' > post_clone.json
58 | curl -s -X PATCH \
59 | --user "${{ github.actor }}:${{ secrets.SECRET_TOKEN }}" \
60 | -H "Content-Type: application/json" \
61 | -d @post_clone.json https://api.github.com/gists/${{ steps.set_id.outputs.GIST }} > /dev/null 2>&1
62 |
63 | if [ ! -f CLONE.md ]; then
64 | shields="https://img.shields.io/badge/dynamic/json?color=success&label=Clone&query=count&url="
65 | url="https://gist.githubusercontent.com/${{ github.actor }}/${{ steps.set_id.outputs.GIST }}/raw/clone.json"
66 | repo="https://github.com/MShawon/github-clone-count-badge"
67 | echo ''> CLONE.md
68 | echo '
69 | **Markdown**
70 |
71 | ```markdown' >> CLONE.md
72 | echo "[]($repo)" >> CLONE.md
73 | echo '
74 | ```
75 |
76 | **HTML**
77 | ```html' >> CLONE.md
78 | echo "
" >> CLONE.md
79 | echo '```' >> CLONE.md
80 |
81 | git add CLONE.md
82 | git config --global user.name "GitHub Action"
83 | git config --global user.email "action@github.com"
84 | git commit -m "create clone count badge"
85 | fi
86 |
87 | - name: Push
88 | uses: ad-m/github-push-action@master
89 | with:
90 | github_token: ${{ secrets.GITHUB_TOKEN }}
91 |
--------------------------------------------------------------------------------
/mammal_mcp/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | .env
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
112 | .pdm.toml
113 | .pdm-python
114 | .pdm-build/
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
165 |
--------------------------------------------------------------------------------
/mammal/examples/dti_bindingdb_kd/main_infer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import click
4 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
5 |
6 | from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask
7 | from mammal.model import Mammal
8 |
9 |
10 | @click.command()
11 | @click.option(
12 | "--finetune_output_dir",
13 | default=None,
14 | help="Specify the model dir (default: None to download from huggingface).",
15 | )
16 | @click.argument(
17 | "target_seq",
18 | default="NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC",
19 | )
20 | @click.argument(
21 | "drug_seq",
22 | default="CC(=O)NCCC1=CNc2c1cc(OC)cc2",
23 | )
24 | @click.argument("norm_y_mean", default=5.79384684128215, type=float)
25 | @click.argument("norm_y_std", default=1.33808027428196, type=float)
26 | @click.option(
27 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')."
28 | )
29 | def main(
30 | finetune_output_dir: str | None,
31 | target_seq: str,
32 | drug_seq: str,
33 | norm_y_mean: float,
34 | norm_y_std: float,
35 | device: str,
36 | ):
37 | dti_bindingdb_kd_infer(
38 | finetune_output_dir=finetune_output_dir,
39 | target_seq=target_seq,
40 | drug_seq=drug_seq,
41 | norm_y_mean=norm_y_mean,
42 | norm_y_std=norm_y_std,
43 | device=device,
44 | )
45 |
46 |
47 | def dti_bindingdb_kd_infer(
48 | finetune_output_dir: str,
49 | target_seq: str,
50 | drug_seq: str,
51 | norm_y_mean: float,
52 | norm_y_std: float,
53 | device: str,
54 | ):
55 | """
56 | :param finetune_output_dir: model_dir argument in fine-tuning or None for downloading from huggingface
57 | :param target_seq: amino acid sequence of a target
58 | :param drug_seq: smiles representation of a drug
59 | :param norm_y_mean: specify the mean and std values used in fine-tuning
60 | :param norm_y_std: specify the mean and std values used in fine-tuning
61 | """
62 | if finetune_output_dir is not None:
63 | # load tokenizer and model from finetune_output_dir
64 | tokenizer_op = ModularTokenizerOp.from_pretrained(
65 | os.path.join(finetune_output_dir, "tokenizer")
66 | )
67 | nn_model = Mammal.from_pretrained(
68 | pretrained_model_name_or_path=os.path.join(
69 | finetune_output_dir, "best_epoch.ckpt"
70 | )
71 | )
72 | else:
73 | # load tokenizer and model from huggingface
74 | tokenizer_op = ModularTokenizerOp.from_pretrained(
75 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd"
76 | )
77 | nn_model = Mammal.from_pretrained(
78 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd"
79 | )
80 | nn_model.eval()
81 | nn_model.to(device=device)
82 |
83 | # convert to MAMMAL style
84 | sample_dict = {"target_seq": target_seq, "drug_seq": drug_seq}
85 | sample_dict = DtiBindingdbKdTask.data_preprocessing(
86 | sample_dict=sample_dict,
87 | tokenizer_op=tokenizer_op,
88 | target_sequence_key="target_seq",
89 | drug_sequence_key="drug_seq",
90 | norm_y_mean=None,
91 | norm_y_std=None,
92 | device=nn_model.device,
93 | )
94 |
95 | # forward pass - encoder_only mode which supports scalars predictions
96 | batch_dict = nn_model.forward_encoder_only([sample_dict])
97 |
98 | # Post-process the model's output
99 | batch_dict = DtiBindingdbKdTask.process_model_output(
100 | batch_dict,
101 | scalars_preds_processed_key="model.out.dti_bindingdb_kd",
102 | norm_y_mean=norm_y_mean,
103 | norm_y_std=norm_y_std,
104 | )
105 | ans = {
106 | "model.out.dti_bindingdb_kd": float(batch_dict["model.out.dti_bindingdb_kd"][0])
107 | }
108 |
109 | # Print prediction
110 | print(ans)
111 | return ans
112 |
113 |
114 | if __name__ == "__main__":
115 | main()
116 |
--------------------------------------------------------------------------------
/mammal/examples/carcinogenicity/pl_data_module.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | import pytorch_lightning as pl
4 | from fuse.data.datasets.dataset_default import DatasetDefault
5 | from fuse.data.ops.ops_read import OpReadDataframe
6 | from fuse.data.pipelines.pipeline_default import PipelineDefault
7 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
8 | from fuse.data.utils.collates import CollateDefault
9 | from tdc.single_pred.tox import Tox
10 | from torch.utils.data.dataloader import DataLoader
11 |
12 | from mammal.keys import * # noqa
13 |
14 |
15 | class CarcinogenicityDataModule(pl.LightningDataModule):
16 | def __init__(
17 | self,
18 | *,
19 | batch_size: int,
20 | tokenizer_op: ModularTokenizerOp,
21 | drug_max_seq_length: int,
22 | encoder_input_max_seq_len: int,
23 | data_preprocessing: Callable,
24 | labels_max_seq_len: int,
25 | ) -> None:
26 | super().__init__()
27 | self.tokenizer_op = tokenizer_op
28 | self.drug_max_seq_length = drug_max_seq_length
29 | self.encoder_input_max_seq_len = encoder_input_max_seq_len
30 | self.labels_max_seq_len = labels_max_seq_len
31 | self.batch_size = batch_size
32 | self.data_preprocessing = data_preprocessing
33 | self.pad_token_id = self.tokenizer_op.get_token_id("")
34 |
35 | def setup(self, stage: str) -> None:
36 | self.ds_dict = load_datasets()
37 |
38 | task_pipeline = [
39 | (
40 | self.data_preprocessing,
41 | dict(
42 | sequence_key="data.drug",
43 | label_key="data.label",
44 | tokenizer_op=self.tokenizer_op,
45 | encoder_input_max_seq_len=self.encoder_input_max_seq_len,
46 | labels_max_seq_len=self.labels_max_seq_len,
47 | ),
48 | ),
49 | ]
50 |
51 | for ds in self.ds_dict.values():
52 | ds.dynamic_pipeline.extend(task_pipeline)
53 |
54 | def train_dataloader(self) -> DataLoader:
55 | train_loader = DataLoader(
56 | dataset=self.ds_dict["train"],
57 | batch_size=self.batch_size,
58 | collate_fn=CollateDefault(),
59 | shuffle=True,
60 | )
61 | return train_loader
62 |
63 | def val_dataloader(self) -> DataLoader:
64 | val_loader = DataLoader(
65 | self.ds_dict["valid"],
66 | batch_size=self.batch_size,
67 | collate_fn=CollateDefault(),
68 | )
69 |
70 | return val_loader
71 |
72 | def test_dataloader(self) -> DataLoader:
73 | test_loader = DataLoader(
74 | self.ds_dict["test"],
75 | batch_size=self.batch_size,
76 | collate_fn=CollateDefault(),
77 | )
78 |
79 | return test_loader
80 |
81 | def predict_dataloader(self) -> DataLoader:
82 | return self.test_dataloader()
83 |
84 |
85 | def load_datasets(split_method: str = "random") -> dict[str, DatasetDefault]:
86 | data = Tox(name="Carcinogens_Lagunin")
87 | split = data.get_split(method=split_method)
88 |
89 | ds_dict = {}
90 | for set_name in ["train", "valid", "test"]:
91 | data_df = split[set_name]
92 | print(f"{set_name} set size is {len(data_df)}")
93 | size = len(data_df)
94 |
95 | dynamic_pipeline = PipelineDefault(
96 | "carcinogenicity",
97 | [
98 | (
99 | OpReadDataframe(
100 | data_df,
101 | key_column=None,
102 | rename_columns={"Drug": "data.drug", "Y": "data.label"},
103 | ),
104 | dict(),
105 | ),
106 | ],
107 | )
108 |
109 | ds = DatasetDefault(sample_ids=size, dynamic_pipeline=dynamic_pipeline)
110 | ds.create()
111 | ds_dict[set_name] = ds
112 |
113 | return ds_dict
114 |
115 |
116 | if __name__ == "__main__":
117 | ds = load_datasets()
118 | print(ds["train"][0])
119 | print(ds["test"][0])
120 |
--------------------------------------------------------------------------------
/mammal_mcp/dependencies.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from contextlib import asynccontextmanager
4 |
5 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
6 |
7 | # bmfm imports
8 | from mammal.model import Mammal
9 |
10 | logging.basicConfig(
11 | level=logging.WARNING,
12 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
13 | handlers=[logging.StreamHandler()],
14 | )
15 |
16 | # Create a logger
17 | logger = logging.getLogger("mammal")
18 | logger.setLevel(logging.DEBUG)
19 |
20 | assets = {}
21 |
22 |
23 | @asynccontextmanager
24 | async def lifespan():
25 | if os.getenv("PROTEIN_PROTEIN_INTERACTION") == "true":
26 | # Load Model
27 | logger.info("downloading: ibm/biomed.omics.bl.sm.ma-ted-458m")
28 | model = Mammal.from_pretrained(
29 | "ibm/biomed.omics.bl.sm.ma-ted-458m", cache_dir="model_cache"
30 | )
31 | assets["model"] = model.eval()
32 | logger.info("completed the download: ibm/biomed.omics.bl.sm.ma-ted-458m")
33 |
34 | logger.info("downloading for tokenizer: ibm/biomed.omics.bl.sm.ma-ted-458m")
35 | assets["tokenizer_op"] = ModularTokenizerOp.from_pretrained(
36 | "ibm/biomed.omics.bl.sm.ma-ted-458m", cache_dir="model_cache"
37 | )
38 |
39 | if os.getenv("PROTEIN_SOLUBILITY") == "true":
40 | logger.info(
41 | "downloading: ibm-research/biomed.omics.bl.sm.ma-ted-458m.protein_solubility"
42 | )
43 | protein_solubility_model = Mammal.from_pretrained(
44 | "ibm-research/biomed.omics.bl.sm.ma-ted-458m.protein_solubility",
45 | cache_dir="model_cache",
46 | )
47 | assets["protein_solubility_model"] = protein_solubility_model.eval()
48 | assets["protein_solubility_tokenizer_op"] = ModularTokenizerOp.from_pretrained(
49 | "ibm-research/biomed.omics.bl.sm.ma-ted-458m.protein_solubility",
50 | cache_dir="model_cache",
51 | )
52 |
53 | if os.getenv("PROTEIN_DRUG_INTERACTION_MODEL") == "true":
54 | logger.info("downloading: ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd")
55 | protein_drug_interaction_model = Mammal.from_pretrained(
56 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd",
57 | cache_dir="model_cache",
58 | )
59 | assets["protein_drug_interaction_model"] = protein_drug_interaction_model.eval()
60 | assets["protein_drug_interaction_tokenizer_op"] = (
61 | ModularTokenizerOp.from_pretrained(
62 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd",
63 | cache_dir="model_cache",
64 | )
65 | )
66 |
67 | if os.getenv("TCR_EPITOPE_BINDING") == "true":
68 | logger.info(
69 | "downloading: ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind"
70 | )
71 |
72 | tcr_epitope_model = Mammal.from_pretrained(
73 | "ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind",
74 | cache_dir="model_cache",
75 | )
76 | # set to eval/inference mode
77 | tcr_epitope_model.eval()
78 | # set model to model mode (use 'cuda' if GPU avail)
79 | tcr_epitope_model.to(device="cpu")
80 | assets["tcr_epitope_model"] = tcr_epitope_model
81 |
82 | assets["tcr_epitope_model_tokenizer_op"] = ModularTokenizerOp.from_pretrained(
83 | "ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind",
84 | cache_dir="model_cache",
85 | )
86 |
87 | if (
88 | os.getenv("DRUG_TARGET_BINDING") == "true"
89 | or os.getenv("DRUG_TARGET_BINDING_FASTA") == "true"
90 | ):
91 | logger.info("downloading: ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd")
92 |
93 | drug_target_model = Mammal.from_pretrained(
94 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd",
95 | cache_dir="model_cache",
96 | )
97 | # set to eval/inference mode
98 | drug_target_model.eval()
99 | assets["drug_target_model"] = drug_target_model
100 |
101 | # download tokeniser
102 | assets["drug_target_model_tokeniser_op"] = ModularTokenizerOp.from_pretrained(
103 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd",
104 | cache_dir="model_cache",
105 | )
106 |
107 | logger.info("Assets loaded")
108 |
109 | yield
110 | # Clean up the assets
111 | assets.clear()
112 |
--------------------------------------------------------------------------------
/mammal/examples/molnet/molnet_infer.py:
--------------------------------------------------------------------------------
1 | import click
2 | import numpy as np
3 | import torch
4 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
5 |
6 | from mammal.keys import (
7 | CLS_PRED,
8 | ENCODER_INPUTS_ATTENTION_MASK,
9 | ENCODER_INPUTS_STR,
10 | ENCODER_INPUTS_TOKENS,
11 | SCORES,
12 | )
13 | from mammal.model import Mammal
14 |
15 | TASK_NAMES = ["BBBP", "TOXICITY", "FDA_APPR"]
16 |
17 |
18 | @click.command()
19 | @click.argument("task_name", default="BBBP")
20 | @click.argument(
21 | "smiles_seq",
22 | default="C(Cl)Cl",
23 | )
24 | @click.option(
25 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')."
26 | )
27 | def main(task_name: str, smiles_seq: str, device: str):
28 | task_dict = load_model(task_name=task_name, device=device)
29 | result = task_infer(task_dict=task_dict, smiles_seq=smiles_seq)
30 | print(f"The prediction for {smiles_seq=} is {result}")
31 |
32 |
33 | def load_model(task_name: str, device: str) -> dict:
34 | match task_name:
35 | case "BBBP":
36 | path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_bbbp"
37 | case "TOXICITY":
38 | path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_tox"
39 | case "FDA_APPR":
40 | path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda"
41 | case _:
42 | print(f"The {task_name=} is incorrect")
43 |
44 | # Load Model and set to evaluation mode
45 | model = Mammal.from_pretrained(path)
46 | model.eval()
47 | model.to(device=device)
48 |
49 | # Load Tokenizer
50 | tokenizer_op = ModularTokenizerOp.from_pretrained(path)
51 |
52 | task_dict = dict(
53 | task_name=task_name,
54 | model=model,
55 | tokenizer_op=tokenizer_op,
56 | )
57 | return task_dict
58 |
59 |
60 | def process_model_output(
61 | tokenizer_op: ModularTokenizerOp,
62 | decoder_output: np.ndarray,
63 | decoder_output_scores: np.ndarray,
64 | ) -> dict:
65 | """
66 | Extract predicted class and scores
67 | """
68 | negative_token_id = tokenizer_op.get_token_id("<0>")
69 | positive_token_id = tokenizer_op.get_token_id("<1>")
70 | label_id_to_int = {
71 | negative_token_id: 0,
72 | positive_token_id: 1,
73 | }
74 | classification_position = 1
75 |
76 | if decoder_output_scores is not None:
77 | scores = decoder_output_scores[classification_position, positive_token_id]
78 |
79 | ans = dict(
80 | pred=label_id_to_int.get(int(decoder_output[classification_position]), -1),
81 | score=scores.item(),
82 | )
83 | return ans
84 |
85 |
86 | def task_infer(task_dict: dict, smiles_seq: str) -> dict:
87 | task_name = task_dict["task_name"]
88 | model = task_dict["model"]
89 | tokenizer_op = task_dict["tokenizer_op"]
90 |
91 | if task_name not in TASK_NAMES:
92 | print(f"The {task_name=} is incorrect. Valid names are {TASK_NAMES}")
93 |
94 | sample_dict = create_sample_dict(task_name, smiles_seq, tokenizer_op, model)
95 | # Generate Prediction
96 | batch_dict = get_predictions(model, sample_dict)
97 |
98 | # Post-process the model's output
99 | result = process_model_output(
100 | tokenizer_op=tokenizer_op,
101 | decoder_output=batch_dict[CLS_PRED][0],
102 | decoder_output_scores=batch_dict[SCORES][0],
103 | )
104 | return result
105 |
106 |
107 | def create_sample_dict(task_name, smiles_seq, tokenizer_op, model):
108 |
109 | # Create and load sample
110 | sample_dict = dict()
111 | # Formatting prompt to match pre-training syntax
112 | sample_dict[ENCODER_INPUTS_STR] = (
113 | f"<@TOKENIZER-TYPE=SMILES><{task_name}><@TOKENIZER-TYPE=SMILES@MAX-LEN=2100>{smiles_seq}"
114 | )
115 |
116 | # Tokenize
117 | tokenizer_op(
118 | sample_dict=sample_dict,
119 | key_in=ENCODER_INPUTS_STR,
120 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
121 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
122 | )
123 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
124 | sample_dict[ENCODER_INPUTS_TOKENS], device=model.device
125 | )
126 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
127 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=model.device
128 | )
129 | return sample_dict
130 |
131 |
132 | def get_predictions(model, sample_dict):
133 | return model.generate(
134 | [sample_dict],
135 | output_scores=True,
136 | return_dict_in_generate=True,
137 | max_new_tokens=5,
138 | )
139 |
140 |
141 | if __name__ == "__main__":
142 | main()
143 |
--------------------------------------------------------------------------------
/mammal/examples/tcr_epitope_binding/main_infer.py:
--------------------------------------------------------------------------------
1 | import click
2 | import numpy as np
3 | import torch
4 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
5 |
6 | from mammal.keys import (
7 | CLS_PRED,
8 | ENCODER_INPUTS_ATTENTION_MASK,
9 | ENCODER_INPUTS_STR,
10 | ENCODER_INPUTS_TOKENS,
11 | SCORES,
12 | )
13 | from mammal.model import Mammal
14 |
15 |
16 | @click.command()
17 | @click.argument(
18 | "tcr_beta_seq",
19 | default="GAVVSQHPSWVICKSGTSVKIECRSLDFQATTMFWYRQFPKQSLMLMATSNEGSKATYEQGVEKDKFLINHASLTLSTLTVTSAHPEDSSFYICSASEGTSSYEQYFGPGTRLTVT", # alternative binder 1
20 | # Alternative binder 2: NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT
21 | )
22 | @click.argument(
23 | "epitope_seq",
24 | default="FLKEKGGL", # alternative binder 1
25 | # Alternative binder 2: LLQTGIHVRVSQPSL
26 | )
27 | @click.option(
28 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')."
29 | )
30 | def main(tcr_beta_seq: str, epitope_seq: str, device: str):
31 | model, tokenizer_op = load_model(device=device)
32 | result = task_infer(
33 | model=model,
34 | tokenizer_op=tokenizer_op,
35 | tcr_beta_seq=tcr_beta_seq,
36 | epitope_seq=epitope_seq,
37 | )
38 | print(f"The prediction for {epitope_seq} and {tcr_beta_seq} is {result}")
39 |
40 |
41 | def load_model(
42 | device: str,
43 | model_path: str = "ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind", # change to "ibm/biomed.omics.bl.sm.ma-ted-458m" to try on the base model
44 | tokenizer_path: str = "ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind",
45 | ) -> tuple["Mammal", "ModularTokenizerOp"]:
46 |
47 | # Load Model and set to evaluation mode
48 | model = Mammal.from_pretrained(
49 | pretrained_model_name_or_path=model_path, allow_config_mismatch=True
50 | )
51 | model.eval()
52 | model.to(device=device)
53 |
54 | # Load Tokenizer
55 | tokenizer_op = ModularTokenizerOp.from_pretrained(tokenizer_path)
56 |
57 | return model, tokenizer_op
58 |
59 |
60 | def process_model_output(
61 | tokenizer_op: ModularTokenizerOp,
62 | decoder_output: np.ndarray,
63 | decoder_output_scores: np.ndarray,
64 | ) -> dict:
65 | """
66 | Extract predicted class and scores
67 | """
68 | negative_token_id = tokenizer_op.get_token_id("<0>")
69 | positive_token_id = tokenizer_op.get_token_id("<1>")
70 | label_id_to_int = {
71 | negative_token_id: 0,
72 | positive_token_id: 1,
73 | }
74 | classification_position = 1
75 |
76 | if decoder_output_scores is not None:
77 | scores = decoder_output_scores[classification_position, positive_token_id]
78 |
79 | ans = dict(
80 | pred=label_id_to_int.get(int(decoder_output[classification_position]), -1),
81 | score=scores.item(),
82 | )
83 | return ans
84 |
85 |
86 | def task_infer(
87 | model: "Mammal",
88 | tokenizer_op: ModularTokenizerOp,
89 | tcr_beta_seq: str,
90 | epitope_seq: str,
91 | ) -> dict:
92 | treat_inputs_as_general_proteins = False
93 |
94 | # Create and load sample
95 | sample_dict = dict()
96 | # Formatting prompt to match pre-training syntax
97 |
98 | if treat_inputs_as_general_proteins:
99 | # Treat inputs as general proteins:
100 | sample_dict[ENCODER_INPUTS_STR] = (
101 | f"<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=AA>{tcr_beta_seq}<@TOKENIZER-TYPE=AA>{epitope_seq}"
102 | )
103 | else:
104 | # Treat inputs as TCR beta chain and epitope
105 | sample_dict[ENCODER_INPUTS_STR] = (
106 | f"<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=AA>{tcr_beta_seq}<@TOKENIZER-TYPE=AA>{epitope_seq}"
107 | )
108 |
109 | # Tokenize
110 | tokenizer_op(
111 | sample_dict=sample_dict,
112 | key_in=ENCODER_INPUTS_STR,
113 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
114 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
115 | )
116 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
117 | sample_dict[ENCODER_INPUTS_TOKENS], device=model.device
118 | )
119 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
120 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=model.device
121 | )
122 |
123 | # Generate Prediction
124 | batch_dict = model.generate(
125 | [sample_dict],
126 | output_scores=True,
127 | return_dict_in_generate=True,
128 | max_new_tokens=5,
129 | )
130 |
131 | # Post-process the model's output
132 | result = process_model_output(
133 | tokenizer_op=tokenizer_op,
134 | decoder_output=batch_dict[CLS_PRED][0],
135 | decoder_output_scores=batch_dict[SCORES][0],
136 | )
137 | return result
138 |
139 |
140 | if __name__ == "__main__":
141 | main()
142 |
--------------------------------------------------------------------------------
/mammal/examples/dti_bindingdb_kd/pl_data_module.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Callable
2 |
3 | import pytorch_lightning as pl
4 | from fuse.data.datasets.dataset_default import DatasetDefault
5 | from fuse.data.ops.ops_read import OpReadDataframe
6 | from fuse.data.pipelines.pipeline_default import PipelineDefault
7 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
8 | from fuse.data.tokenizers.modular_tokenizer.special_tokens import special_wrap_input
9 | from fuse.data.utils.collates import CollateDefault
10 | from tdc.multi_pred.dti import DTI
11 | from torch.utils.data.dataloader import DataLoader
12 |
13 | from mammal.keys import * # noqa
14 |
15 |
16 | class DtiBindingdbKdDataModule(pl.LightningDataModule):
17 | def __init__(
18 | self,
19 | *,
20 | batch_size: int,
21 | tokenizer_op: ModularTokenizerOp,
22 | train_dl_kwargs: dict,
23 | valid_dl_kwargs: dict,
24 | seed: int,
25 | data_preprocessing: Callable,
26 | target_max_seq_length: int,
27 | drug_max_seq_length: int,
28 | encoder_input_max_seq_len: int,
29 | load_datasets_kwargs: dict,
30 | ) -> None:
31 | super().__init__()
32 | self.tokenizer_op = tokenizer_op
33 | self.target_max_seq_length = target_max_seq_length
34 | self.drug_max_seq_length = drug_max_seq_length
35 | self.encoder_input_max_seq_len = encoder_input_max_seq_len
36 | self.batch_size = batch_size
37 | self.train_dl_kwargs = train_dl_kwargs
38 | self.valid_dl_kwargs = valid_dl_kwargs
39 | self.seed = seed
40 | self.data_preprocessing = data_preprocessing
41 | self.load_datasets_kwargs = load_datasets_kwargs
42 |
43 | self.pad_token_id = self.tokenizer_op.get_token_id(special_wrap_input("PAD"))
44 |
45 | def setup(self, stage: str) -> None:
46 | self.ds_dict = load_datasets(**self.load_datasets_kwargs)
47 |
48 | task_pipeline = [
49 | (
50 | # Prepare the input string(s) in modular tokenizer input format
51 | self.data_preprocessing,
52 | dict(
53 | target_sequence_key="Target",
54 | drug_sequence_key="Drug",
55 | ground_truth_key="Y",
56 | tokenizer_op=self.tokenizer_op,
57 | target_max_seq_length=self.target_max_seq_length,
58 | drug_max_seq_length=self.drug_max_seq_length,
59 | encoder_input_max_seq_len=self.encoder_input_max_seq_len,
60 | ),
61 | ),
62 | ]
63 |
64 | for ds in self.ds_dict.values():
65 | ds.dynamic_pipeline.extend(task_pipeline)
66 |
67 | def train_dataloader(self) -> DataLoader:
68 | train_loader = DataLoader(
69 | dataset=self.ds_dict["train"],
70 | batch_size=self.batch_size,
71 | collate_fn=CollateDefault(add_to_batch_dict={"forward_mode": "encoder"}),
72 | shuffle=True,
73 | **self.train_dl_kwargs,
74 | )
75 | return train_loader
76 |
77 | def val_dataloader(self) -> DataLoader:
78 | val_loader = DataLoader(
79 | self.ds_dict["valid"],
80 | batch_size=self.batch_size,
81 | collate_fn=CollateDefault(add_to_batch_dict={"forward_mode": "encoder"}),
82 | **self.valid_dl_kwargs,
83 | )
84 |
85 | return val_loader
86 |
87 | def test_dataloader(self) -> DataLoader:
88 | test_loader = DataLoader(
89 | self.ds_dict["test"],
90 | batch_size=self.batch_size,
91 | collate_fn=CollateDefault(add_to_batch_dict={"forward_mode": "encoder"}),
92 | **self.valid_dl_kwargs,
93 | )
94 |
95 | return test_loader
96 |
97 | def predict_dataloader(self) -> DataLoader:
98 | return self.test_dataloader()
99 |
100 |
101 | def load_datasets(
102 | split_type: str = "cold_split", split_column: list[str] | str = ["Drug", "Target"]
103 | ) -> dict[str, DatasetDefault]:
104 | """
105 | Automatically downloads (using tdc) the data and create dataset iterator for "train", "val" and "test".
106 | :return: dictionary that maps fold name "train", "val" and "test" to a dataset iterator
107 | """
108 |
109 | data = DTI(name="BindingDB_Kd")
110 | data.harmonize_affinities(mode="max_affinity")
111 | data.convert_to_log(form="binding")
112 | split = data.get_split(method=split_type, column_name=split_column)
113 | ds_dict = {}
114 | for set_name in ["train", "valid", "test"]:
115 | set_df = split[set_name]
116 | print(f"{set_name} set size is {len(set_df)}")
117 | print(f"{set_name=} {set_df.Y.mean()=} {set_df.Y.std()=}")
118 | dynamic_pipeline = PipelineDefault(
119 | "dti",
120 | [
121 | (
122 | OpReadDataframe(
123 | set_df,
124 | key_column=None,
125 | columns_to_extract=["Target", "Drug", "Y"],
126 | ),
127 | dict(),
128 | ),
129 | ],
130 | )
131 |
132 | ds = DatasetDefault(sample_ids=len(set_df), dynamic_pipeline=dynamic_pipeline)
133 | ds.create()
134 | ds_dict[set_name] = ds
135 |
136 | return ds_dict
137 |
138 |
139 | if __name__ == "__main__":
140 | load_datasets()
141 |
--------------------------------------------------------------------------------
/mammal/examples/tests/test_main_finetune.py:
--------------------------------------------------------------------------------
1 | import socket
2 | from pathlib import Path
3 |
4 | import hydra
5 | import pytest
6 | import pytorch_lightning as pl
7 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
8 | from fuse.utils.multiprocessing.helpers import num_available_cores
9 | from omegaconf import OmegaConf
10 |
11 | from mammal.main_finetune import *
12 |
13 | # pylint: disable=W0621
14 |
15 | """_summary_
16 |
17 | Testing examples
18 | """
19 |
20 | TEST_CONFIG_DIRPATH = str(Path(__file__).parents[0] / "../protein_solubility")
21 | TEST_CONFIG_FILENAME = "config.yaml"
22 |
23 | STATIC_OVERRIDES = [
24 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials.
25 | "trainer.max_epochs=1", # Small number for a faster run.
26 | "+trainer.limit_train_batches=2", # Small number for a faster run.
27 | "+trainer.limit_val_batches=3", # Small number for a faster run.
28 | "+trainer.enable_checkpointing=False", # Do not checkpoint - saves memory
29 | "model_dir=null",
30 | ]
31 |
32 | OVERRIDES = [
33 | "task.data_module_kwargs.train_dl_kwargs.num_workers=0", # Using parallelization cause co
34 | "task.data_module_kwargs.valid_dl_kwargs.num_workers=0", # Using parallelization cause co
35 | "root=.",
36 | "name=sol_test",
37 | "+tokenizer.new_special_tokens=['','special_token2',yet_another_special_token]",
38 | ] + STATIC_OVERRIDES
39 |
40 |
41 | @pytest.fixture(scope="session")
42 | def cfg_dict():
43 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"):
44 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES)
45 | yield _cfg
46 |
47 |
48 | @pytest.fixture(scope="session")
49 | def cfg_obj(cfg_dict):
50 | OmegaConf.register_new_resolver("num_cores_auto", num_available_cores, replace=True)
51 | cfg_obj = hydra.utils.instantiate(cfg_dict)
52 | return cfg_obj
53 |
54 |
55 | @pytest.fixture(scope="session")
56 | def cfg(cfg_obj):
57 | return cfg_obj
58 |
59 |
60 | @pytest.fixture(scope="session")
61 | def clearml_logger():
62 | return None
63 |
64 |
65 | def test_context(cfg_dict):
66 | assert cfg_dict
67 |
68 |
69 | def test_context_obj(cfg):
70 | assert cfg
71 |
72 |
73 | def seed(seed_value: int) -> int:
74 | pl.seed_everything(seed_value, workers=True)
75 |
76 | return seed_value
77 |
78 |
79 | def test_seed():
80 | original_seed_value = 12345
81 | seed_value = seed(original_seed_value)
82 | assert seed_value == original_seed_value
83 |
84 |
85 | @pytest.fixture(scope="session")
86 | def tokenizer_op(cfg_dict):
87 | # return ModularTokenizerOp.from_pretrained(cfg_dict.tokenizer.tokenizer_path)
88 | # The tokenizer is loaded with the extra special tokens, so we can check they are avaible
89 | return load_and_update_tokenizer_op(cfg_dict)
90 |
91 |
92 | def test_tokenizer(tokenizer_op):
93 | assert isinstance(tokenizer_op, ModularTokenizerOp)
94 | special_tokens = [
95 | "",
96 | "",
97 | "",
98 | "",
99 | "",
100 | "",
101 | "",
102 | "",
103 | "special_token2",
104 | "yet_another_special_token",
105 | ]
106 | for token in special_tokens:
107 | tokenizer_op.get_token_id(token) # throws assert if fails
108 | never_seen_before = "never_seen_before"
109 | with pytest.raises(AssertionError):
110 | tokenizer_op.get_token_id(never_seen_before)
111 | num_new = tokenizer_op.add_new_special_tokens(
112 | [
113 | never_seen_before,
114 | ]
115 | )
116 | assert num_new == 1
117 | tokenizer_op.get_token_id(never_seen_before)
118 |
119 |
120 | @pytest.fixture(scope="session")
121 | def current_train_session_metadata():
122 | return {}
123 |
124 |
125 | @pytest.fixture(scope="session")
126 | def test_task(cfg_obj, clearml_logger, tokenizer_op):
127 | _task_list = cfg_obj.task(
128 | tokenizer_op=tokenizer_op,
129 | logger=clearml_logger,
130 | )
131 | return _task_list
132 |
133 |
134 | @pytest.fixture(scope="session")
135 | def pl_data_module(test_task):
136 | """get lightning data module"""
137 | return test_task.data_module()
138 |
139 |
140 | @pytest.fixture(scope="session")
141 | def pl_module(test_task, cfg):
142 | """get lightning module"""
143 | model = Mammal.from_pretrained(
144 | cfg.model.pretrained_kwargs.pretrained_model_name_or_path
145 | )
146 | _pl_module = module(
147 | model=model,
148 | task=test_task,
149 | **OmegaConf.to_container(cfg.module, resolve=True),
150 | )
151 | return _pl_module
152 |
153 |
154 | @pytest.mark.skipif(
155 | "ccc" not in socket.gethostname(),
156 | reason="Train consumes too much memory for a Travis run.",
157 | )
158 | def test_evaluate(cfg, pl_data_module, pl_module):
159 | pl_trainer = pl.Trainer(**cfg.trainer)
160 |
161 | pl_data_module.setup("test")
162 | out = pl_trainer.validate(
163 | model=pl_module,
164 | dataloaders=pl_data_module.test_dataloader(),
165 | )
166 | print(out)
167 |
168 |
169 | @pytest.mark.skipif(
170 | "ccc" not in socket.gethostname(),
171 | reason="Train consumes too much memory for a Travis run.",
172 | )
173 | def test_train(cfg, pl_data_module, pl_module):
174 | pl_trainer = pl.Trainer(**cfg.trainer)
175 | pl_trainer.fit(model=pl_module, datamodule=pl_data_module)
176 |
--------------------------------------------------------------------------------
/mammal/main_finetune.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections.abc import Callable
3 | from functools import partial
4 |
5 | import hydra
6 | import pytorch_lightning as pl
7 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
8 | from fuse.dl.lightning.pl_module import LightningModuleDefault
9 | from fuse.utils import NDict
10 | from omegaconf import DictConfig, OmegaConf
11 |
12 | from mammal.model import Mammal
13 | from mammal.task import MammalTask
14 |
15 |
16 | def save_in_model_dir(
17 | model_dir: str, model: Mammal, tokenizer_op: ModularTokenizerOp
18 | ) -> None:
19 | """
20 | Save model configuration and tokenizer in model_dir before starting the finetunning session
21 | :param model_dir: location to store the files
22 | :param model: the model to save the configuration for
23 | :param tokenizer_op: the tokenizer to save
24 | """
25 |
26 | if model_dir is None:
27 | return
28 |
29 | os.makedirs(model_dir, exist_ok=True)
30 |
31 | # save model config
32 | model._save_pretrained(model_dir, save_config_only=True)
33 |
34 | # tokenizer
35 | tokenizer_op.save_pretrained(os.path.join(model_dir, "tokenizer"))
36 |
37 |
38 | def configure_optimizers(
39 | module: LightningModuleDefault, opt_callable: Callable, lr_sch_callable: Callable
40 | ) -> dict:
41 | """
42 | A callback use by lightning module to set the learning rate scheduler and the optimizer
43 | :param module: the lightning module
44 | :param opt_callable: a callable that creates an optimizer given the model parameters
45 | :param lr_sch_callable: a callable that creates a learning rate scheduler given the optimizer
46 |
47 | """
48 | opt = opt_callable(module.trainer.model.parameters())
49 | lr_sch = lr_sch_callable(opt)
50 | return {
51 | "optimizer": opt,
52 | "lr_scheduler": {"scheduler": lr_sch, "interval": "step"},
53 | }
54 |
55 |
56 | def module(
57 | model: Mammal,
58 | task: MammalTask,
59 | opt_callable: Callable,
60 | lr_sch_callable: Callable,
61 | **kwargs,
62 | ) -> pl.LightningModule:
63 | """
64 | Create lightning module
65 | :param task: the task to finetune for
66 | :param opt_callable: a callable that creates an optimizer given the model parameters
67 | :param lr_sch_callable: a callable that creates a learning rate scheduler given the optimizer
68 | :param kwargs: additional LightningModuleDefault arguments
69 | """
70 | optimizers_and_lr_schs_callable = partial(
71 | configure_optimizers, opt_callable=opt_callable, lr_sch_callable=lr_sch_callable
72 | )
73 | return LightningModuleDefault(
74 | model=model,
75 | losses=task.losses(),
76 | validation_metrics=task.validation_metrics(),
77 | train_metrics=task.train_metrics(),
78 | optimizers_and_lr_schs=optimizers_and_lr_schs_callable,
79 | **kwargs,
80 | )
81 |
82 |
83 | @hydra.main(version_base="1.2", config_path=None, config_name=None)
84 | def main(cfg: DictConfig):
85 | cfg = hydra.utils.instantiate(cfg)
86 |
87 | # print configuration
88 | NDict(OmegaConf.to_container(cfg, resolve=True)).print_tree(True)
89 |
90 | # connect to clearml - if configured
91 | if "track_clearml" in cfg and cfg["track_clearml"] is not None:
92 | try:
93 | from fuse.dl.lightning.pl_funcs import start_clearml_logger
94 |
95 | clearml_task = start_clearml_logger(**cfg.track_clearml)
96 | except Exception as e:
97 | print("Tracking using clearml failed: continue without tracking")
98 | print(e)
99 | clearml_task = None
100 |
101 | if clearml_task is not None:
102 | clearml_logger = clearml_task.get_logger()
103 | else:
104 | clearml_logger = None # will be None in dist training and rank != 0
105 | else:
106 | clearml_logger = None
107 |
108 | # seed
109 | pl.seed_everything(seed=cfg.seed, workers=True)
110 |
111 | # tokenizer
112 | tokenizer_op = load_and_update_tokenizer_op(cfg)
113 |
114 | # model
115 | model = Mammal.from_pretrained(**cfg.model.pretrained_kwargs)
116 | print(model)
117 |
118 | # initialize task
119 | task: MammalTask = cfg.task(tokenizer_op=tokenizer_op, logger=clearml_logger)
120 |
121 | # lightning data module
122 | pl_data_module = task.data_module()
123 |
124 | # lightning module
125 | pl_module = module(
126 | task=task,
127 | model=model,
128 | **OmegaConf.to_container(cfg.module, resolve=True),
129 | )
130 |
131 | # create lightning trainer.
132 | pl_trainer = pl.Trainer(**cfg.trainer)
133 |
134 | if cfg.evaluate:
135 | pl_data_module.setup("test")
136 | out = pl_trainer.test(
137 | model=pl_module,
138 | dataloaders=pl_data_module.test_dataloader(),
139 | )
140 | print(out)
141 | else:
142 | # save model_config and tokenizer in output_dir
143 | save_in_model_dir(cfg.module.model_dir, model, tokenizer_op)
144 | pl_trainer.fit(model=pl_module, datamodule=pl_data_module)
145 |
146 |
147 | def load_and_update_tokenizer_op(cfg):
148 | tokenizer_op = ModularTokenizerOp.from_pretrained(cfg.tokenizer.tokenizer_path)
149 |
150 | if "new_special_tokens" in cfg.tokenizer and len(cfg.tokenizer.new_special_tokens):
151 | num_new_tokens_added = tokenizer_op.add_new_special_tokens(
152 | cfg.tokenizer.new_special_tokens
153 | )
154 | if num_new_tokens_added:
155 | # TODO: write better message
156 | print(
157 | 10 * "****",
158 | f" Added { num_new_tokens_added} special tokens to the tokenizer ",
159 | 10 * "****",
160 | )
161 |
162 | return tokenizer_op
163 |
164 |
165 | if __name__ == "__main__":
166 | main()
167 |
--------------------------------------------------------------------------------
/mammal/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from mammal.keys import *
4 |
5 |
6 | class LossHead(torch.nn.Module):
7 | """
8 | Cross Entropy Loss
9 | """
10 |
11 | def __init__(
12 | self,
13 | *, # prevent positional args
14 | loss_type: str = "ce",
15 | loss_weight: float = 1.0,
16 | sample_token_weights_key: str | None = None,
17 | ignore_index: int | None = -100,
18 | pred_key: str = LOGITS,
19 | labels_key: str = LABELS_TOKENS,
20 | verify_no_weight_at_ignore: bool = True,
21 | ) -> None:
22 | """
23 | :param verify_no_weight_at_ignore: verifies that weights at positions where the loss is ignored (via ignore_index) equal 0, otherwise the normalization of loss by sum of weights is skewed.
24 | """
25 |
26 | super().__init__()
27 | self.loss_type = loss_type
28 | self.loss_weight = loss_weight
29 | self.sample_token_weights_key = sample_token_weights_key
30 | self.ignore_index = ignore_index
31 | self.labels_key = labels_key
32 | self.verify_no_weight_at_ignore = verify_no_weight_at_ignore
33 | self.pred_key = pred_key
34 |
35 | if self.loss_type == "ce":
36 | self.loss_function = torch.nn.CrossEntropyLoss(
37 | ignore_index=self.ignore_index, reduction="none"
38 | )
39 | else:
40 | raise NotImplementedError(self._loss)
41 |
42 | def forward(self, batch_dict: dict) -> torch.Tensor:
43 | preds = batch_dict[self.pred_key]
44 | targets = batch_dict[self.labels_key]
45 |
46 | if self.sample_token_weights_key:
47 | if self.sample_token_weights_key in batch_dict:
48 | sample_token_weights = batch_dict[self.sample_token_weights_key]
49 | else:
50 |
51 | sample_token_weights = None
52 | else:
53 | sample_token_weights = None
54 |
55 | # concat the tokens in all samples, we loss is at the level of single token
56 | n_classes = preds.shape[-1] if len(preds.shape) > 2 else 1
57 | preds = preds.reshape(-1, n_classes).squeeze(dim=1)
58 | targets = targets.reshape(-1)
59 | if sample_token_weights is not None:
60 | weights = sample_token_weights.reshape(-1)
61 |
62 | losses = self.loss_function(preds, targets)
63 |
64 | if sample_token_weights is None:
65 | losses = losses[targets != self.ignore_index]
66 | loss = losses.mean()
67 | else:
68 | losses[weights == 0.0] = (
69 | 0.0 # to make nan loss values be equal to zero if weights are zero
70 | )
71 | if self.verify_no_weight_at_ignore:
72 | assert (
73 | weights[targets == self.ignore_index].abs().sum() == 0
74 | ), "You are using none-zero weights at ignore index positions when calculating the loss - this is most likely skewing your loss evaluation.\nTo turn off this assertion pass 'verify_no_weight_at_ignore'=False to 'mammal.losses.LossHead"
75 | loss = (losses * weights).sum() / weights.sum()
76 |
77 | return loss * self.loss_weight
78 |
79 |
80 | class RMSELoss(torch.nn.Module):
81 | def __init__(self, reduction: str = "mean", eps: float = 1e-6) -> None:
82 | super().__init__()
83 | self.mse = torch.nn.MSELoss(reduction=reduction)
84 | self.eps = eps
85 |
86 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
87 | loss = (self.mse(inputs, targets) + self.eps) ** 0.5
88 | return loss
89 |
90 |
91 | class ScalarsPredictionsLoss(torch.nn.Module):
92 | """
93 | Scalars prediction loss. Might be MSE (default), RMSE or MAE
94 | """
95 |
96 | def __init__(
97 | self,
98 | *, # prevent positional args
99 | loss_type: str,
100 | loss_weight: float | None = None,
101 | pred_key: str = SCALARS_PREDICTION_HEAD_LOGITS,
102 | labels_scalars_values_key: str = LABELS_SCALARS_VALUES,
103 | labels_scalars_valid_mask_key: str = LABELS_SCALARS_VALID_MASK,
104 | ) -> None:
105 |
106 | super().__init__()
107 | self.loss_type = loss_type
108 | self.loss_weight = 1.0 if loss_weight is None else loss_weight
109 | self.pred_key = pred_key
110 | self.labels_scalars_values_key = labels_scalars_values_key
111 | self.labels_scalars_valid_mask_key = labels_scalars_valid_mask_key
112 |
113 | if self.loss_type == "mse":
114 | self.loss_function = torch.nn.MSELoss(reduction="none")
115 | elif self.loss_type == "mae":
116 | self.loss_function = torch.nn.L1Loss(reduction="none")
117 | elif self.loss_type == "rmse":
118 | self.loss_function = RMSELoss(reduction="none")
119 | else:
120 | raise NotImplementedError(self._loss)
121 |
122 | def forward(self, batch_dict: dict) -> torch.Tensor:
123 | if (
124 | (not _legit(batch_dict, self.pred_key))
125 | or (not _legit(batch_dict, self.labels_scalars_values_key))
126 | or (not _legit(batch_dict, self.labels_scalars_valid_mask_key))
127 | ):
128 | return 0.0
129 |
130 | preds = batch_dict[self.pred_key]
131 | targets = batch_dict[self.labels_scalars_values_key]
132 | valid = batch_dict[self.labels_scalars_valid_mask_key]
133 |
134 | assert (
135 | preds.shape == targets.shape
136 | ), f"preds shape ({preds.shape}) is expected to be the same as targets shape ({targets.shape})"
137 | assert (
138 | targets.shape == valid.shape
139 | ), f"targets shape ({valid.shape}) is expected to be the same as targets shape ({valid.shape})"
140 |
141 | # take only the valid elements
142 | preds = preds[valid.bool()]
143 | targets = targets[valid.bool()]
144 |
145 | curr_loss = self.loss_function(preds, targets)
146 |
147 | return curr_loss.mean() * self.loss_weight
148 |
149 |
150 | def _legit(batch_dict: dict, key: str) -> bool:
151 | if key not in batch_dict:
152 | return False
153 | if batch_dict[key] is None:
154 | return False
155 | if len(batch_dict[key]) == 0:
156 | return False
157 | return True
158 |
--------------------------------------------------------------------------------
/mammal/examples/scrna_cell_type/scRNA_infer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from pprint import pprint
3 |
4 | import anndata
5 | import click
6 | import numpy as np
7 | import torch
8 | from anndata_op import OpReadAnnData
9 | from fuse.data.datasets.dataset_default import DatasetDefault
10 | from fuse.data.pipelines.pipeline_default import PipelineDefault
11 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
12 | from fuse.data.utils.collates import CollateDefault
13 | from pl_data_module import CellTypeDataModule
14 | from task import CellTypeTask
15 | from torch.utils.data.dataset import Dataset
16 |
17 | from mammal.keys import (
18 | CLS_PRED,
19 | SCORES,
20 | )
21 | from mammal.model import Mammal
22 |
23 |
24 | @click.command()
25 | @click.argument("task_name", default="cell_type")
26 | @click.option(
27 | "--model-path",
28 | "-m",
29 | help="Specify the model dir.",
30 | )
31 | @click.option(
32 | "--tokenizer_path",
33 | default=None,
34 | type=str,
35 | help="Specify the tokenizer path.",
36 | )
37 | @click.option(
38 | "--h5ad-file-path",
39 | "-i",
40 | type=str,
41 | help="Specify the A5HD (AnnData) input file.",
42 | default="data/Zheng_68k_processed.h5ad",
43 | )
44 | @click.option("--sample_id", "-s", type=int, default=0)
45 | @click.option(
46 | "--verbose",
47 | "-v",
48 | count=True,
49 | default=1,
50 | )
51 | @click.option(
52 | "--test-h5ad-file",
53 | "-T",
54 | is_flag=True,
55 | default=False,
56 | )
57 | @click.option(
58 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')."
59 | )
60 | def main(
61 | task_name: str,
62 | h5ad_file_path: str,
63 | sample_id: int,
64 | model_path: str,
65 | tokenizer_path: str,
66 | verbose: int,
67 | test_h5ad_file: bool,
68 | device: str,
69 | ):
70 | try:
71 | anndata_object = anndata.read_h5ad(h5ad_file_path)
72 | except Exception as e:
73 | raise ValueError(
74 | f"Failed to read {h5ad_file_path} as a serialized AnnData "
75 | ) from e
76 |
77 | tokenizer_op, nn_model = get_tokenizer_and_model(tokenizer_path, model_path, device)
78 |
79 | # convert to MAMMAL style
80 |
81 | dynamic_pipeline = PipelineDefault(
82 | "cell_type",
83 | [
84 | (OpReadAnnData(data=anndata_object), {"prefix": "scrna"}),
85 | ],
86 | )
87 |
88 | data_source = DatasetDefault(
89 | sample_ids=anndata_object.shape[0], dynamic_pipeline=dynamic_pipeline
90 | )
91 | data_source.create()
92 |
93 | sample_dict = create_sample_dict(task_name, data_source, sample_id, tokenizer_op)
94 | if test_h5ad_file or verbose > 2:
95 | print("/n/n sample dict:/n")
96 | pprint(dict(sample_dict))
97 | if test_h5ad_file and "data.label" not in sample_dict:
98 | raise ValueError(
99 | "sample_dict['data.label'] is missing - data can not be used for training"
100 | )
101 | unique_values, counts = np.unique(
102 | sample_dict["scrna.scrna"].data, return_counts=True
103 | )
104 | n_values = len(unique_values)
105 | if n_values > 11:
106 | print("-" * 60)
107 | print(
108 | f"The data has {n_values} different expression bins which is typical for data that did not pass though the preprocessing."
109 | )
110 | for value, count in zip(unique_values, counts):
111 | print(f"Value {value} appears {count} times")
112 | print("\n" * 4)
113 |
114 | batch_dict = CollateDefault(skip_keys=CellTypeDataModule.skip_keys)([sample_dict])
115 |
116 | if test_h5ad_file or verbose > 1:
117 | key_to_print = ["data.query.encoder_input", "data.encoder_input_token_ids"]
118 | for key in key_to_print:
119 | print(f"{key}: {batch_dict[key]}")
120 |
121 | if test_h5ad_file or verbose > 2:
122 | n_zero = torch.sum(batch_dict["data.encoder_input_token_ids"] == 0).item()
123 | total_length = torch.sum(batch_dict["data.encoder_input_attention_mask"]).item()
124 | print(
125 | f"{n_zero} unknown from {total_length} tokens ({round((n_zero*100)/total_length,2)}%)"
126 | )
127 |
128 | # run the model
129 | batch_dict = get_predictions(nn_model, batch_dict=batch_dict)
130 | ans = process_model_output(tokenizer_op, batch_dict)
131 | ans = {
132 | k: v.detach().numpy() if isinstance(v, torch.Tensor) else v
133 | for k, v in ans.items()
134 | }
135 | if test_h5ad_file or verbose:
136 | print(ans)
137 |
138 |
139 | def process_model_output(tokenizer_op, batch_dict):
140 | return CellTypeTask.process_model_output(
141 | tokenizer_op=tokenizer_op,
142 | decoder_output=batch_dict[CLS_PRED][0],
143 | decoder_output_scores=batch_dict[SCORES][0],
144 | )
145 |
146 |
147 | def get_tokenizer_and_model(tokenizer_path, model_path, device):
148 | if tokenizer_path is None:
149 | tokenizer_path = Path(model_path)
150 | tokenizer_op = ModularTokenizerOp.from_pretrained(tokenizer_path)
151 | nn_model = Mammal.from_pretrained(
152 | pretrained_model_name_or_path=model_path,
153 | )
154 | nn_model.eval()
155 | nn_model.to(device=device)
156 |
157 | return tokenizer_op, nn_model
158 |
159 |
160 | def create_sample_dict(
161 | task_name: str, data_source: Dataset, sample_id: int, tokenizer_op
162 | ):
163 | sequence_key = "scrna.scrna"
164 | # Create and load sample
165 | sample_dict = data_source[sample_id]
166 |
167 | sample_dict = CellTypeTask.data_preprocessing(
168 | sample_dict,
169 | sequence_key=sequence_key,
170 | input_max_seq_length=1260,
171 | encoder_input_max_seq_len=1260,
172 | labels_max_seq_len=4,
173 | tokenizer_op=tokenizer_op,
174 | )
175 |
176 | return sample_dict
177 |
178 |
179 | def get_predictions(model, batch_dict):
180 | return model.generate(
181 | batch_dict,
182 | output_scores=True,
183 | return_dict_in_generate=True,
184 | max_new_tokens=5,
185 | )
186 |
187 |
188 | def print_result(batch_dict, scalars_preds_processed_key):
189 |
190 | value = batch_dict[scalars_preds_processed_key]
191 | ans = {"scalar_result": value}
192 |
193 | # Print prediction
194 |
195 | batch_dict["scalar_result"] = value
196 | print(f"estimated value: {ans}")
197 |
198 |
199 | def process_sample(tokenizer_op, nn_model, sample_dict):
200 | # running in generate mode
201 | batch_dict = nn_model.generate(
202 | [sample_dict],
203 | output_scores=True,
204 | return_dict_in_generate=True,
205 | max_new_tokens=5,
206 | )
207 |
208 | return batch_dict
209 |
210 |
211 | if __name__ == "__main__":
212 | main()
213 |
--------------------------------------------------------------------------------
/mammal/examples/protein_solubility/pl_data_module.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from collections.abc import Callable
4 |
5 | import pandas as pd
6 | import pytorch_lightning as pl
7 | import wget
8 | from fuse.data.datasets.dataset_default import DatasetDefault
9 | from fuse.data.ops.ops_read import OpReadDataframe
10 | from fuse.data.pipelines.pipeline_default import PipelineDefault
11 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
12 | from fuse.data.utils.collates import CollateDefault
13 | from torch.utils.data.dataloader import DataLoader
14 |
15 | from mammal.keys import * # noqa
16 |
17 |
18 | class ProteinSolubilityDataModule(pl.LightningDataModule):
19 | def __init__(
20 | self,
21 | *,
22 | data_path: str,
23 | batch_size: int,
24 | tokenizer_op: ModularTokenizerOp,
25 | train_dl_kwargs: dict,
26 | valid_dl_kwargs: dict,
27 | seed: int,
28 | data_preprocessing: Callable,
29 | protein_max_seq_length: int,
30 | encoder_input_max_seq_len: int,
31 | labels_max_seq_len: int,
32 | ) -> None:
33 | """_summary_
34 | Args:
35 | data_path (str): path to the raw data, if not exist, will download the data to the given path.
36 | batch_size (int): batch size
37 | tokenizer_op (ModularTokenizerOp): tokenizer op
38 | encoder_inputs_max_seq_len: max tokenizer sequence length for the encoder inputs,
39 | labels_max_seq_len: max tokenizer sequence length for the labels,
40 | train_dl_kwargs (dict): train dataloader constructor parameters
41 | valid_dl_kwargs (dict): validation dataloader constructor parameters
42 | seed (int): random seed
43 | """
44 | super().__init__()
45 | self.data_path = data_path
46 | self.tokenizer_op = tokenizer_op
47 | self.protein_max_seq_length = protein_max_seq_length
48 | self.encoder_input_max_seq_len = encoder_input_max_seq_len
49 | self.labels_max_seq_len = labels_max_seq_len
50 | self.batch_size = batch_size
51 | self.train_dl_kwargs = train_dl_kwargs
52 | self.valid_dl_kwargs = valid_dl_kwargs
53 | self.seed = seed
54 | self.data_preprocessing = data_preprocessing
55 |
56 | self.pad_token_id = self.tokenizer_op.get_token_id("")
57 |
58 | def setup(self, stage: str) -> None:
59 | self.ds_dict = load_datasets(self.data_path)
60 |
61 | task_pipeline = [
62 | (
63 | # Prepare the input string(s) in modular tokenizer input format
64 | self.data_preprocessing,
65 | dict(
66 | protein_sequence_key="data.protein",
67 | solubility_label_key="data.label",
68 | tokenizer_op=self.tokenizer_op,
69 | protein_max_seq_length=self.protein_max_seq_length,
70 | encoder_input_max_seq_len=self.encoder_input_max_seq_len,
71 | labels_max_seq_len=self.labels_max_seq_len,
72 | ),
73 | ),
74 | ]
75 |
76 | for ds in self.ds_dict.values():
77 | ds.dynamic_pipeline.extend(task_pipeline)
78 |
79 | def train_dataloader(self) -> DataLoader:
80 | train_loader = DataLoader(
81 | dataset=self.ds_dict["train"],
82 | batch_size=self.batch_size,
83 | collate_fn=CollateDefault(),
84 | shuffle=True,
85 | **self.train_dl_kwargs,
86 | )
87 | return train_loader
88 |
89 | def val_dataloader(self) -> DataLoader:
90 | val_loader = DataLoader(
91 | self.ds_dict["val"],
92 | batch_size=self.batch_size,
93 | collate_fn=CollateDefault(),
94 | **self.valid_dl_kwargs,
95 | )
96 |
97 | return val_loader
98 |
99 | def test_dataloader(self) -> DataLoader:
100 | test_loader = DataLoader(
101 | self.ds_dict["test"],
102 | batch_size=self.batch_size,
103 | collate_fn=CollateDefault(),
104 | **self.valid_dl_kwargs,
105 | )
106 |
107 | return test_loader
108 |
109 | def predict_dataloader(self) -> DataLoader:
110 | return self.test_dataloader()
111 |
112 |
113 | _SOLUBILITY_URL = "https://zenodo.org/api/records/1162886/files-archive"
114 |
115 |
116 | def load_datasets(data_path: str) -> dict[str, DatasetDefault]:
117 | """
118 | Automatically downloads the data and create dataset iterator for "train", "val" and "test".
119 | paper: https://academic.oup.com/bioinformatics/article/34/15/2605/4938490
120 | Data retrieved from: https://zenodo.org/records/1162886
121 | The benchmark requires classifying protein sequences into binary labels - Soluble or Insoluble (1 or 0).
122 | :param data_path: path to a directory to store the raw data
123 | :return: dictionary that maps fold name "train", "val" and "test" to a dataset iterator
124 | """
125 |
126 | if not os.path.exists(data_path):
127 | os.makedirs(data_path)
128 |
129 | raw_data_path = os.path.join(data_path, "sameerkhurana10-DSOL_rv0.2-20562ad/data")
130 | if not os.path.exists(raw_data_path):
131 | wget.download(_SOLUBILITY_URL, data_path)
132 | file_path = os.path.join(data_path, "1162886.zip")
133 | shutil.unpack_archive(file_path, extract_dir=data_path)
134 | inner_file_path = os.path.join(
135 | data_path, "sameerkhurana10", "DSOL_rv0.2-v0.3.zip"
136 | )
137 | shutil.unpack_archive(inner_file_path, extract_dir=data_path)
138 | assert os.path.exists(
139 | raw_data_path
140 | ), f"Error: download complete but {raw_data_path} doesn't exist"
141 |
142 | # read files
143 | df_dict = {}
144 | for set_name in ["train", "val", "test"]:
145 | input_df = pd.read_csv(
146 | os.path.join(raw_data_path, f"{set_name}_src"), names=["data.protein"]
147 | )
148 | labels_df = pd.read_csv(
149 | os.path.join(raw_data_path, f"{set_name}_tgt"), names=["data.label"]
150 | )
151 | df_dict[set_name] = (input_df, labels_df)
152 |
153 | ds_dict = {}
154 | for set_name in ["train", "val", "test"]:
155 | input_df, labels_df = df_dict[set_name]
156 | size = len(labels_df)
157 | print(f"{set_name} set size is {size}")
158 | dynamic_pipeline = PipelineDefault(
159 | "solubility",
160 | [
161 | (OpReadDataframe(input_df, key_column=None), dict()),
162 | (OpReadDataframe(labels_df, key_column=None), dict()),
163 | ],
164 | )
165 |
166 | ds = DatasetDefault(sample_ids=size, dynamic_pipeline=dynamic_pipeline)
167 | ds.create()
168 | ds_dict[set_name] = ds
169 |
170 | return ds_dict
171 |
172 |
173 | if __name__ == "__main__":
174 | load_datasets("data")
175 |
--------------------------------------------------------------------------------
/mammal/examples/carcinogenicity/task.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import numpy as np
4 | import pytorch_lightning as pl
5 | import torch
6 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp
7 |
8 | from mammal.examples.carcinogenicity.pl_data_module import CarcinogenicityDataModule
9 | from mammal.keys import (
10 | CLS_PRED,
11 | DECODER_INPUTS_ATTENTION_MASK,
12 | DECODER_INPUTS_STR,
13 | DECODER_INPUTS_TOKENS,
14 | ENCODER_INPUTS_ATTENTION_MASK,
15 | ENCODER_INPUTS_STR,
16 | ENCODER_INPUTS_TOKENS,
17 | LABELS_ATTENTION_MASK,
18 | LABELS_STR,
19 | LABELS_TOKENS,
20 | SCORES,
21 | )
22 | from mammal.metrics import classification_metrics
23 | from mammal.task import (
24 | MammalTask,
25 | MetricBase,
26 | )
27 |
28 |
29 | class CarcinogenicityTask(MammalTask):
30 | def __init__(
31 | self,
32 | *,
33 | tokenizer_op: ModularTokenizerOp,
34 | data_module_kwargs: dict,
35 | logger: Any | None = None,
36 | ) -> None:
37 | super().__init__(
38 | name="carcinogenicity",
39 | logger=logger,
40 | tokenizer_op=tokenizer_op,
41 | )
42 | self._data_module_kwargs = data_module_kwargs
43 |
44 | self.preds_key = CLS_PRED
45 | self.scores_key = SCORES
46 | self.labels_key = LABELS_TOKENS
47 |
48 | def data_module(self) -> pl.LightningDataModule:
49 | return CarcinogenicityDataModule(
50 | tokenizer_op=self._tokenizer_op,
51 | data_preprocessing=self.data_preprocessing,
52 | **self._data_module_kwargs,
53 | )
54 |
55 | def train_metrics(self) -> dict[str, MetricBase]:
56 | metrics = super().train_metrics()
57 | metrics.update(
58 | classification_metrics(
59 | self.name(),
60 | class_position=1,
61 | tokenizer_op=self._tokenizer_op,
62 | class_tokens=["<0>", "<1>"],
63 | )
64 | )
65 |
66 | return metrics
67 |
68 | def validation_metrics(self) -> dict[str, MetricBase]:
69 | validation_metrics = super().validation_metrics()
70 | validation_metrics.update(
71 | classification_metrics(
72 | self.name(),
73 | class_position=1,
74 | tokenizer_op=self._tokenizer_op,
75 | class_tokens=["<0>", "<1>"],
76 | )
77 | )
78 | return validation_metrics
79 |
80 | @staticmethod
81 | def data_preprocessing(
82 | sample_dict: dict,
83 | *,
84 | sequence_key: str,
85 | label_key: int | None = None,
86 | drug_max_seq_length: int = 1250,
87 | encoder_input_max_seq_len: int | None = 1260,
88 | labels_max_seq_len: int | None = 4,
89 | tokenizer_op: ModularTokenizerOp,
90 | device: str | torch.device = "cpu",
91 | ) -> dict:
92 | drug_sequence = sample_dict[sequence_key]
93 | label = sample_dict.get(label_key, None)
94 |
95 | sample_dict[ENCODER_INPUTS_STR] = (
96 | f"<@TOKENIZER-TYPE=SMILES><@TOKENIZER-TYPE=SMILES@MAX-LEN={drug_max_seq_length}>{drug_sequence}"
97 | )
98 | tokenizer_op(
99 | sample_dict=sample_dict,
100 | key_in=ENCODER_INPUTS_STR,
101 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS,
102 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,
103 | max_seq_len=encoder_input_max_seq_len,
104 | )
105 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(
106 | sample_dict[ENCODER_INPUTS_TOKENS], device=device
107 | )
108 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(
109 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=device
110 | )
111 |
112 | if label is not None:
113 | pad_id = tokenizer_op.get_token_id("")
114 | ignore_token_value = -100
115 | sample_dict[LABELS_STR] = (
116 | f"<@TOKENIZER-TYPE=SMILES><{label}>"
117 | )
118 | tokenizer_op(
119 | sample_dict=sample_dict,
120 | key_in=LABELS_STR,
121 | key_out_tokens_ids=LABELS_TOKENS,
122 | key_out_attention_mask=LABELS_ATTENTION_MASK,
123 | max_seq_len=labels_max_seq_len,
124 | )
125 | sample_dict[LABELS_TOKENS] = torch.tensor(
126 | sample_dict[LABELS_TOKENS], device=device
127 | )
128 | sample_dict[LABELS_ATTENTION_MASK] = torch.tensor(
129 | sample_dict[LABELS_ATTENTION_MASK], device=device
130 | )
131 | # replace pad_id with -100 to
132 | sample_dict[LABELS_TOKENS][
133 | (sample_dict[LABELS_TOKENS][..., None] == torch.tensor(pad_id))
134 | .any(-1)
135 | .nonzero()
136 | ] = ignore_token_value
137 |
138 | sample_dict[DECODER_INPUTS_STR] = (
139 | f"<@TOKENIZER-TYPE=SMILES><{label}>"
140 | )
141 | tokenizer_op(
142 | sample_dict=sample_dict,
143 | key_in=DECODER_INPUTS_STR,
144 | key_out_tokens_ids=DECODER_INPUTS_TOKENS,
145 | key_out_attention_mask=DECODER_INPUTS_ATTENTION_MASK,
146 | max_seq_len=labels_max_seq_len,
147 | )
148 | sample_dict[DECODER_INPUTS_TOKENS] = torch.tensor(
149 | sample_dict[DECODER_INPUTS_TOKENS], device=device
150 | )
151 | sample_dict[DECODER_INPUTS_ATTENTION_MASK] = torch.tensor(
152 | sample_dict[DECODER_INPUTS_ATTENTION_MASK], device=device
153 | )
154 |
155 | return sample_dict
156 |
157 | @staticmethod
158 | def process_model_output(
159 | tokenizer_op: ModularTokenizerOp,
160 | decoder_output: np.ndarray,
161 | decoder_output_scores: np.ndarray,
162 | ) -> dict:
163 | negative_token_id = tokenizer_op.get_token_id("<0>")
164 | positive_token_id = tokenizer_op.get_token_id("<1>")
165 | label_id_to_int = {
166 | negative_token_id: 0,
167 | positive_token_id: 1,
168 | }
169 | classification_position = 1
170 |
171 | if decoder_output_scores is not None:
172 | not_normalized_score = decoder_output_scores[
173 | classification_position, positive_token_id
174 | ]
175 | normalized_score = not_normalized_score / (
176 | not_normalized_score
177 | + decoder_output_scores[classification_position, negative_token_id]
178 | + 1e-10
179 | )
180 |
181 | ans = dict(
182 | pred=label_id_to_int.get(int(decoder_output[classification_position]), -1),
183 | not_normalized_scores=not_normalized_score,
184 | normalized_scores=normalized_score,
185 | )
186 |
187 | return ans
188 |
--------------------------------------------------------------------------------
/mammal/examples/scrna_cell_type/data/Zheng68k_to_anndata.py:
--------------------------------------------------------------------------------
1 | # script to pack zheng68k data downloaded from x10genomics into an AnnData/h5ad file.
2 |
3 | # This example follows [Zheng]() for identification of white blood cell types from single cell RNA expression data.
4 |
5 |
6 | # ## Outline of process
7 | # The finetune process requires the input data to be in scRNA-sec AnnData format (saved as an h5ad file) with cell types
8 | # as labels. If the data is not packed as AnnData, as is the case when downloading from the
9 | # 10xGenomics site as explained below, it need to first be packed into one and saved to the disk.
10 | # cell types, if present, should be stored in the `adata.obs['cell_type']` observation.
11 |
12 | # This script assumes that it is run from the data directory,
13 | # which is typically `bmfm-mammal-release/mammal/examples/scrna_cell_type/data`
14 |
15 | import os
16 | import subprocess
17 | from pathlib import Path
18 |
19 | import anndata
20 | import click
21 | import pandas as pd
22 | from scipy.io import mmread
23 |
24 | ### Obtaining the source data:
25 | # The main data is available online, for example in the [10xGenomics](https://www.10xgenomics.com/) cite.
26 | # The labels are based on the data in [LINK](https://www.10xgenomics.com/datasets/fresh-68-k-pbm-cs-donor-a-1-standard-1-1-0)
27 | # From this site download the file `fresh_68k_pbmc_donor_a_filtered_gene_bc_matrices.tar.gz` and place it in the data directory.
28 |
29 |
30 | DEFULT_LABELS_FILE = (
31 | "zheng17_bulk_lables.txt" # yes, the original file is named this way.
32 | )
33 | GZIP_FILE_NAME = "fresh_68k_pbmc_donor_a_filtered_gene_bc_matrices.tar.gz"
34 | RAW_H5AD_FILE = "Zheng_68k.h5ad"
35 | RAW_DATA_SUBDIR = Path("filtered_matrices_mex/hg19")
36 |
37 |
38 | @click.command()
39 | @click.option(
40 | "--output-h5ad-file",
41 | "-o",
42 | default=None,
43 | help="name of output H5AD file. default is adding '_preprocessed' to the input file",
44 | )
45 | @click.option(
46 | "--data-dir",
47 | help="dirname for the downloaded and constructed data files",
48 | default=".",
49 | )
50 | @click.option(
51 | "--labels_file",
52 | "-l",
53 | default=DEFULT_LABELS_FILE,
54 | )
55 | @click.option(
56 | "--labels_key",
57 | "-k",
58 | default="cell_type",
59 | help="key to use for the cell type labels in the AnnData observations.",
60 | )
61 | @click.option(
62 | "--verbose", "-v", is_flag=True, default=False, help="be verbose (default: off)."
63 | )
64 | def main(
65 | output_h5ad_file: str,
66 | data_dir: os.PathLike,
67 | labels_file: os.PathLike,
68 | labels_key: str,
69 | verbose: bool = False,
70 | ):
71 | # all work is done in the data dir
72 | os.chdir(data_dir)
73 | barcode_file = RAW_DATA_SUBDIR / "barcodes.tsv"
74 | genes_file = RAW_DATA_SUBDIR / "genes.tsv"
75 | matrix_file = RAW_DATA_SUBDIR / "matrix.mtx"
76 |
77 | if not RAW_DATA_SUBDIR.exists():
78 | # check if the file exists
79 | if not os.path.exists(GZIP_FILE_NAME):
80 | print(
81 | f"please download the file {GZIP_FILE_NAME} from https://www.10xgenomics.com/datasets/fresh-68-k-pbm-cs-donor-a-1-standard-1-1-0 into this data directory and then run this script again from that directory"
82 | )
83 | raise FileNotFoundError(
84 | f"Both the {GZIP_FILE_NAME} and the raw data directory {RAW_DATA_SUBDIR} extracted from it not found under the current directory"
85 | )
86 | else:
87 | if verbose:
88 | print(f"extracting files from {GZIP_FILE_NAME}")
89 | subprocess.run(["tar", "xvzf", GZIP_FILE_NAME], check=True)
90 |
91 | if labels_file is not None: # if we do not want to add labels
92 | if not os.path.exists(labels_file):
93 | if (
94 | labels_file == DEFULT_LABELS_FILE
95 | ): # special case - we can download this file if needed.
96 | labels_file_url = "https://raw.githubusercontent.com/scverse/scanpy_usage/refs/heads/master/170503_zheng17/data/zheng17_bulk_lables.txt"
97 | if verbose:
98 | print(f"Missing cell-type-labels file {labels_file}")
99 | print(f"downloading it from {labels_file_url}")
100 | subprocess.run(["wget", labels_file_url], check=True)
101 | if verbose:
102 | print("downloaded")
103 | else:
104 | raise FileNotFoundError("please supply labels file")
105 | raw_adata = create_anndata_from_csv(
106 | barcode_file,
107 | genes_file,
108 | matrix_file,
109 | labels_file=labels_file,
110 | labels_key=labels_key,
111 | )
112 |
113 | # Save result anndata object to disk
114 | raw_adata.write_h5ad(output_h5ad_file)
115 |
116 | if verbose:
117 | print(f"the raw {output_h5ad_file} is ready")
118 | print(
119 | "to use this h5ad file please filter, normalize and bin the counts. You can use process_h5ad_data.py to do this."
120 | )
121 |
122 | # This is a standard AnnData file, and can be replaced by your own data.
123 |
124 |
125 | def create_anndata_from_csv(
126 | barcode_file,
127 | genes_file,
128 | matrix_file,
129 | labels_file,
130 | labels_key,
131 | ) -> anndata.AnnData:
132 | """Construct an h5ad file (an anndata object dump) from its components.
133 |
134 |
135 | Args:
136 | raw_h5ad_file (os.PathLike): name of file to save constructed AnnData into
137 | barcode_file (os.PathLike): this file holds the mapping from the sample index to the cell identifier
138 | genes_file (os.PathLike): Mapping from feature index to gene name
139 | matrix_file (os.PathLike): The actual data, is (sparse) matrix form.
140 | labels_file (os.PathLike, optional): File containing the cell types for each file.
141 | labels_key (str): name of observation to place labels under in the AnnData object.
142 | verbose (bool, optional): verbose output. Defaults to False.
143 |
144 | Returns:
145 | anndata.AnnData: the generated anndata object
146 | """
147 | mmx = mmread(matrix_file)
148 |
149 | #### Create an AnnData object wrapping the read data
150 |
151 | # Notice that this code transposes the data to the correct direction
152 |
153 | anndata_object = anndata.AnnData(X=mmx.transpose().tocsr())
154 |
155 | # Cell identifiers
156 | observation_names = pd.read_csv(barcode_file, header=None, sep="\t")
157 | # names of genes
158 | genes = pd.read_csv(genes_file, header=None, sep="\t")
159 |
160 | # use the gene names as variable names in the AnnData object
161 | anndata_object.var_names = genes[1]
162 |
163 | # use the cell barcodes as names for the samples
164 | anndata_object.obs_names = observation_names[0]
165 |
166 | if labels_file is not None:
167 | # cell types (this is actualy just one column)
168 | cell_type_labels = pd.read_csv(labels_file, header=None, sep="\t")
169 | # use cell types as labels for the samples
170 | anndata_object.obs[labels_key] = cell_type_labels.squeeze().to_numpy()
171 |
172 | return anndata_object
173 |
174 |
175 | if __name__ == "__main__":
176 | main()
177 |
--------------------------------------------------------------------------------
/mammal_mcp/README.md:
--------------------------------------------------------------------------------
1 | # Mammal MCP server - making mammal tasks accessible to AI Agents.
2 |
3 | A service that provides the `ibm/biomed.omics.bl.sm.ma-ted-458m` model tasks to AI Agents.
4 |
5 | ## Overview
6 |
7 | MAMMAL (ibm/biomed.omics.bl.sm.ma-ted-458m) is a 'biomedical foundation model' (BMFM) that has been trained by IBM and the details can be found here -> https://github.com/BiomedSciAI/biomed-multi-alignment.
8 |
9 | This repository is a fastmcp server which creates entrypoints for AI Agents to make inference for tasks currently supported by MAMMAL.
10 |
11 | ## Getting started
12 |
13 | Change to mammal-mcp directory
14 |
15 | ```sh
16 | cd mammal_mcp
17 | ```
18 |
19 | Create the environment:
20 |
21 | ```sh
22 | cp .env.example .env
23 | ```
24 |
25 | These env vars control which modalities will be available to the agent.
26 |
27 |
28 | The first time you run the server you need to download the models. Therefore set all the tasks that you will subsequently use to true in the .env and run:
29 |
30 | ```sh
31 | uv run python -m server
32 | ```
33 |
34 | Then wait for all the models to be downloaded before quiting the server.
35 |
36 | ## Running the server using STDIO (default)
37 |
38 | ### Integration into Claude Desktop
39 |
40 | **If using Claude as your MCP client, DO NOT, add any confidential or personal data into the system.**
41 |
42 | One of the easiest ways to experiment with the tools provided by mammal-mcp is to leverage the Claude Desktop.
43 |
44 | For that, update your Claude Desktop config file (located at `~/Library/Application Support/Claude/claude_desktop_config.json`) with the JSON below:
45 |
46 | ```json
47 | {
48 | "mcpServers": {
49 | "mammal": {
50 | "command": "