├── mammal ├── __init__.py ├── examples │ ├── molnet │ │ ├── __init__.py │ │ └── molnet_infer.py │ ├── carcinogenicity │ │ ├── __init__.py │ │ ├── config.yaml │ │ ├── main_infer.py │ │ ├── pl_data_module.py │ │ └── task.py │ ├── scrna_cell_type │ │ ├── __init__.py │ │ ├── cell_type_mapping.csv │ │ ├── data │ │ │ ├── process_h5ad_data.py │ │ │ └── Zheng68k_to_anndata.py │ │ ├── anndata_op.py │ │ ├── config.yaml │ │ ├── scRNA_infer.py │ │ ├── pl_data_module.py │ │ └── task.py │ ├── dti_bindingdb_kd │ │ ├── __init__.py │ │ ├── config.yaml │ │ ├── main_infer.py │ │ ├── pl_data_module.py │ │ └── task.py │ ├── protein_solubility │ │ ├── __init__.py │ │ ├── config.yaml │ │ ├── main_infer.py │ │ ├── pl_data_module.py │ │ └── task.py │ ├── tests │ │ ├── test_molnet.py │ │ ├── test_tcr_epitope_binding_inference.py │ │ ├── test_simple_inference.py │ │ ├── test_protein_solubility_prediction.py │ │ ├── test_drug_carcinogenicity_classification.py │ │ ├── test_dti_bindingdb_kd.py │ │ └── test_main_finetune.py │ └── tcr_epitope_binding │ │ └── main_infer.py ├── lora.py ├── lr_schedulers.py ├── keys.py ├── main_finetune.py ├── losses.py ├── task.py └── metrics.py ├── mammal.png ├── mammal_preview_2.pdf ├── mammal_mcp ├── .env.example ├── pyproject.toml ├── util.py ├── .gitignore ├── dependencies.py └── README.md ├── CLONE.md ├── .github ├── ISSUE_TEMPLATE │ ├── ---question.md │ ├── ---bug-report.md │ └── ---feature-request.md └── workflows │ ├── github-actions.yml │ ├── python-publish.yml │ └── clone.yml ├── CONTRIBUTING.md ├── .pre-commit-config.yaml ├── .gitignore ├── pyproject.toml ├── tutorials └── begginer_inference.ipynb └── README.md /mammal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mammal/examples/molnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mammal/examples/carcinogenicity/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mammal/examples/scrna_cell_type/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mammal/examples/dti_bindingdb_kd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mammal/examples/protein_solubility/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mammal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/biomed-multi-alignment/HEAD/mammal.png -------------------------------------------------------------------------------- /mammal_preview_2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/biomed-multi-alignment/HEAD/mammal_preview_2.pdf -------------------------------------------------------------------------------- /mammal_mcp/.env.example: -------------------------------------------------------------------------------- 1 | PROTEIN_PROTEIN_INTERACTION=false 2 | PROTEIN_SOLUBILITY=false 3 | TCR_EPITOPE_BINDING=false 4 | DRUG_TARGET_BINDING=false 5 | DRUG_TARGET_BINDING_FASTA=true 6 | 7 | STREAMABLE_HTTP=false 8 | SSE=false 9 | PORT=8001 10 | -------------------------------------------------------------------------------- /mammal/examples/tests/test_molnet.py: -------------------------------------------------------------------------------- 1 | import socket 2 | 3 | import pytest 4 | 5 | from mammal.examples.molnet.molnet_infer import load_model, task_infer 6 | 7 | 8 | @pytest.mark.skipif( 9 | "ccc" not in socket.gethostname(), 10 | reason="Train consumes too much memory for a Travis run.", 11 | ) 12 | def test_infer(): 13 | smiles_seq = "C(Cl)Cl" 14 | for task_name in ["BBBP", "TOXICITY", "FDA_APPR"]: 15 | task_dict = load_model(task_name=task_name, device="cpu") 16 | result = task_infer(task_dict=task_dict, smiles_seq=smiles_seq) 17 | print(f"The prediction for {smiles_seq=} is {result}") 18 | -------------------------------------------------------------------------------- /CLONE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | **Markdown** 4 | 5 | ```markdown 6 | [![GitHub Clones](https://img.shields.io/badge/dynamic/json?color=success&label=Clone&query=count&url=https://gist.githubusercontent.com/mosheraboh/a19913f8cf752e05e84f0d09d997a403/raw/clone.json&logo=github)](https://github.com/MShawon/github-clone-count-badge) 7 | 8 | ``` 9 | 10 | **HTML** 11 | ```html 12 | GitHub Clones 13 | ``` 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/---question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: ❓ Question 3 | description: Use this template to ask a question about MAMMAL's code or its paper details. 4 | title: "[Q]: " 5 | labels: 6 | - "ty:question" 7 | 8 | body: 9 | 10 | - type: markdown 11 | attributes: 12 | value: > 13 | **Thanks :heart: for taking the time to fill out this bug report!** We kindly ask that you search to see if an 14 | issue [already exists](https://github.com/BiomedSciAI/biomed-multi-alignment/issues) for the bug you encountered. 15 | 16 | - type: textarea 17 | attributes: 18 | label: Ask your question 19 | value: | 20 | 21 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MAMMAL🦍 2 | 3 | First off, thanks for taking the time to contribute! 4 | 5 | ## How do I report a bug or suggest an enhancement? 6 | 7 | - As a first step, please search in the [existing issues](https://github.com/BiomedSciAI/biomed-multi-alignment/issues) to check if your point has already been addressed. 8 | - If that is not the case, go ahead and [create an issue](https://github.com/BiomedSciAI/biomed-multi-alignment/issues/new/choose) of the respective type, providing the details as instructed in the template. 9 | 10 | ## How do I submit a change? 11 | 12 | We welcome contributions via pull requests: 13 | 14 | - Fork (or clone) the repo and create your branch from the up-to-date default branch. 15 | - If you have added code that should be tested - add tests. 16 | - If any documentation updates are needed - make them. 17 | - Ensure the test suite passes and the code lints. 18 | - Submit the pull request. 19 | -------------------------------------------------------------------------------- /.github/workflows/github-actions.yml: -------------------------------------------------------------------------------- 1 | name: GitHub Actions - MAMMAL 2 | 3 | on: 4 | pull_request: 5 | branches: [ "main" ] 6 | 7 | jobs: 8 | build-linux: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | python-version: ["3.10", "3.11"] 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install -q .[examples] 25 | pip install pre-commit 26 | 27 | - name: Pre-Commit Hooks 28 | run: | 29 | pre-commit install 30 | pre-commit run --all-files --show-diff-on-failure 31 | - name: Test with pytest 32 | run: | 33 | python -m pytest --capture=no . 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/---bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐞 Bug Report 3 | description: Use this template to report a bug in MAMMAL. 4 | title: "[Bug]: " 5 | labels: 6 | - "a:app" 7 | - "ty:bug" 8 | 9 | body: 10 | - type: markdown 11 | attributes: 12 | value: > 13 | **Thanks :heart: for taking the time to fill out this bug report!** We kindly ask that you search to see if an 14 | issue [already exists](https://github.com/BiomedSciAI/biomed-multi-alignment/issues) for the bug you encountered. 15 | 16 | - type: textarea 17 | attributes: 18 | label: Describe the bug 19 | description: > 20 | A clear and concise description of the issue you're experiencing. If relevant, add screenshots or videos to help 21 | explain the bug, or any steps to reproduce the bug. If relevant describe the expected behavior. Please include 22 | the following information: 23 | 24 | - python version 25 | 26 | - torch version 27 | 28 | - os type and version 29 | value: | 30 | 31 | -------------------------------------------------------------------------------- /mammal/examples/tests/test_tcr_epitope_binding_inference.py: -------------------------------------------------------------------------------- 1 | from mammal.examples.tcr_epitope_binding.main_infer import load_model, task_infer 2 | 3 | 4 | def test_infer() -> None: 5 | """ 6 | A test for TCR beta chain and epitope binding example on HF, https://huggingface.co/ibm/biomed.omics.bl.sm.ma-ted-458m 7 | """ 8 | # positive 1: 9 | tcr_beta_seq = "NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT" 10 | epitope_seq = "LLQTGIHVRVSQPSL" 11 | 12 | # positive 2: 13 | # tcr_beta_seq = "GAVVSQHPSWVICKSGTSVKIECRSLDFQATTMFWYRQFPKQSLMLMATSNEGSKATYEQGVEKDKFLINHASLTLSTLTVTSAHPEDSSFYICSASEGTSSYEQYFGPGTRLTVT" 14 | # epitope_seq = "FLKEKGGL" 15 | 16 | model_inst, tokenizer_op = load_model(device="cpu") 17 | result = task_infer( 18 | model=model_inst, 19 | tokenizer_op=tokenizer_op, 20 | tcr_beta_seq=tcr_beta_seq, 21 | epitope_seq=epitope_seq, 22 | ) 23 | print(f"The prediction for {epitope_seq} and {tcr_beta_seq} is {result}") 24 | 25 | 26 | if __name__ == "__main__": 27 | test_infer() 28 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /mammal/examples/scrna_cell_type/cell_type_mapping.csv: -------------------------------------------------------------------------------- 1 | celltype,cell_type_ontology_term_name,cell_type_ontology_term_id,comments 2 | "CD14+ Monocyte","CD14-positive monocyte",CL:0001054, 3 | "CD19+ B","B cell",CL:0000236,"exact mapping is CL:0001201 but it doesn't exists in cellxgene so a parent celltype was chosen" 4 | "CD34+","CD34-positive, CD38-positive common myeloid progenitor OR CD34-positive, CD38-positive common lymphoid progenitor",CL:0000049,"exact mapping is CL:0000995 but it doesn't exists in cellxgene so a parent celltype was chosen" 5 | "CD4+/CD25 T Reg","CD4-positive, CD25-positive, alpha-beta regulatory T cell",CL:0000792, 6 | "CD4+/CD45RA+/CD25- Naive T","naive thymus-derived CD4-positive, alpha-beta T cell",CL:0000895, 7 | "CD4+/CD45RO+ Memory","CD4-positive, alpha-beta memory T cell, CD45RO-positive",CL:0001204, 8 | "CD4+ T Helper2","T-helper 2 cell",CL:0000546, 9 | "CD56+ NK","CD16-positive, CD56-dim natural killer cell, human",CL:0000939, 10 | "CD8+/CD45RA+ Naive Cytotoxic","Effector Memory Cd8-Positive, Alpha-Beta T Cell, Terminally Differentiated",CL:0001062,"mapped also to parentsm, may need review" 11 | "CD8+ Cytotoxic T","CD8-positive, alpha-beta cytotoxic T cell",CL:0000794, 12 | "Dendritic","dendritic cell",CL:0000451,"maybe we need the child cell type - dendtiric, human" 13 | -------------------------------------------------------------------------------- /mammal_mcp/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mammal_mcp" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | authors = [ 6 | { name = "IBM Research" }, 7 | { name = "Ash Evans", email = "ash.evans@ibm.com" }, 8 | { name = "Jennifer Kelly", email = "jennifer.kelly@ibm.com" }, 9 | { name = "Laura-Jayne Gardiner", email = "laura-jayne.gardiner@ibm.com" } 10 | ] 11 | readme = "README.md" 12 | requires-python = ">=3.10, <3.12" 13 | dependencies = [ 14 | "fastmcp>=2.6.1", 15 | "fuse-med-ml>=0.4.0", 16 | "mcp[cli]>=1.6.0", 17 | "mypy>=1.17.1", 18 | "pre-commit>=4.3.0", 19 | "pydantic>=2.11.1", 20 | "pydantic-settings>=2.8.1", 21 | "pytdc>=0.4.1", 22 | "pytest>=8.3.5", 23 | "python-dotenv>=1.1.0", 24 | "requests==2.24.0", 25 | "types-requests>=2.31.0.6", 26 | "uv>=0.6.11", 27 | "bio>=1.7.1", 28 | "biomed-multi-alignment", 29 | ] 30 | 31 | [dependency-groups] 32 | dev = [ 33 | "types-requests>=2.31.0.6", 34 | ] 35 | 36 | [tool.uv.sources] 37 | biomed-multi-alignment = { git = "https://github.com/BiomedSciAI/biomed-multi-alignment.git", rev = "v0.2.2" } 38 | 39 | [project.optional-dependencies] 40 | dev = [ 41 | "types-requests", 42 | "mypy", 43 | "types-PyYAML", 44 | "types-decorator", 45 | "types-simplejson", 46 | "types-tabulate", 47 | ] 48 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/---feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🛠️ Feature Request 3 | description: Suggest an idea to help us improve MAMMAL 4 | title: "[Feature]: " 5 | labels: 6 | - "ty:feature" 7 | 8 | body: 9 | - type: markdown 10 | attributes: 11 | value: > 12 | **Thanks :heart: for taking the time to fill out this feature request report!** We kindly ask that you search to 13 | see if an issue [already exists](https://github.com/BiomedSciAI/biomed-multi-alignment/issues) for 14 | your feature. 15 | 16 | We are also happy to accept contributions from our users. For more details see 17 | [here](https://github.com/BiomedSciAI/biomed-multi-alignment/blob/main/CONTRIBUTING.md). 18 | 19 | - type: textarea 20 | attributes: 21 | label: Description 22 | description: | 23 | A clear and concise description of the feature you're interested in. 24 | value: | 25 | 26 | validations: 27 | required: true 28 | 29 | - type: textarea 30 | attributes: 31 | label: Suggested Solution 32 | description: > 33 | Describe the solution you'd like. A clear and concise description of what you want to happen. If you have 34 | considered alternatives, please describe them. 35 | value: | 36 | 37 | validations: 38 | required: false 39 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: .*\.pdb$ 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: check-case-conflict 8 | - id: end-of-file-fixer 9 | - id: mixed-line-ending 10 | - id: trailing-whitespace 11 | - repo: https://github.com/psf/black 12 | rev: 25.1.0 13 | hooks: 14 | - id: black 15 | - repo: https://github.com/PyCQA/flake8 16 | rev: 7.1.2 17 | hooks: 18 | - id: flake8 19 | args: 20 | - "--ignore=E203,E266,E501,F405,F403,W503" 21 | - "--statistics" 22 | 23 | - repo: https://github.com/astral-sh/ruff-pre-commit 24 | # Ruff version. 25 | rev: v0.9.10 26 | hooks: 27 | - id: ruff 28 | args: 29 | - "--fix" 30 | - "--select" 31 | - "UP,PT,I,E"#,F,W,C90,I,N,F405,E402" # Specify the rules to select 32 | - "--line-length" 33 | - "88" 34 | - "--exit-non-zero-on-fix" 35 | - "--ignore" 36 | - "F405,F403,E501,E402,PT018,PT015,E722,E741" 37 | types_or: [ python, pyi] #, jupyter ] 38 | - repo: https://github.com/pre-commit/mirrors-mypy 39 | rev: v1.15.0 40 | hooks: 41 | - id: mypy 42 | additional_dependencies: [types-requests] 43 | 44 | - repo: https://github.com/srstevenson/nb-clean 45 | rev: "4.0.1" 46 | hooks: 47 | - id: nb-clean 48 | args: 49 | - --remove-empty-cells 50 | - --preserve-cell-outputs 51 | -------------------------------------------------------------------------------- /mammal/examples/scrna_cell_type/data/process_h5ad_data.py: -------------------------------------------------------------------------------- 1 | import anndata 2 | import click 3 | 4 | from mammal.examples.scrna_cell_type.pl_data_module import preprocess_ann_data 5 | 6 | 7 | @click.command() 8 | @click.option( 9 | "--input-h5ad-file", 10 | "-i", 11 | prompt=True, 12 | help="name of input H5AD file", 13 | ) 14 | @click.option( 15 | "--output-h5ad-file", 16 | "-o", 17 | prompt=True, 18 | help="name of output H5AD file", 19 | ) 20 | @click.option( 21 | "--min-genes", 22 | "-m", 23 | type=click.INT, 24 | help="minimal number of different genes per cell. Used for filtering", 25 | default=200, 26 | ) 27 | @click.option( 28 | "--num-bins", 29 | "-b", 30 | type=click.INT, 31 | help="number of expression bins to use", 32 | default=10, 33 | ) 34 | def main( 35 | input_h5ad_file: str, 36 | output_h5ad_file: str, 37 | min_genes: int = 200, 38 | num_bins: int = 10, 39 | ): 40 | 41 | anndata_object = anndata.read_h5ad(input_h5ad_file) 42 | # process the data - filter out cells with shallow reads, normelize depth and change to log scale of about 0-10 (log_2(1001)~=10) 43 | preprocess_ann_data( 44 | anndata_object=anndata_object, 45 | min_genes=min_genes, 46 | num_bins=num_bins, 47 | ) 48 | # Save result anndata object to disk 49 | anndata_object.write_h5ad(output_h5ad_file) 50 | print(f"processed AnnData file saved to {output_h5ad_file}") 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /mammal/examples/scrna_cell_type/anndata_op.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from anndata import AnnData 3 | from fuse.data import OpBase 4 | from fuse.utils.ndict import NDict 5 | 6 | from mammal.keys import SAMPLE_ID 7 | 8 | 9 | class OpReadAnnData(OpBase): 10 | """ 11 | Op reading data from anndata. 12 | Each row will be added as a value to sample dict. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | data: AnnData | None = None, 18 | key_name: str = SAMPLE_ID, 19 | label_column: str = "label", 20 | ): 21 | """ 22 | :param data: input AnnData object 23 | :param key_name: name of value in sample_dict which will be used as the key/index 24 | :param label_column: name of the column which contains the label 25 | """ 26 | super().__init__() 27 | 28 | self._key_name = key_name 29 | self._data = data 30 | self.label_column = label_column 31 | self.gene_names = np.array(self._data.var_names) 32 | 33 | def __call__( 34 | self, sample_dict: NDict, prefix: str | None = None 35 | ) -> None | dict | list[dict]: 36 | """ 37 | See base class 38 | 39 | :param prefix: specify a prefix for the sample dict keys. 40 | For example, with prefix 'data.features' and a df with the columns ['height', 'weight', 'sex'], 41 | the matching keys will be: 'data.features.height', 'data.features.weight', 'data.features.sex'. 42 | """ 43 | 44 | key = sample_dict[self._key_name] 45 | 46 | # locate the required item 47 | sample_dict[f"{prefix}.scrna"] = self._data[key, :].X 48 | sample_dict["data.label"] = self._data.obs.iloc[key].get(self.label_column) 49 | sample_dict[f"{prefix}.gene_names"] = self.gene_names 50 | 51 | return sample_dict 52 | -------------------------------------------------------------------------------- /mammal_mcp/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 3 | 4 | 5 | def process_model_output( 6 | tokenizer_op: ModularTokenizerOp, 7 | decoder_output: np.ndarray, 8 | decoder_output_scores: np.ndarray, 9 | ) -> dict: 10 | """ 11 | Extract predicted solubility class and scores 12 | expecting decoder output to be <0> or <1> 13 | note - the normalized version will calculate the positive ('<1>') score divided by the sum of the scores for both '<0>' and '<1>' 14 | BE CAREFUL as both negative and positive absolute scores can be drastically low, and normalized score could be very high. 15 | outputs a dictionary containing: 16 | dict( 17 | predicted_token_str = #... e.g. '<1>' 18 | not_normalized_score = #the score for the positive token... e.g. 0.01 19 | normalized_score = #... (positive_token_score) / (positive_token_score+negative_token_score) 20 | ) 21 | if there is any error in parsing the model output, None is returned. 22 | """ 23 | 24 | negative_token_id = tokenizer_op.get_token_id("<0>") 25 | positive_token_id = tokenizer_op.get_token_id("<1>") 26 | label_id_to_int = { 27 | negative_token_id: 0, 28 | positive_token_id: 1, 29 | } 30 | classification_position = 1 31 | 32 | if decoder_output_scores is not None: 33 | not_normalized_score = decoder_output_scores[ 34 | classification_position, positive_token_id 35 | ] 36 | normalized_score = not_normalized_score / ( 37 | not_normalized_score 38 | + decoder_output_scores[classification_position, negative_token_id] 39 | + 1e-10 40 | ) 41 | ans = dict( 42 | pred=label_id_to_int.get(int(decoder_output[classification_position]), -1), 43 | not_normalized_scores=not_normalized_score, 44 | normalized_scores=normalized_score, 45 | ) 46 | 47 | return ans 48 | -------------------------------------------------------------------------------- /mammal/examples/tests/test_simple_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 3 | 4 | from mammal.keys import * 5 | from mammal.model import Mammal 6 | 7 | 8 | def test_simple_infer() -> None: 9 | """ 10 | A simple function that test the proposed inference example on HF, https://huggingface.co/ibm/biomed.omics.bl.sm.ma-ted-458m 11 | """ 12 | 13 | # Load Model 14 | model = Mammal.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m") 15 | model.eval() 16 | 17 | # Load Tokenizer 18 | tokenizer_op = ModularTokenizerOp.from_pretrained( 19 | "ibm/biomed.omics.bl.sm.ma-ted-458m" 20 | ) 21 | 22 | # Prepare Input Prompt 23 | protein_calmodulin = "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK" 24 | protein_calcineurin = "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ" 25 | 26 | # Create and load sample 27 | sample_dict = dict() 28 | # Formatting prompt to match pre-training syntax 29 | sample_dict[ENCODER_INPUTS_STR] = ( 30 | f"<@TOKENIZER-TYPE=AA>{protein_calmodulin}{protein_calcineurin}" 31 | ) 32 | 33 | # Tokenize 34 | tokenizer_op( 35 | sample_dict=sample_dict, 36 | key_in=ENCODER_INPUTS_STR, 37 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS, 38 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK, 39 | ) 40 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor( 41 | sample_dict[ENCODER_INPUTS_TOKENS] 42 | ) 43 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor( 44 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] 45 | ) 46 | 47 | # Generate Prediction 48 | batch_dict = model.generate( 49 | [sample_dict], 50 | output_scores=True, 51 | return_dict_in_generate=True, 52 | max_new_tokens=5, 53 | ) 54 | 55 | # Get output 56 | generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]) 57 | print(f"{generated_output=}") 58 | -------------------------------------------------------------------------------- /mammal/examples/carcinogenicity/config.yaml: -------------------------------------------------------------------------------- 1 | name: carcinogenicity_finetune 2 | root: "." 3 | model_dir: ${root}/${name} 4 | seed: 4224 5 | 6 | task: 7 | _target_: mammal.examples.carcinogenicity.task.CarcinogenicityTask 8 | _partial_: True 9 | # details about the arguments below can be found on mammal.examples.carcinogenicity.task.CarcinogenicityTask() 10 | 11 | data_module_kwargs: 12 | # details about the arguments below can be found on mammal.examples.carcinogenicity.pl_data_module.CarcinogenicityDataModule() 13 | batch_size: 15 14 | drug_max_seq_length: 300 # Maximum drug length in the dataset is 292. 15 | encoder_input_max_seq_len: 320 # 20 chars buffer for special tokens. 16 | labels_max_seq_len: 4 17 | 18 | 19 | # tokenizer 20 | tokenizer: 21 | tokenizer_path: ibm/biomed.omics.bl.sm.ma-ted-458m 22 | new_special_tokens: 23 | - "" 24 | 25 | model: 26 | mammal_kwargs: null # arguments for Mammal.__init__() 27 | pretrained_kwargs: # arguments for Mammal.from_pretrained() which triggered only if pretrained_model_name_or_path is not None 28 | pretrained_model_name_or_path: ibm/biomed.omics.bl.sm.ma-ted-458m 29 | 30 | # lightning module 31 | module: 32 | opt_callable: 33 | _target_: torch.optim.AdamW 34 | _partial_: true 35 | # arguments for torch.optim.AdamW() 36 | lr: 0.00001 37 | 38 | lr_sch_callable: 39 | _target_: mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler 40 | _partial_: True 41 | T_max: 10000 42 | num_warmup_steps: 300 43 | eta_min_factor: 0.1 44 | 45 | model_dir: ${model_dir} 46 | best_epoch_source: 47 | monitor: validation.metrics.carcinogenicity_acc # possible options are validation.metrics._ 48 | 49 | mode: max 50 | 51 | # train 52 | trainer: 53 | # arguments for pytorch_lightning.Trainer() 54 | max_epochs: 100 55 | default_root_dir: ${model_dir} 56 | accelerator: "auto" 57 | devices: 1 58 | num_nodes: 1 59 | num_sanity_val_steps: 0 60 | 61 | 62 | # experiment tracker 63 | track_clearml: # arguments for fuse.dl.lightning.pl_funcs.start_clearml_logger 64 | project_name: "mammal/opensource" 65 | task_name: ${name} 66 | tags: "mammal" 67 | reuse_last_task_id: True 68 | continue_last_task: False 69 | offline_mode: False 70 | 71 | evaluate : false #if true then it will use lightning's validate on the test dataloader 72 | 73 | hydra: 74 | run: 75 | dir: ${model_dir} 76 | job: 77 | chdir: False 78 | -------------------------------------------------------------------------------- /mammal/examples/scrna_cell_type/config.yaml: -------------------------------------------------------------------------------- 1 | name: cell_type_finetune 2 | root: "." 3 | model_dir: ${root}/${name} 4 | seed: 2024 5 | 6 | task: 7 | _target_: mammal.examples.scrna_cell_type.task.CellTypeTask 8 | _partial_: True 9 | data_module_kwargs: 10 | data_path: "data/Zheng_68k_preprocessed.h5ad" # this should be absolute or relative to the directory with the example code 11 | # this is the name of the observation the model will try to predict 12 | label_name: "cell-type" 13 | batch_size: 20 14 | # tokenizer_op is provided later, dynamically 15 | train_dl_kwargs: # Dataloader constructor parameters 16 | num_workers: 8 17 | valid_dl_kwargs: # Dataloader constructor parameters 18 | num_workers: 8 19 | # data_preprocessing is provided later, dynamically 20 | input_max_seq_length: 500 21 | encoder_input_max_seq_len: 512 22 | labels_max_seq_len: 20 23 | 24 | # tokenizer 25 | tokenizer: 26 | tokenizer_path: ibm-research/biomed.omics.bl.sm.ma-ted-458m 27 | 28 | model: 29 | pretrained_kwargs: # arguments for Mammal.from_pretrained() which triggered only if pretrained_model_name_or_path is not None 30 | pretrained_model_name_or_path: ibm-research/biomed.omics.bl.sm.ma-ted-458m 31 | # config_overrides: 32 | # use_lora: True 33 | 34 | # lightning module 35 | module: 36 | opt_callable: 37 | _target_: torch.optim.AdamW 38 | _partial_: true # should get also parameters 39 | lr: 0.00001 40 | 41 | lr_sch_callable: 42 | _target_: mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler 43 | _partial_: True 44 | T_max: 10000 45 | num_warmup_steps: 300 46 | eta_min_factor: 0.1 47 | 48 | model_dir: ${model_dir} 49 | best_epoch_source: 50 | monitor: validation.metrics.cell_type_acc # see metrics.py:classification_metrics 51 | mode: max 52 | 53 | # train 54 | trainer: 55 | # arguments for pytorch_lightning.Trainer() 56 | max_epochs: 100 57 | default_root_dir: ${model_dir} 58 | num_sanity_val_steps: 0 59 | # val_check_interval: 0.1 60 | 61 | 62 | # experiment tracker 63 | track_clearml: # arguments for fuse.dl.lightning.pl_funcs.start_clearml_logger 64 | project_name: "mammal/opensource" 65 | task_name: ${name} 66 | tags: "mammal" 67 | reuse_last_task_id: True 68 | continue_last_task: False 69 | offline_mode: False 70 | 71 | evaluate : false #if true then it will use lightning's validate on the test dataloader 72 | 73 | hydra: 74 | run: 75 | dir: ${model_dir} 76 | job: 77 | chdir: False 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | t.py 131 | sol_test 132 | 133 | .vscode/settings.json 134 | .vscode/launch.json 135 | example_solubility_data 136 | -------------------------------------------------------------------------------- /mammal/examples/carcinogenicity/main_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import click 4 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 5 | 6 | from mammal.examples.carcinogenicity.task import CarcinogenicityTask 7 | from mammal.keys import CLS_PRED, SCORES 8 | from mammal.model import Mammal 9 | 10 | 11 | @click.command() 12 | @click.argument("finetune_output_dir") 13 | @click.argument("drug_seq") 14 | @click.option( 15 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')." 16 | ) 17 | def main(finetune_output_dir: str, drug_seq: str, device: str) -> None: 18 | click.echo(f"Using device: {device}") 19 | infer(finetune_output_dir=finetune_output_dir, drug_seq=drug_seq, device=device) 20 | 21 | 22 | def infer(finetune_output_dir: str, drug_seq: str, device: str) -> dict: 23 | """ 24 | :param finetune_output_dir: model_dir argument in fine-tuning 25 | :param drug_seq: smiles sequence of a drug 26 | """ 27 | # Load tokenizer from the checkpoint dir. 28 | # NOTE It's important to load the tokenizer from the fine-tuning phase, since we've introduced a new token to it. 29 | tokenizer_op = ModularTokenizerOp.from_pretrained( 30 | os.path.join(finetune_output_dir, "tokenizer") 31 | ) 32 | 33 | # Load model from the best checkpoint. 34 | # NOTE The total order of the checkpoints is induced by the monitored metric (see config) 35 | nn_model = Mammal.from_pretrained( 36 | pretrained_model_name_or_path=os.path.join( 37 | finetune_output_dir, "best_epoch.ckpt" 38 | ) 39 | ) 40 | nn_model.eval() 41 | nn_model.to(device=device) 42 | 43 | # Format the input drug sequence value into a prompt that fits MAMMAL's training paradigm. 44 | sample_dict = {"drug_seq": drug_seq} 45 | sample_dict = CarcinogenicityTask.data_preprocessing( 46 | sample_dict=sample_dict, 47 | sequence_key="drug_seq", 48 | tokenizer_op=tokenizer_op, 49 | device=nn_model.device, 50 | ) 51 | 52 | # running in generate mode 53 | batch_dict = nn_model.generate( 54 | [sample_dict], 55 | output_scores=True, 56 | return_dict_in_generate=True, 57 | max_new_tokens=5, 58 | ) 59 | 60 | # Post-process the model's output 61 | ans = CarcinogenicityTask.process_model_output( 62 | tokenizer_op=tokenizer_op, 63 | decoder_output=batch_dict[CLS_PRED][0], 64 | decoder_output_scores=batch_dict[SCORES][0], 65 | ) 66 | 67 | # Print prediction 68 | print(ans) 69 | return ans 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /mammal/examples/protein_solubility/config.yaml: -------------------------------------------------------------------------------- 1 | name: mammal_solubility_finetune 2 | root: "." 3 | model_dir: ${root}/${name} 4 | seed: 1234 5 | 6 | task: 7 | _target_: mammal.examples.protein_solubility.task.ProteinSolubilityTask 8 | _partial_: True 9 | # details about the arguments below can be found on mammal.examples.protein_solubility.task.ProteinSolubilityTask() 10 | name: solubility_prediction 11 | seed: ${seed} 12 | data_module_kwargs: 13 | # details about the arguments below can be found on mammal.examples.protein_solubility.pl_data_module.ProteinSolubilityDataModule() 14 | data_path: ./example_solubility_data 15 | 16 | train_dl_kwargs: # Dataloader constructor parameters 17 | num_workers: 8 18 | valid_dl_kwargs: # Dataloader constructor parameters 19 | num_workers: 8 20 | 21 | batch_size: 6 22 | protein_max_seq_length: 1250 23 | encoder_input_max_seq_len: 1260 24 | labels_max_seq_len: 4 25 | 26 | # tokenizer 27 | tokenizer: 28 | tokenizer_path: ibm/biomed.omics.bl.sm.ma-ted-458m 29 | 30 | model: 31 | mammal_kwargs: null # arguments for Mammal.__init__() 32 | pretrained_kwargs: # arguments for Mammal.from_pretrained() which triggered only if pretrained_model_name_or_path is not None 33 | pretrained_model_name_or_path: ibm/biomed.omics.bl.sm.ma-ted-458m 34 | 35 | # lightning module 36 | module: 37 | opt_callable: 38 | _target_: torch.optim.AdamW 39 | _partial_: true 40 | # arguments for torch.optim.AdamW() 41 | lr: 0.00001 42 | 43 | lr_sch_callable: 44 | _target_: mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler 45 | _partial_: True 46 | # arguments for mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler() 47 | T_max: 20000 48 | eta_min_factor: 0.1 49 | 50 | model_dir: ${model_dir} 51 | best_epoch_source: 52 | # arguments for pytorch_lightning.callbacks.ModelCheckpoint() 53 | monitor: validation.metrics.solubility_prediction_acc # possible options are validation.metrics._ 54 | 55 | mode: max 56 | 57 | # train 58 | trainer: 59 | # arguments for pytorch_lightning.Trainer() 60 | max_epochs: 1000 61 | default_root_dir: ${model_dir} 62 | accelerator: "auto" 63 | devices: 1 64 | num_nodes: 1 65 | strategy: "ddp_find_unused_parameters_true" 66 | use_distributed_sampler: False # Must be set when using a batch sampler 67 | num_sanity_val_steps: 0 68 | # limit_train_batches: 128 69 | # limit_val_batches: 128 70 | 71 | # experiment tracker 72 | track_clearml: # arguments for fuse.dl.lightning.pl_funcs.start_clearml_logger 73 | project_name: "mammal/opensource" 74 | task_name: ${name} 75 | tags: "mammal" 76 | reuse_last_task_id: True 77 | continue_last_task: False 78 | offline_mode: False 79 | 80 | evaluate : false #if true then it will use lightning's validate on the test dataloader 81 | 82 | hydra: 83 | run: 84 | dir: ${model_dir} 85 | job: 86 | chdir: False 87 | -------------------------------------------------------------------------------- /mammal/examples/protein_solubility/main_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import click 4 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 5 | 6 | from mammal.examples.protein_solubility.task import ProteinSolubilityTask 7 | from mammal.keys import CLS_PRED, SCORES 8 | from mammal.model import Mammal 9 | 10 | 11 | @click.command() 12 | @click.option( 13 | "--finetune_output_dir", 14 | default=None, 15 | help="Specify the model dir (default: None to download from huggingface).", 16 | ) 17 | @click.argument( 18 | "protein_seq", 19 | default="NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC", 20 | ) 21 | @click.option( 22 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')." 23 | ) 24 | def main(finetune_output_dir: str, protein_seq: str, device: str): 25 | protein_solubility_infer( 26 | finetune_output_dir=finetune_output_dir, protein_seq=protein_seq, device=device 27 | ) 28 | 29 | 30 | def protein_solubility_infer( 31 | finetune_output_dir: str | None, protein_seq: str, device: str 32 | ): 33 | """ 34 | :param finetune_output_dir: model_dir argument in finetuning or None for downloading from huggingface 35 | :param protein_seq: amino acid sequence of a protein 36 | """ 37 | if finetune_output_dir is not None: 38 | # load tokenizer and model from finetune_output_dir 39 | tokenizer_op = ModularTokenizerOp.from_pretrained( 40 | os.path.join(finetune_output_dir, "tokenizer") 41 | ) 42 | nn_model = Mammal.from_pretrained( 43 | pretrained_model_name_or_path=os.path.join( 44 | finetune_output_dir, "best_epoch.ckpt" 45 | ) 46 | ) 47 | else: 48 | # load tokenizer and model from huggingface 49 | tokenizer_op = ModularTokenizerOp.from_pretrained( 50 | "ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility" 51 | ) 52 | nn_model = Mammal.from_pretrained( 53 | "ibm/biomed.omics.bl.sm.ma-ted-458m.protein_solubility" 54 | ) 55 | nn_model.eval() 56 | nn_model.to(device=device) 57 | 58 | # convert to MAMMAL style 59 | sample_dict = {"protein_seq": protein_seq} 60 | sample_dict = ProteinSolubilityTask.data_preprocessing( 61 | sample_dict=sample_dict, 62 | protein_sequence_key="protein_seq", 63 | tokenizer_op=tokenizer_op, 64 | device=nn_model.device, 65 | ) 66 | 67 | # running in generate mode 68 | batch_dict = nn_model.generate( 69 | [sample_dict], 70 | output_scores=True, 71 | return_dict_in_generate=True, 72 | max_new_tokens=5, 73 | ) 74 | 75 | # Post-process the model's output 76 | ans = ProteinSolubilityTask.process_model_output( 77 | tokenizer_op=tokenizer_op, 78 | decoder_output=batch_dict[CLS_PRED][0], 79 | decoder_output_scores=batch_dict[SCORES][0], 80 | ) 81 | 82 | # Print prediction 83 | print(ans) 84 | return ans 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /mammal/examples/dti_bindingdb_kd/config.yaml: -------------------------------------------------------------------------------- 1 | name: mammal_tdc_dti_bindingdb_kd 2 | root: "." 3 | model_dir: ${root}/${name} 4 | seed: 1234 5 | 6 | 7 | task: 8 | _target_: mammal.examples.dti_bindingdb_kd.task.DtiBindingdbKdTask 9 | _partial_: True 10 | # details about the arguments below can be found on mammal.examples.dti_bindingdb_kd.task.DtiBindingdbKdTask() 11 | name: dti_bindingdb_kd 12 | seed: ${seed} 13 | norm_y_mean: 5.79384684128215 14 | norm_y_std: 1.33808027428196 15 | 16 | data_module_kwargs: 17 | # details about the arguments below can be found on mammal.examples.dti_bindingdb_kd.pl_data_module.DtiBindingdbKdDataModule() 18 | load_datasets_kwargs: 19 | split_type: "cold_split" 20 | split_column: ["Drug", "Target"] 21 | train_dl_kwargs: # Dataloader constructor parameters 22 | num_workers: 8 23 | valid_dl_kwargs: # Dataloader constructor parameters 24 | num_workers: 8 25 | 26 | batch_size: 8 # over a100_80g 27 | target_max_seq_length: 1250 28 | drug_max_seq_length: 256 29 | encoder_input_max_seq_len: 1560 30 | 31 | 32 | # tokenizer 33 | tokenizer: 34 | tokenizer_path: ibm/biomed.omics.bl.sm.ma-ted-458m 35 | 36 | model: 37 | mammal_kwargs: null # arguments for Mammal.__init__() 38 | pretrained_kwargs: # arguments for Mammal.from_pretrained() which triggered only if pretrained_model_name_or_path is not None 39 | pretrained_model_name_or_path: ibm/biomed.omics.bl.sm.ma-ted-458m 40 | 41 | # lightning module 42 | module: 43 | opt_callable: 44 | _target_: torch.optim.AdamW 45 | _partial_: true 46 | # arguments for torch.optim.AdamW() 47 | lr: 0.00001 48 | 49 | lr_sch_callable: 50 | _target_: mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler 51 | _partial_: True 52 | # arguments for mammal.lr_schedulers.cosine_annealing_with_warmup_lr_scheduler() 53 | T_max: 100000 54 | eta_min_factor: 0.1 55 | 56 | model_dir: ${model_dir} 57 | best_epoch_source: 58 | # arguments for pytorch_lightning.callbacks.ModelCheckpoint() 59 | monitor: validation.losses.dti_bindingdb_kd_scalars_mse # possible options are validation.metrics._ 60 | 61 | mode: min 62 | 63 | # train 64 | trainer: 65 | # arguments for pytorch_lightning.Trainer() 66 | max_epochs: 1000 67 | default_root_dir: ${model_dir} 68 | accelerator: "auto" 69 | devices: 1 70 | num_nodes: 1 71 | strategy: "ddp_find_unused_parameters_true" 72 | use_distributed_sampler: False # Must be set when using a batch sampler 73 | num_sanity_val_steps: 0 74 | # limit_train_batches: 128 75 | # limit_val_batches: 128 76 | 77 | # experiment tracker 78 | track_clearml: # arguments for fuse.dl.lightning.pl_funcs.start_clearml_logger 79 | project_name: "mammal/opensource" 80 | task_name: ${name} 81 | tags: "mammal" 82 | reuse_last_task_id: True 83 | continue_last_task: False 84 | offline_mode: False 85 | 86 | evaluate : false #if true then it will use lightning's validate on the test dataloader 87 | 88 | hydra: 89 | run: 90 | dir: ${model_dir} 91 | job: 92 | chdir: False 93 | -------------------------------------------------------------------------------- /mammal/examples/tests/test_protein_solubility_prediction.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from pathlib import Path 3 | 4 | import hydra 5 | import pytest 6 | from hydra.core.global_hydra import GlobalHydra 7 | 8 | from mammal.examples.protein_solubility.main_infer import protein_solubility_infer 9 | from mammal.main_finetune import main as main_finetune 10 | 11 | TEST_CONFIG_DIRPATH = str(Path(__file__).parents[0] / "../protein_solubility") 12 | TEST_CONFIG_FILENAME = "config.yaml" 13 | 14 | 15 | @pytest.fixture(autouse=True, scope="session") 16 | def _clean_hydra() -> None: 17 | GlobalHydra.instance().clear() 18 | 19 | 20 | @pytest.fixture(scope="session") 21 | def model_dir(tmp_path_factory: pytest.TempPathFactory): 22 | if "ccc" not in socket.gethostname(): 23 | pytest.skip("Full tests requires resources") 24 | 25 | model_dir_path = tmp_path_factory.mktemp("test_protein_solubility") / "test" 26 | return model_dir_path 27 | 28 | 29 | def test_finetune(model_dir: str): 30 | print(model_dir) 31 | OVERRIDES = [ 32 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials. 33 | "trainer.max_epochs=2", # Small number for a faster run. 34 | "+trainer.limit_train_batches=3", # Small number for a faster run. 35 | "+trainer.limit_val_batches=2", # Small number for a faster run. 36 | f"model_dir={model_dir}", 37 | "task.data_module_kwargs.train_dl_kwargs.num_workers=0", # Using parallelization cause co 38 | "task.data_module_kwargs.valid_dl_kwargs.num_workers=0", # Using parallelization cause co 39 | "root=.", 40 | "name=sol_test", 41 | ] 42 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"): 43 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES) 44 | cfg = hydra.utils.instantiate(_cfg) 45 | main_finetune(cfg) 46 | 47 | 48 | def test_evaluate(model_dir: str): 49 | OVERRIDES = [ 50 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials. 51 | "trainer.max_epochs=1", # Small number for a faster run. 52 | "+trainer.limit_test_batches=10", # Small number for a faster run. 53 | f"model_dir={model_dir}", 54 | "task.data_module_kwargs.train_dl_kwargs.num_workers=0", # Using parallelization cause co 55 | "task.data_module_kwargs.valid_dl_kwargs.num_workers=0", # Using parallelization cause co 56 | "root=.", 57 | "name=sol_test", 58 | "evaluate=True", 59 | f"model.pretrained_kwargs.pretrained_model_name_or_path={model_dir}/best_epoch.ckpt", 60 | ] 61 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"): 62 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES) 63 | cfg = hydra.utils.instantiate(_cfg) 64 | main_finetune(cfg) 65 | 66 | 67 | def test_infer(model_dir: str): 68 | protein_solubility_infer( 69 | finetune_output_dir=model_dir, 70 | protein_seq="NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC", 71 | device="cpu", 72 | ) 73 | -------------------------------------------------------------------------------- /mammal/examples/tests/test_drug_carcinogenicity_classification.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from pathlib import Path 3 | 4 | import hydra 5 | import pytest 6 | from hydra.core.global_hydra import GlobalHydra 7 | 8 | from mammal.examples.carcinogenicity.main_infer import infer 9 | from mammal.main_finetune import main as main_finetune 10 | from mammal.model import Mammal 11 | 12 | TEST_CONFIG_DIRPATH = str(Path(__file__).parents[0] / "../carcinogenicity") 13 | TEST_CONFIG_FILENAME = "config.yaml" 14 | 15 | 16 | @pytest.fixture(autouse=True, scope="session") 17 | def _clean_hydra() -> None: 18 | GlobalHydra.instance().clear() 19 | 20 | 21 | @pytest.fixture(scope="session") 22 | def tmp_model_dir(tmp_path_factory): 23 | model_dir_path = tmp_path_factory.mktemp("test_carcinogenicity") / "test" 24 | return model_dir_path 25 | 26 | 27 | @pytest.fixture(scope="session") 28 | def finetuned_model_dir(tmp_model_dir: str): 29 | if "ccc" not in socket.gethostname(): 30 | pytest.skip("Full tests requires resources") 31 | model_dir = tmp_model_dir 32 | print(f"\n{model_dir=}") 33 | OVERRIDES = [ 34 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials. 35 | "trainer.max_epochs=2", # Small number for a faster run. 36 | "+trainer.limit_train_batches=3", # Small number for a faster run. 37 | "+trainer.limit_val_batches=2", # Small number for a faster run. 38 | f"model_dir={model_dir}", 39 | "root=.", 40 | "name=carcinogenicity_test", 41 | ] 42 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"): 43 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES) 44 | main_finetune(_cfg) 45 | return model_dir 46 | 47 | 48 | def test_finetune(finetuned_model_dir: Path): 49 | # the actual work is done in the fixture, here we just read the finetuned model. 50 | model = Mammal.from_pretrained( 51 | pretrained_model_name_or_path=str(finetuned_model_dir / "best_epoch.ckpt") 52 | ) 53 | assert model is not None 54 | 55 | 56 | def test_evaluate(finetuned_model_dir: str): 57 | OVERRIDES = [ 58 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials. 59 | "trainer.max_epochs=1", # Small number for a faster run. 60 | "+trainer.limit_test_batches=10", # Small number for a faster run. 61 | f"model_dir={finetuned_model_dir}", 62 | "root=.", 63 | "name=carcinogenicity_test", 64 | "evaluate=True", 65 | f"model.pretrained_kwargs.pretrained_model_name_or_path={finetuned_model_dir}/best_epoch.ckpt", 66 | ] 67 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"): 68 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES) 69 | # main_finetune does "evaluate" if evalute=True, as is set above. 70 | main_finetune(_cfg) 71 | 72 | 73 | def test_infer(finetuned_model_dir: str): 74 | infer( 75 | finetune_output_dir=finetuned_model_dir, 76 | drug_seq="CC(CCl)OC(C)CCl", 77 | device="cpu", 78 | ) 79 | -------------------------------------------------------------------------------- /mammal/lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from peft import LoraConfig, TaskType, get_peft_model 3 | from peft.utils import PeftType 4 | from transformers import PreTrainedModel 5 | 6 | 7 | def get_lora_model( 8 | model: PreTrainedModel, 9 | peft_type: str | PeftType | None = None, 10 | auto_mapping: dict | None = None, 11 | base_model_name_or_path: str | None = None, 12 | revision: str | None = None, 13 | task_type: str | TaskType | None = None, 14 | inference_mode: bool = False, 15 | r: int = 8, 16 | target_modules: list[str] | str | None = None, 17 | lora_alpha: int = 8, 18 | lora_dropout: float = 0, 19 | fan_in_fan_out: bool = False, 20 | bias: str = "none", 21 | modules_to_save: list[str] | None = None, 22 | init_lora_weights: bool = True, 23 | layers_to_transform: list[int] | None = None, 24 | layers_pattern: str | None = None, 25 | ) -> torch.nn.Module: 26 | """ 27 | Freeze model params and make lora config. Then convert model to lora. 28 | 29 | Args: 30 | model (`PreTrainedModel`): The model to convert to Lora. 31 | r (`int`): Lora attention dimension. 32 | target_modules (`Union[List[str],str]`): The names of the modules to apply Lora to. 33 | lora_alpha (`int`): The alpha parameter for Lora scaling. 34 | lora_dropout (`float`): The dropout probability for Lora layers. 35 | fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). 36 | For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.: 37 | bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' 38 | modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable 39 | and saved in the final checkpoint. 40 | layers_to_transform (`Union[List[int],int]`): 41 | The layer indexes to transform, if this argument is specified, it will apply the LoRA transformations on 42 | the layer indexes that are specified in this list. If a single integer is passed, it will apply the LoRA 43 | transformations on the layer at this index. 44 | layers_pattern (`str`): 45 | The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer 46 | pattern is not in the common layers pattern. 47 | """ 48 | 49 | # build lora config 50 | config = LoraConfig( 51 | peft_type=peft_type, 52 | auto_mapping=auto_mapping, 53 | base_model_name_or_path=base_model_name_or_path, 54 | revision=revision, 55 | task_type=task_type, 56 | inference_mode=inference_mode, 57 | r=r, 58 | target_modules=target_modules, 59 | lora_alpha=lora_alpha, 60 | lora_dropout=lora_dropout, 61 | fan_in_fan_out=fan_in_fan_out, 62 | bias=bias, 63 | modules_to_save=modules_to_save, 64 | init_lora_weights=init_lora_weights, 65 | layers_to_transform=layers_to_transform, 66 | layers_pattern=layers_pattern, 67 | ) 68 | 69 | model = get_peft_model(model, config) 70 | model.print_trainable_parameters() 71 | return model 72 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # configuration approach followed: 2 | # - whenever possible, prefer pyproject.toml 3 | # - for configurations insufficiently supported by pyproject.toml, use setup.cfg instead 4 | # - setup.py discouraged; minimal stub included only for compatibility with legacy tools 5 | 6 | 7 | [build-system] 8 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 9 | build-backend = "setuptools.build_meta" 10 | 11 | [project] 12 | name = "biomed-multi-alignment" 13 | description = "MAMMAL (Molecular Aligned Multi-Modal Architecture and Language), a flexible, multi-domain architecture with an adaptable task prompt syntax." 14 | authors = [ 15 | {name="IBM Research"}, 16 | {name="Moshe Raboh", email="moshiko.raboh@ibm.com"} 17 | ] 18 | version = "0.2.1" 19 | readme = "README.md" 20 | license = {file = "LICENSE.txt"} 21 | # due to how PEP 440 defines version matching, prefer [incl, excl) definitions like below: 22 | requires-python = ">=3.10, <3.13" 23 | dependencies = [ 24 | "fuse-med-ml==0.4.0", 25 | "tensorflow>=2.17", 26 | "peft", 27 | "tabulate", 28 | "clearml", 29 | "hydra-core", 30 | "pytest", 31 | ] 32 | 33 | [project.optional-dependencies] 34 | examples = [ 35 | "PyTDC", 36 | "anndata", 37 | "click", 38 | ] 39 | 40 | [project.urls] 41 | repository = "https://github.com/BiomedSciAI/biomed-multi-alignment" 42 | 43 | [tool.setuptools.packages] 44 | find = {} 45 | 46 | [tool.ruff] 47 | target-version = "py310" 48 | extend-include = ["*.ipynb"] 49 | 50 | # Activate all the rules that are pyupgrade-related 51 | lint.select = [ 52 | # "UP", # pyupgrade 53 | "D", # pydocstyle 54 | "PT", # pytest style checking 55 | "C4", # comprehensions style checking 56 | "PD", # pandas style checking 57 | "F", # pyflakes: is-literal 58 | "W605", # pycodestyle: invalid-escape-sequence 59 | "I", # isort 60 | ] 61 | # On top of the Google convention, disable `D417`, which requires 62 | # documentation for every function parameter. 63 | lint.ignore = [ 64 | "D100", # pydocstyle: Missing module docstring 65 | "D101", # pydocstyle: Missing module-level docstring 66 | "D102", # pydocstyle: Missing docstring in public module 67 | "D103", # pydocstyle: Missing class docstring 68 | "D105", # pydocstyle: Missing docstring in magic method 69 | "D107", # pydocstyle: Missing parameter descriptions in the docstring 70 | "D203", # pydocstyle: 1 blank line required before class docstring 71 | "D205", # pydocstyle: 1 blank line required between summary line and description 72 | "D212", # pydocstyle: Multi-line docstring summary should start at the first line 73 | "D401", # pydocstyle: First line should be in imperative mood 74 | "D417", # pydocstyle: Missing argument descriptions in the docstring 75 | "F841", # flake8: unused variable 76 | "PD011", # pandas do not use .values (false positives causing bugs in torch code) 77 | "PD015", # Use .merge method instead of pd.merge function. They have equivalent functionality. 78 | "PT011", #TODO remove 79 | "UP035", # TODO types. remove 80 | ] 81 | [lint.per-file-ignores] 82 | "__init__.py" = ["I001"] 83 | 84 | 85 | [tool.coverage.report] 86 | 87 | exclude_lines = ["pragma: no cover", "abc.abstractmethod", "@abstract"] 88 | 89 | #[tool.coverage.run] 90 | #omit = ["gene_benchmark/tests/*"] 91 | 92 | [tool.mypy] 93 | disable_error_code = [ 94 | "index", 95 | "override", 96 | "arg-type", 97 | "union-attr" 98 | ] 99 | -------------------------------------------------------------------------------- /mammal/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | from torch.optim.lr_scheduler import ConstantLR, LinearLR, MultiStepLR, SequentialLR 4 | from transformers import get_inverse_sqrt_schedule 5 | 6 | 7 | def inverse_sqrt_with_warmup_lr_scheduler( 8 | optimizer: Optimizer, 9 | num_warmup_steps: int = 10000, 10 | timescale: int = 1, 11 | last_epoch: int = -1, 12 | ) -> SequentialLR: 13 | """ 14 | T5 learning rate scheduler with the original hyper parameter (not necessarily fits out use case) 15 | 16 | ''' 17 | During pre-training, we use an “inverse square root” learning rate schedule: 1/sqrt(max(n, k)) 18 | where n is the current training iteration and k is the number of warm-up steps (set to 10^4 in all of our experiments). 19 | This sets a constant learning rate of 0.01 for the first 104 steps, then exponentially decays the learning rate until pre-training is over. 20 | ''' 21 | """ 22 | constant_sch = ConstantLR(optimizer, factor=1.0, total_iters=num_warmup_steps) 23 | inverse_sqrt_schedule = get_inverse_sqrt_schedule( 24 | optimizer=optimizer, 25 | num_warmup_steps=0, 26 | timescale=timescale, 27 | last_epoch=last_epoch, 28 | ) 29 | return SequentialLR( 30 | optimizer, 31 | schedulers=[constant_sch, inverse_sqrt_schedule], 32 | milestones=[num_warmup_steps], 33 | ) 34 | 35 | 36 | def cosine_annealing_with_warmup_lr_scheduler( 37 | optimizer: Optimizer, 38 | *, 39 | num_warmup_steps: int = 2000, 40 | start_factor: float = 0.3333333333333333, 41 | T_max: int = 40000, 42 | eta_min_factor: float = 0.1, 43 | ) -> SequentialLR: 44 | """ 45 | cosine annealing with warmup, followed by a constant learning rate 46 | num_warmup_steps: warmup period, number of steps until the learning rate reach the maximum 47 | start_factor: the factor to start with in warmup period 48 | T_max: number of steps until the cosine reach the minimum 49 | eta_min_factor: minimum learning factor of the cosine scheduler 50 | """ 51 | linear_sch = LinearLR( 52 | optimizer, start_factor=start_factor, total_iters=num_warmup_steps 53 | ) 54 | assert ( 55 | len(optimizer.param_groups) == 1 56 | ), f"this learning rate scheduler support single params group, got {optimizer.param_groups=}" 57 | 58 | initial_lr = [group["initial_lr"] for group in optimizer.param_groups][0] 59 | eta_min = eta_min_factor * initial_lr 60 | cosine_lr_sch = torch.optim.lr_scheduler.CosineAnnealingLR( 61 | optimizer, T_max=T_max, eta_min=eta_min 62 | ) 63 | multi_step_lr_sch = MultiStepLR( 64 | optimizer=optimizer, milestones=[0], gamma=eta_min_factor 65 | ) 66 | return SequentialLR( 67 | optimizer, 68 | schedulers=[linear_sch, cosine_lr_sch, multi_step_lr_sch], 69 | milestones=[num_warmup_steps, T_max], 70 | ) 71 | 72 | 73 | def multistep_with_warmup_lr_scheduler( 74 | optimizer: Optimizer, 75 | *, 76 | num_warmup_steps: int = 2000, 77 | start_factor: float = 0.3333333333333333, 78 | milestones: list[int], 79 | gamma: float = 0.1, 80 | ) -> MultiStepLR: 81 | """ 82 | multistep LR with warmup - used in finetunning 83 | """ 84 | linear_sch = LinearLR( 85 | optimizer, start_factor=start_factor, total_iters=num_warmup_steps 86 | ) 87 | multi_step_lr_sch = MultiStepLR( 88 | optimizer=optimizer, 89 | milestones=[m - num_warmup_steps for m in milestones], 90 | gamma=gamma, 91 | ) 92 | return SequentialLR( 93 | optimizer, 94 | schedulers=[linear_sch, multi_step_lr_sch], 95 | milestones=[num_warmup_steps], 96 | ) 97 | -------------------------------------------------------------------------------- /mammal/keys.py: -------------------------------------------------------------------------------- 1 | """ 2 | List the keys available for each task / expected from each task in this t5 multitask generic implementation 3 | """ 4 | 5 | # DATA 6 | SAMPLE_ID = "data.sample_id" 7 | 8 | # expected outputs foreach task - data pipeline 9 | ENCODER_INPUTS_TOKENS = "data.encoder_input_token_ids" # encoder input token ids 10 | ENCODER_INPUTS_STR = "data.query.encoder_input" # the original string representation of encoder input - used for debug 11 | ENCODER_INPUTS_ATTENTION_MASK = "data.encoder_input_attention_mask" # attention mask of the tokenized encoder input (output of the tokenizer) 12 | DECODER_INPUTS_STR = "data.query.decoder_input" # the original string representation of decoder input - used for debug 13 | DECODER_INPUTS_TOKENS = "data.decoder_input_token_ids" # decoder input token ids (decoder start token followed by labels token ids) 14 | DECODER_INPUTS_ATTENTION_MASK = "data.decoder_input_attention_mask" # attention mask of the tokenized decoder input (output of the tokenizer) 15 | LABELS_TOKENS = "data.labels_token_ids" # labels token ids 16 | LABELS_ATTENTION_MASK = "data.labels_attention_mask" # attention mask of the tokenized labels (output of the tokenizer) 17 | LABELS_STR = ( 18 | "data.query.labels" # the original string representation of labels - used for debug 19 | ) 20 | 21 | # adding custom embeddings 22 | ENCODER_INPUT_ADD_EMBEDDINGS = "data.encoder_input_add_embeddings" # optional, can be used to add (in additional to token_ids) custom embeddings 23 | 24 | 25 | # the list of keys each task (data_module) must add to sample_dict for training in encoder only mode 26 | DATA_KEYS_ENCODER = [ 27 | ENCODER_INPUTS_TOKENS, 28 | ENCODER_INPUTS_ATTENTION_MASK, 29 | ENCODER_INPUTS_STR, 30 | LABELS_TOKENS, 31 | LABELS_ATTENTION_MASK, 32 | LABELS_STR, 33 | SAMPLE_ID, 34 | ] 35 | 36 | # the list of keys each task (data_module) must add to sample_dict for training in encoder-decoder mode 37 | DATA_KEYS = DATA_KEYS_ENCODER + [DECODER_INPUTS_TOKENS, DECODER_INPUTS_ATTENTION_MASK] 38 | 39 | # MODEL 40 | # expected model outputs 41 | LOGITS = "model.out.logits" 42 | SCORES = "model.out.scores" # model output after softmax 43 | CLS_PRED = "model.out.cls_pred" # result argmax() to get teacher forcing prediction 44 | CE_LOSS = "model.out.loss" # cross-entropy loss calculated in t5 model 45 | ENCODER_LAST_HIDDEN_STATE = "model.out.encoder_last_hidden_state" # encoder head logits 46 | 47 | ######################################## 48 | #### related to scalars inputs/outputs 49 | ######################################## 50 | 51 | 52 | # logits of the scalars output prediction head - a single scalar is predicted per input element 53 | # active only in encoder-only mode! 54 | SCALARS_PREDICTION_HEAD_LOGITS = "model.out.scalars_prediction_logits" 55 | 56 | ENCODER_INPUTS_SCALARS = "data.encoder_input.scalars" 57 | # a float tensor with the values of scalars. A default value is used for elements that are not scalars. Its length is the number of input elements 58 | ENCODER_INPUTS_SCALARS_VALUES = ENCODER_INPUTS_SCALARS + ".values" 59 | # a boolean tensor with the values of scalars. True/False per element describes what are scalar elements . Its length is the number of input elements 60 | ENCODER_INPUTS_SCALARS_VALID_MASK = ENCODER_INPUTS_SCALARS + ".valid_mask" 61 | 62 | 63 | LABELS_SCALARS = "data.labels.scalars" 64 | # a float tensor with the values of scalars. A default value is used for elements that are not scalars. Its length is the number of input elements 65 | LABELS_SCALARS_VALUES = LABELS_SCALARS + ".values" 66 | # a boolean tensor with the values of scalars. True/False per element describes what are scalar elements . Its length is the number of input elements 67 | LABELS_SCALARS_VALID_MASK = LABELS_SCALARS + ".valid_mask" 68 | -------------------------------------------------------------------------------- /tutorials/begginer_inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Inference using MAMMAL" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Install `biomed-multi-alignment` package. One can also clone and install it in editable model." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "!pip install biomed-multi-alignment" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "Run simplest inference script" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import torch\n", 40 | "from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp\n", 41 | "from mammal.model import Mammal\n", 42 | "from mammal.keys import *" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "# Check if CUDA is available\n", 52 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 53 | "print(f\"Using device: {device}\")\n", 54 | "\n", 55 | "# Load Model and set it to evaluation mode\n", 56 | "model = Mammal.from_pretrained(\"ibm/biomed.omics.bl.sm.ma-ted-458m\")\n", 57 | "model.eval()\n", 58 | "model.to(device=device)\n", 59 | "\n", 60 | "\n", 61 | "# Load Tokenizer\n", 62 | "tokenizer_op = ModularTokenizerOp.from_pretrained(\"ibm/biomed.omics.bl.sm.ma-ted-458m\")\n", 63 | "\n", 64 | "# Prepare Input Prompt\n", 65 | "protein_calmodulin = \"MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK\"\n", 66 | "protein_calcineurin = \"MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ\"\n", 67 | "\n", 68 | "# Create and load sample\n", 69 | "sample_dict = dict()\n", 70 | "# Formatting prompt to match pre-training syntax\n", 71 | "sample_dict[ENCODER_INPUTS_STR] = f\"<@TOKENIZER-TYPE=AA>{protein_calmodulin}{protein_calcineurin}\"\n", 72 | "\n", 73 | "# Tokenize\n", 74 | "tokenizer_op(\n", 75 | " sample_dict=sample_dict,\n", 76 | " key_in=ENCODER_INPUTS_STR,\n", 77 | " key_out_tokens_ids=ENCODER_INPUTS_TOKENS,\n", 78 | " key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK,\n", 79 | ")\n", 80 | "sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(sample_dict[ENCODER_INPUTS_TOKENS]).to(device=device)\n", 81 | "sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(sample_dict[ENCODER_INPUTS_ATTENTION_MASK]).to(device=device)\n", 82 | "\n", 83 | "# Generate Prediction\n", 84 | "batch_dict = model.generate(\n", 85 | " [sample_dict],\n", 86 | " output_scores=True,\n", 87 | " return_dict_in_generate=True,\n", 88 | " max_new_tokens=5,\n", 89 | ")\n", 90 | "\n", 91 | "# Get output\n", 92 | "generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0])\n", 93 | "print(f\"{generated_output=}\")" 94 | ] 95 | } 96 | ], 97 | "metadata": { 98 | "language_info": { 99 | "name": "python" 100 | } 101 | }, 102 | "nbformat": 4, 103 | "nbformat_minor": 2 104 | } 105 | -------------------------------------------------------------------------------- /mammal/examples/tests/test_dti_bindingdb_kd.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from pathlib import Path 3 | 4 | import hydra 5 | import pytest 6 | from hydra.core.global_hydra import GlobalHydra 7 | 8 | from mammal.examples.dti_bindingdb_kd.main_infer import dti_bindingdb_kd_infer 9 | from mammal.main_finetune import main as main_finetune 10 | 11 | TEST_CONFIG_DIRPATH = str(Path(__file__).parents[0] / "../dti_bindingdb_kd") 12 | TEST_CONFIG_FILENAME = "config.yaml" 13 | 14 | 15 | @pytest.fixture(autouse=True, scope="session") 16 | def _clean_hydra() -> None: 17 | GlobalHydra.instance().clear() 18 | 19 | 20 | @pytest.fixture(scope="session") 21 | def model_dir(tmp_path_factory: pytest.TempPathFactory) -> str: 22 | if "ccc" not in socket.gethostname(): 23 | pytest.skip("Full tests requires resources") 24 | 25 | model_dir_path = tmp_path_factory.mktemp("test_dti_bindingdb_kd") / "test" 26 | return model_dir_path 27 | 28 | 29 | @pytest.mark.xfail(reason="tokenizer only available on the CCC for now") 30 | def test_finetune(model_dir: str) -> None: 31 | print(model_dir) 32 | OVERRIDES = [ 33 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials. 34 | "trainer.accelerator=auto", 35 | "trainer.max_epochs=2", # Small number for a faster run. 36 | "+trainer.limit_train_batches=3", # Small number for a faster run. This is not applicable for validation mb per epoch! (WHY?) 37 | "+trainer.limit_val_batches=2", # Small number for a faster run. This is not applicable for validation mb per epoch! (WHY?) 38 | "task.data_module_kwargs.batch_size=1", 39 | f"model_dir={model_dir}", 40 | "task.data_module_kwargs.train_dl_kwargs.num_workers=0", # Using parallelization cause co 41 | "task.data_module_kwargs.valid_dl_kwargs.num_workers=0", # Using parallelization cause co 42 | "root=.", 43 | "name=dti_bindingdb_kd_test", 44 | ] 45 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"): 46 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES) 47 | cfg = hydra.utils.instantiate(_cfg) 48 | main_finetune(cfg) 49 | 50 | 51 | @pytest.mark.xfail(reason="tokenizer only available on the CCC for now") 52 | def test_evaluate(model_dir: str): 53 | OVERRIDES = [ 54 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials. 55 | # "train.trainer.accelerator=cpu", # Travis doesn't have GPU. 56 | "trainer.accelerator=auto", 57 | "trainer.max_epochs=1", 58 | "task.data_module_kwargs.batch_size=1", 59 | "+trainer.limit_test_batches=10", # Small number for a faster run. This is not applicable for validation mb per epoch! (WHY?) 60 | f"model_dir={model_dir}", 61 | "task.data_module_kwargs.train_dl_kwargs.num_workers=0", # Using parallelization cause co 62 | "task.data_module_kwargs.valid_dl_kwargs.num_workers=0", # Using parallelization cause co 63 | "root=.", 64 | "name=dti_bindingdb_kd_test", 65 | "evaluate=True", 66 | f"model.pretrained_kwargs.pretrained_model_name_or_path={model_dir}/best_epoch.ckpt", 67 | ] 68 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"): 69 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES) 70 | cfg = hydra.utils.instantiate(_cfg) 71 | main_finetune(cfg) 72 | 73 | 74 | @pytest.mark.xfail(reason="tokenizer only available on the CCC for now") 75 | def test_infer(model_dir: str): 76 | dti_bindingdb_kd_infer( 77 | finetune_output_dir=model_dir, 78 | target_seq="NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC", 79 | drug_seq="CC(=O)NCCC1=CNc2c1cc(OC)cc2", 80 | norm_y_mean=0.0, 81 | norm_y_std=1.0, 82 | device="cpu", 83 | ) 84 | -------------------------------------------------------------------------------- /.github/workflows/clone.yml: -------------------------------------------------------------------------------- 1 | name: GitHub Clone Count Update Everyday 2 | 3 | on: 4 | schedule: 5 | - cron: "0 */24 * * *" 6 | workflow_dispatch: 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | 15 | - name: gh login 16 | run: echo "${{ secrets.SECRET_TOKEN }}" | gh auth login --with-token 17 | 18 | - name: parse latest clone count 19 | run: | 20 | curl --user "${{ github.actor }}:${{ secrets.SECRET_TOKEN }}" \ 21 | -H "Accept: application/vnd.github.v3+json" \ 22 | https://api.github.com/repos/${{ github.repository }}/traffic/clones \ 23 | > clone.json 24 | 25 | - name: create gist and download previous count 26 | id: set_id 27 | run: | 28 | if gh secret list | grep -q "GIST_ID" 29 | then 30 | echo "GIST_ID found" 31 | echo "GIST=${{ secrets.GIST_ID }}" >> $GITHUB_OUTPUT 32 | curl https://gist.githubusercontent.com/${{ github.actor }}/${{ secrets.GIST_ID }}/raw/clone.json > clone_before.json 33 | if cat clone_before.json | grep '404: Not Found'; then 34 | echo "GIST_ID not valid anymore. Creating another gist..." 35 | gist_id=$(gh gist create clone.json | awk -F / '{print $NF}') 36 | echo $gist_id | gh secret set GIST_ID 37 | echo "GIST=$gist_id" >> $GITHUB_OUTPUT 38 | cp clone.json clone_before.json 39 | git rm --ignore-unmatch CLONE.md 40 | fi 41 | else 42 | echo "GIST_ID not found. Creating a gist..." 43 | gist_id=$(gh gist create clone.json | awk -F / '{print $NF}') 44 | echo $gist_id | gh secret set GIST_ID 45 | echo "GIST=$gist_id" >> $GITHUB_OUTPUT 46 | cp clone.json clone_before.json 47 | fi 48 | 49 | - name: update clone.json 50 | run: | 51 | curl https://raw.githubusercontent.com/MShawon/github-clone-count-badge/master/main.py > main.py 52 | python3 main.py 53 | 54 | - name: Update gist with latest count 55 | run: | 56 | content=$(sed -e 's/\\/\\\\/g' -e 's/\t/\\t/g' -e 's/\"/\\"/g' -e 's/\r//g' "clone.json" | sed -E ':a;N;$!ba;s/\r{0,1}\n/\\n/g') 57 | echo '{"description": "${{ github.repository }} clone statistics", "files": {"clone.json": {"content": "'"$content"'"}}}' > post_clone.json 58 | curl -s -X PATCH \ 59 | --user "${{ github.actor }}:${{ secrets.SECRET_TOKEN }}" \ 60 | -H "Content-Type: application/json" \ 61 | -d @post_clone.json https://api.github.com/gists/${{ steps.set_id.outputs.GIST }} > /dev/null 2>&1 62 | 63 | if [ ! -f CLONE.md ]; then 64 | shields="https://img.shields.io/badge/dynamic/json?color=success&label=Clone&query=count&url=" 65 | url="https://gist.githubusercontent.com/${{ github.actor }}/${{ steps.set_id.outputs.GIST }}/raw/clone.json" 66 | repo="https://github.com/MShawon/github-clone-count-badge" 67 | echo ''> CLONE.md 68 | echo ' 69 | **Markdown** 70 | 71 | ```markdown' >> CLONE.md 72 | echo "[![GitHub Clones]($shields$url&logo=github)]($repo)" >> CLONE.md 73 | echo ' 74 | ``` 75 | 76 | **HTML** 77 | ```html' >> CLONE.md 78 | echo "GitHub Clones" >> CLONE.md 79 | echo '```' >> CLONE.md 80 | 81 | git add CLONE.md 82 | git config --global user.name "GitHub Action" 83 | git config --global user.email "action@github.com" 84 | git commit -m "create clone count badge" 85 | fi 86 | 87 | - name: Push 88 | uses: ad-m/github-push-action@master 89 | with: 90 | github_token: ${{ secrets.GITHUB_TOKEN }} 91 | -------------------------------------------------------------------------------- /mammal_mcp/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | .env 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /mammal/examples/dti_bindingdb_kd/main_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import click 4 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 5 | 6 | from mammal.examples.dti_bindingdb_kd.task import DtiBindingdbKdTask 7 | from mammal.model import Mammal 8 | 9 | 10 | @click.command() 11 | @click.option( 12 | "--finetune_output_dir", 13 | default=None, 14 | help="Specify the model dir (default: None to download from huggingface).", 15 | ) 16 | @click.argument( 17 | "target_seq", 18 | default="NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC", 19 | ) 20 | @click.argument( 21 | "drug_seq", 22 | default="CC(=O)NCCC1=CNc2c1cc(OC)cc2", 23 | ) 24 | @click.argument("norm_y_mean", default=5.79384684128215, type=float) 25 | @click.argument("norm_y_std", default=1.33808027428196, type=float) 26 | @click.option( 27 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')." 28 | ) 29 | def main( 30 | finetune_output_dir: str | None, 31 | target_seq: str, 32 | drug_seq: str, 33 | norm_y_mean: float, 34 | norm_y_std: float, 35 | device: str, 36 | ): 37 | dti_bindingdb_kd_infer( 38 | finetune_output_dir=finetune_output_dir, 39 | target_seq=target_seq, 40 | drug_seq=drug_seq, 41 | norm_y_mean=norm_y_mean, 42 | norm_y_std=norm_y_std, 43 | device=device, 44 | ) 45 | 46 | 47 | def dti_bindingdb_kd_infer( 48 | finetune_output_dir: str, 49 | target_seq: str, 50 | drug_seq: str, 51 | norm_y_mean: float, 52 | norm_y_std: float, 53 | device: str, 54 | ): 55 | """ 56 | :param finetune_output_dir: model_dir argument in fine-tuning or None for downloading from huggingface 57 | :param target_seq: amino acid sequence of a target 58 | :param drug_seq: smiles representation of a drug 59 | :param norm_y_mean: specify the mean and std values used in fine-tuning 60 | :param norm_y_std: specify the mean and std values used in fine-tuning 61 | """ 62 | if finetune_output_dir is not None: 63 | # load tokenizer and model from finetune_output_dir 64 | tokenizer_op = ModularTokenizerOp.from_pretrained( 65 | os.path.join(finetune_output_dir, "tokenizer") 66 | ) 67 | nn_model = Mammal.from_pretrained( 68 | pretrained_model_name_or_path=os.path.join( 69 | finetune_output_dir, "best_epoch.ckpt" 70 | ) 71 | ) 72 | else: 73 | # load tokenizer and model from huggingface 74 | tokenizer_op = ModularTokenizerOp.from_pretrained( 75 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd" 76 | ) 77 | nn_model = Mammal.from_pretrained( 78 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd" 79 | ) 80 | nn_model.eval() 81 | nn_model.to(device=device) 82 | 83 | # convert to MAMMAL style 84 | sample_dict = {"target_seq": target_seq, "drug_seq": drug_seq} 85 | sample_dict = DtiBindingdbKdTask.data_preprocessing( 86 | sample_dict=sample_dict, 87 | tokenizer_op=tokenizer_op, 88 | target_sequence_key="target_seq", 89 | drug_sequence_key="drug_seq", 90 | norm_y_mean=None, 91 | norm_y_std=None, 92 | device=nn_model.device, 93 | ) 94 | 95 | # forward pass - encoder_only mode which supports scalars predictions 96 | batch_dict = nn_model.forward_encoder_only([sample_dict]) 97 | 98 | # Post-process the model's output 99 | batch_dict = DtiBindingdbKdTask.process_model_output( 100 | batch_dict, 101 | scalars_preds_processed_key="model.out.dti_bindingdb_kd", 102 | norm_y_mean=norm_y_mean, 103 | norm_y_std=norm_y_std, 104 | ) 105 | ans = { 106 | "model.out.dti_bindingdb_kd": float(batch_dict["model.out.dti_bindingdb_kd"][0]) 107 | } 108 | 109 | # Print prediction 110 | print(ans) 111 | return ans 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /mammal/examples/carcinogenicity/pl_data_module.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | import pytorch_lightning as pl 4 | from fuse.data.datasets.dataset_default import DatasetDefault 5 | from fuse.data.ops.ops_read import OpReadDataframe 6 | from fuse.data.pipelines.pipeline_default import PipelineDefault 7 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 8 | from fuse.data.utils.collates import CollateDefault 9 | from tdc.single_pred.tox import Tox 10 | from torch.utils.data.dataloader import DataLoader 11 | 12 | from mammal.keys import * # noqa 13 | 14 | 15 | class CarcinogenicityDataModule(pl.LightningDataModule): 16 | def __init__( 17 | self, 18 | *, 19 | batch_size: int, 20 | tokenizer_op: ModularTokenizerOp, 21 | drug_max_seq_length: int, 22 | encoder_input_max_seq_len: int, 23 | data_preprocessing: Callable, 24 | labels_max_seq_len: int, 25 | ) -> None: 26 | super().__init__() 27 | self.tokenizer_op = tokenizer_op 28 | self.drug_max_seq_length = drug_max_seq_length 29 | self.encoder_input_max_seq_len = encoder_input_max_seq_len 30 | self.labels_max_seq_len = labels_max_seq_len 31 | self.batch_size = batch_size 32 | self.data_preprocessing = data_preprocessing 33 | self.pad_token_id = self.tokenizer_op.get_token_id("") 34 | 35 | def setup(self, stage: str) -> None: 36 | self.ds_dict = load_datasets() 37 | 38 | task_pipeline = [ 39 | ( 40 | self.data_preprocessing, 41 | dict( 42 | sequence_key="data.drug", 43 | label_key="data.label", 44 | tokenizer_op=self.tokenizer_op, 45 | encoder_input_max_seq_len=self.encoder_input_max_seq_len, 46 | labels_max_seq_len=self.labels_max_seq_len, 47 | ), 48 | ), 49 | ] 50 | 51 | for ds in self.ds_dict.values(): 52 | ds.dynamic_pipeline.extend(task_pipeline) 53 | 54 | def train_dataloader(self) -> DataLoader: 55 | train_loader = DataLoader( 56 | dataset=self.ds_dict["train"], 57 | batch_size=self.batch_size, 58 | collate_fn=CollateDefault(), 59 | shuffle=True, 60 | ) 61 | return train_loader 62 | 63 | def val_dataloader(self) -> DataLoader: 64 | val_loader = DataLoader( 65 | self.ds_dict["valid"], 66 | batch_size=self.batch_size, 67 | collate_fn=CollateDefault(), 68 | ) 69 | 70 | return val_loader 71 | 72 | def test_dataloader(self) -> DataLoader: 73 | test_loader = DataLoader( 74 | self.ds_dict["test"], 75 | batch_size=self.batch_size, 76 | collate_fn=CollateDefault(), 77 | ) 78 | 79 | return test_loader 80 | 81 | def predict_dataloader(self) -> DataLoader: 82 | return self.test_dataloader() 83 | 84 | 85 | def load_datasets(split_method: str = "random") -> dict[str, DatasetDefault]: 86 | data = Tox(name="Carcinogens_Lagunin") 87 | split = data.get_split(method=split_method) 88 | 89 | ds_dict = {} 90 | for set_name in ["train", "valid", "test"]: 91 | data_df = split[set_name] 92 | print(f"{set_name} set size is {len(data_df)}") 93 | size = len(data_df) 94 | 95 | dynamic_pipeline = PipelineDefault( 96 | "carcinogenicity", 97 | [ 98 | ( 99 | OpReadDataframe( 100 | data_df, 101 | key_column=None, 102 | rename_columns={"Drug": "data.drug", "Y": "data.label"}, 103 | ), 104 | dict(), 105 | ), 106 | ], 107 | ) 108 | 109 | ds = DatasetDefault(sample_ids=size, dynamic_pipeline=dynamic_pipeline) 110 | ds.create() 111 | ds_dict[set_name] = ds 112 | 113 | return ds_dict 114 | 115 | 116 | if __name__ == "__main__": 117 | ds = load_datasets() 118 | print(ds["train"][0]) 119 | print(ds["test"][0]) 120 | -------------------------------------------------------------------------------- /mammal_mcp/dependencies.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from contextlib import asynccontextmanager 4 | 5 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 6 | 7 | # bmfm imports 8 | from mammal.model import Mammal 9 | 10 | logging.basicConfig( 11 | level=logging.WARNING, 12 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 13 | handlers=[logging.StreamHandler()], 14 | ) 15 | 16 | # Create a logger 17 | logger = logging.getLogger("mammal") 18 | logger.setLevel(logging.DEBUG) 19 | 20 | assets = {} 21 | 22 | 23 | @asynccontextmanager 24 | async def lifespan(): 25 | if os.getenv("PROTEIN_PROTEIN_INTERACTION") == "true": 26 | # Load Model 27 | logger.info("downloading: ibm/biomed.omics.bl.sm.ma-ted-458m") 28 | model = Mammal.from_pretrained( 29 | "ibm/biomed.omics.bl.sm.ma-ted-458m", cache_dir="model_cache" 30 | ) 31 | assets["model"] = model.eval() 32 | logger.info("completed the download: ibm/biomed.omics.bl.sm.ma-ted-458m") 33 | 34 | logger.info("downloading for tokenizer: ibm/biomed.omics.bl.sm.ma-ted-458m") 35 | assets["tokenizer_op"] = ModularTokenizerOp.from_pretrained( 36 | "ibm/biomed.omics.bl.sm.ma-ted-458m", cache_dir="model_cache" 37 | ) 38 | 39 | if os.getenv("PROTEIN_SOLUBILITY") == "true": 40 | logger.info( 41 | "downloading: ibm-research/biomed.omics.bl.sm.ma-ted-458m.protein_solubility" 42 | ) 43 | protein_solubility_model = Mammal.from_pretrained( 44 | "ibm-research/biomed.omics.bl.sm.ma-ted-458m.protein_solubility", 45 | cache_dir="model_cache", 46 | ) 47 | assets["protein_solubility_model"] = protein_solubility_model.eval() 48 | assets["protein_solubility_tokenizer_op"] = ModularTokenizerOp.from_pretrained( 49 | "ibm-research/biomed.omics.bl.sm.ma-ted-458m.protein_solubility", 50 | cache_dir="model_cache", 51 | ) 52 | 53 | if os.getenv("PROTEIN_DRUG_INTERACTION_MODEL") == "true": 54 | logger.info("downloading: ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd") 55 | protein_drug_interaction_model = Mammal.from_pretrained( 56 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd", 57 | cache_dir="model_cache", 58 | ) 59 | assets["protein_drug_interaction_model"] = protein_drug_interaction_model.eval() 60 | assets["protein_drug_interaction_tokenizer_op"] = ( 61 | ModularTokenizerOp.from_pretrained( 62 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd", 63 | cache_dir="model_cache", 64 | ) 65 | ) 66 | 67 | if os.getenv("TCR_EPITOPE_BINDING") == "true": 68 | logger.info( 69 | "downloading: ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind" 70 | ) 71 | 72 | tcr_epitope_model = Mammal.from_pretrained( 73 | "ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind", 74 | cache_dir="model_cache", 75 | ) 76 | #  set to eval/inference mode 77 | tcr_epitope_model.eval() 78 | #  set model to model mode (use 'cuda' if GPU avail) 79 | tcr_epitope_model.to(device="cpu") 80 | assets["tcr_epitope_model"] = tcr_epitope_model 81 | 82 | assets["tcr_epitope_model_tokenizer_op"] = ModularTokenizerOp.from_pretrained( 83 | "ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind", 84 | cache_dir="model_cache", 85 | ) 86 | 87 | if ( 88 | os.getenv("DRUG_TARGET_BINDING") == "true" 89 | or os.getenv("DRUG_TARGET_BINDING_FASTA") == "true" 90 | ): 91 | logger.info("downloading: ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd") 92 | 93 | drug_target_model = Mammal.from_pretrained( 94 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd", 95 | cache_dir="model_cache", 96 | ) 97 | #  set to eval/inference mode 98 | drug_target_model.eval() 99 | assets["drug_target_model"] = drug_target_model 100 | 101 | # download tokeniser 102 | assets["drug_target_model_tokeniser_op"] = ModularTokenizerOp.from_pretrained( 103 | "ibm/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd", 104 | cache_dir="model_cache", 105 | ) 106 | 107 | logger.info("Assets loaded") 108 | 109 | yield 110 | # Clean up the assets 111 | assets.clear() 112 | -------------------------------------------------------------------------------- /mammal/examples/molnet/molnet_infer.py: -------------------------------------------------------------------------------- 1 | import click 2 | import numpy as np 3 | import torch 4 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 5 | 6 | from mammal.keys import ( 7 | CLS_PRED, 8 | ENCODER_INPUTS_ATTENTION_MASK, 9 | ENCODER_INPUTS_STR, 10 | ENCODER_INPUTS_TOKENS, 11 | SCORES, 12 | ) 13 | from mammal.model import Mammal 14 | 15 | TASK_NAMES = ["BBBP", "TOXICITY", "FDA_APPR"] 16 | 17 | 18 | @click.command() 19 | @click.argument("task_name", default="BBBP") 20 | @click.argument( 21 | "smiles_seq", 22 | default="C(Cl)Cl", 23 | ) 24 | @click.option( 25 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')." 26 | ) 27 | def main(task_name: str, smiles_seq: str, device: str): 28 | task_dict = load_model(task_name=task_name, device=device) 29 | result = task_infer(task_dict=task_dict, smiles_seq=smiles_seq) 30 | print(f"The prediction for {smiles_seq=} is {result}") 31 | 32 | 33 | def load_model(task_name: str, device: str) -> dict: 34 | match task_name: 35 | case "BBBP": 36 | path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_bbbp" 37 | case "TOXICITY": 38 | path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_tox" 39 | case "FDA_APPR": 40 | path = "ibm/biomed.omics.bl.sm.ma-ted-458m.moleculenet_clintox_fda" 41 | case _: 42 | print(f"The {task_name=} is incorrect") 43 | 44 | # Load Model and set to evaluation mode 45 | model = Mammal.from_pretrained(path) 46 | model.eval() 47 | model.to(device=device) 48 | 49 | # Load Tokenizer 50 | tokenizer_op = ModularTokenizerOp.from_pretrained(path) 51 | 52 | task_dict = dict( 53 | task_name=task_name, 54 | model=model, 55 | tokenizer_op=tokenizer_op, 56 | ) 57 | return task_dict 58 | 59 | 60 | def process_model_output( 61 | tokenizer_op: ModularTokenizerOp, 62 | decoder_output: np.ndarray, 63 | decoder_output_scores: np.ndarray, 64 | ) -> dict: 65 | """ 66 | Extract predicted class and scores 67 | """ 68 | negative_token_id = tokenizer_op.get_token_id("<0>") 69 | positive_token_id = tokenizer_op.get_token_id("<1>") 70 | label_id_to_int = { 71 | negative_token_id: 0, 72 | positive_token_id: 1, 73 | } 74 | classification_position = 1 75 | 76 | if decoder_output_scores is not None: 77 | scores = decoder_output_scores[classification_position, positive_token_id] 78 | 79 | ans = dict( 80 | pred=label_id_to_int.get(int(decoder_output[classification_position]), -1), 81 | score=scores.item(), 82 | ) 83 | return ans 84 | 85 | 86 | def task_infer(task_dict: dict, smiles_seq: str) -> dict: 87 | task_name = task_dict["task_name"] 88 | model = task_dict["model"] 89 | tokenizer_op = task_dict["tokenizer_op"] 90 | 91 | if task_name not in TASK_NAMES: 92 | print(f"The {task_name=} is incorrect. Valid names are {TASK_NAMES}") 93 | 94 | sample_dict = create_sample_dict(task_name, smiles_seq, tokenizer_op, model) 95 | # Generate Prediction 96 | batch_dict = get_predictions(model, sample_dict) 97 | 98 | # Post-process the model's output 99 | result = process_model_output( 100 | tokenizer_op=tokenizer_op, 101 | decoder_output=batch_dict[CLS_PRED][0], 102 | decoder_output_scores=batch_dict[SCORES][0], 103 | ) 104 | return result 105 | 106 | 107 | def create_sample_dict(task_name, smiles_seq, tokenizer_op, model): 108 | 109 | # Create and load sample 110 | sample_dict = dict() 111 | # Formatting prompt to match pre-training syntax 112 | sample_dict[ENCODER_INPUTS_STR] = ( 113 | f"<@TOKENIZER-TYPE=SMILES><{task_name}><@TOKENIZER-TYPE=SMILES@MAX-LEN=2100>{smiles_seq}" 114 | ) 115 | 116 | # Tokenize 117 | tokenizer_op( 118 | sample_dict=sample_dict, 119 | key_in=ENCODER_INPUTS_STR, 120 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS, 121 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK, 122 | ) 123 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor( 124 | sample_dict[ENCODER_INPUTS_TOKENS], device=model.device 125 | ) 126 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor( 127 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=model.device 128 | ) 129 | return sample_dict 130 | 131 | 132 | def get_predictions(model, sample_dict): 133 | return model.generate( 134 | [sample_dict], 135 | output_scores=True, 136 | return_dict_in_generate=True, 137 | max_new_tokens=5, 138 | ) 139 | 140 | 141 | if __name__ == "__main__": 142 | main() 143 | -------------------------------------------------------------------------------- /mammal/examples/tcr_epitope_binding/main_infer.py: -------------------------------------------------------------------------------- 1 | import click 2 | import numpy as np 3 | import torch 4 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 5 | 6 | from mammal.keys import ( 7 | CLS_PRED, 8 | ENCODER_INPUTS_ATTENTION_MASK, 9 | ENCODER_INPUTS_STR, 10 | ENCODER_INPUTS_TOKENS, 11 | SCORES, 12 | ) 13 | from mammal.model import Mammal 14 | 15 | 16 | @click.command() 17 | @click.argument( 18 | "tcr_beta_seq", 19 | default="GAVVSQHPSWVICKSGTSVKIECRSLDFQATTMFWYRQFPKQSLMLMATSNEGSKATYEQGVEKDKFLINHASLTLSTLTVTSAHPEDSSFYICSASEGTSSYEQYFGPGTRLTVT", # alternative binder 1 20 | # Alternative binder 2: NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT 21 | ) 22 | @click.argument( 23 | "epitope_seq", 24 | default="FLKEKGGL", # alternative binder 1 25 | # Alternative binder 2: LLQTGIHVRVSQPSL 26 | ) 27 | @click.option( 28 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')." 29 | ) 30 | def main(tcr_beta_seq: str, epitope_seq: str, device: str): 31 | model, tokenizer_op = load_model(device=device) 32 | result = task_infer( 33 | model=model, 34 | tokenizer_op=tokenizer_op, 35 | tcr_beta_seq=tcr_beta_seq, 36 | epitope_seq=epitope_seq, 37 | ) 38 | print(f"The prediction for {epitope_seq} and {tcr_beta_seq} is {result}") 39 | 40 | 41 | def load_model( 42 | device: str, 43 | model_path: str = "ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind", # change to "ibm/biomed.omics.bl.sm.ma-ted-458m" to try on the base model 44 | tokenizer_path: str = "ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind", 45 | ) -> tuple["Mammal", "ModularTokenizerOp"]: 46 | 47 | # Load Model and set to evaluation mode 48 | model = Mammal.from_pretrained( 49 | pretrained_model_name_or_path=model_path, allow_config_mismatch=True 50 | ) 51 | model.eval() 52 | model.to(device=device) 53 | 54 | # Load Tokenizer 55 | tokenizer_op = ModularTokenizerOp.from_pretrained(tokenizer_path) 56 | 57 | return model, tokenizer_op 58 | 59 | 60 | def process_model_output( 61 | tokenizer_op: ModularTokenizerOp, 62 | decoder_output: np.ndarray, 63 | decoder_output_scores: np.ndarray, 64 | ) -> dict: 65 | """ 66 | Extract predicted class and scores 67 | """ 68 | negative_token_id = tokenizer_op.get_token_id("<0>") 69 | positive_token_id = tokenizer_op.get_token_id("<1>") 70 | label_id_to_int = { 71 | negative_token_id: 0, 72 | positive_token_id: 1, 73 | } 74 | classification_position = 1 75 | 76 | if decoder_output_scores is not None: 77 | scores = decoder_output_scores[classification_position, positive_token_id] 78 | 79 | ans = dict( 80 | pred=label_id_to_int.get(int(decoder_output[classification_position]), -1), 81 | score=scores.item(), 82 | ) 83 | return ans 84 | 85 | 86 | def task_infer( 87 | model: "Mammal", 88 | tokenizer_op: ModularTokenizerOp, 89 | tcr_beta_seq: str, 90 | epitope_seq: str, 91 | ) -> dict: 92 | treat_inputs_as_general_proteins = False 93 | 94 | # Create and load sample 95 | sample_dict = dict() 96 | # Formatting prompt to match pre-training syntax 97 | 98 | if treat_inputs_as_general_proteins: 99 | # Treat inputs as general proteins: 100 | sample_dict[ENCODER_INPUTS_STR] = ( 101 | f"<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=AA>{tcr_beta_seq}<@TOKENIZER-TYPE=AA>{epitope_seq}" 102 | ) 103 | else: 104 | # Treat inputs as TCR beta chain and epitope 105 | sample_dict[ENCODER_INPUTS_STR] = ( 106 | f"<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=AA>{tcr_beta_seq}<@TOKENIZER-TYPE=AA>{epitope_seq}" 107 | ) 108 | 109 | # Tokenize 110 | tokenizer_op( 111 | sample_dict=sample_dict, 112 | key_in=ENCODER_INPUTS_STR, 113 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS, 114 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK, 115 | ) 116 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor( 117 | sample_dict[ENCODER_INPUTS_TOKENS], device=model.device 118 | ) 119 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor( 120 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=model.device 121 | ) 122 | 123 | # Generate Prediction 124 | batch_dict = model.generate( 125 | [sample_dict], 126 | output_scores=True, 127 | return_dict_in_generate=True, 128 | max_new_tokens=5, 129 | ) 130 | 131 | # Post-process the model's output 132 | result = process_model_output( 133 | tokenizer_op=tokenizer_op, 134 | decoder_output=batch_dict[CLS_PRED][0], 135 | decoder_output_scores=batch_dict[SCORES][0], 136 | ) 137 | return result 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /mammal/examples/dti_bindingdb_kd/pl_data_module.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | 3 | import pytorch_lightning as pl 4 | from fuse.data.datasets.dataset_default import DatasetDefault 5 | from fuse.data.ops.ops_read import OpReadDataframe 6 | from fuse.data.pipelines.pipeline_default import PipelineDefault 7 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 8 | from fuse.data.tokenizers.modular_tokenizer.special_tokens import special_wrap_input 9 | from fuse.data.utils.collates import CollateDefault 10 | from tdc.multi_pred.dti import DTI 11 | from torch.utils.data.dataloader import DataLoader 12 | 13 | from mammal.keys import * # noqa 14 | 15 | 16 | class DtiBindingdbKdDataModule(pl.LightningDataModule): 17 | def __init__( 18 | self, 19 | *, 20 | batch_size: int, 21 | tokenizer_op: ModularTokenizerOp, 22 | train_dl_kwargs: dict, 23 | valid_dl_kwargs: dict, 24 | seed: int, 25 | data_preprocessing: Callable, 26 | target_max_seq_length: int, 27 | drug_max_seq_length: int, 28 | encoder_input_max_seq_len: int, 29 | load_datasets_kwargs: dict, 30 | ) -> None: 31 | super().__init__() 32 | self.tokenizer_op = tokenizer_op 33 | self.target_max_seq_length = target_max_seq_length 34 | self.drug_max_seq_length = drug_max_seq_length 35 | self.encoder_input_max_seq_len = encoder_input_max_seq_len 36 | self.batch_size = batch_size 37 | self.train_dl_kwargs = train_dl_kwargs 38 | self.valid_dl_kwargs = valid_dl_kwargs 39 | self.seed = seed 40 | self.data_preprocessing = data_preprocessing 41 | self.load_datasets_kwargs = load_datasets_kwargs 42 | 43 | self.pad_token_id = self.tokenizer_op.get_token_id(special_wrap_input("PAD")) 44 | 45 | def setup(self, stage: str) -> None: 46 | self.ds_dict = load_datasets(**self.load_datasets_kwargs) 47 | 48 | task_pipeline = [ 49 | ( 50 | # Prepare the input string(s) in modular tokenizer input format 51 | self.data_preprocessing, 52 | dict( 53 | target_sequence_key="Target", 54 | drug_sequence_key="Drug", 55 | ground_truth_key="Y", 56 | tokenizer_op=self.tokenizer_op, 57 | target_max_seq_length=self.target_max_seq_length, 58 | drug_max_seq_length=self.drug_max_seq_length, 59 | encoder_input_max_seq_len=self.encoder_input_max_seq_len, 60 | ), 61 | ), 62 | ] 63 | 64 | for ds in self.ds_dict.values(): 65 | ds.dynamic_pipeline.extend(task_pipeline) 66 | 67 | def train_dataloader(self) -> DataLoader: 68 | train_loader = DataLoader( 69 | dataset=self.ds_dict["train"], 70 | batch_size=self.batch_size, 71 | collate_fn=CollateDefault(add_to_batch_dict={"forward_mode": "encoder"}), 72 | shuffle=True, 73 | **self.train_dl_kwargs, 74 | ) 75 | return train_loader 76 | 77 | def val_dataloader(self) -> DataLoader: 78 | val_loader = DataLoader( 79 | self.ds_dict["valid"], 80 | batch_size=self.batch_size, 81 | collate_fn=CollateDefault(add_to_batch_dict={"forward_mode": "encoder"}), 82 | **self.valid_dl_kwargs, 83 | ) 84 | 85 | return val_loader 86 | 87 | def test_dataloader(self) -> DataLoader: 88 | test_loader = DataLoader( 89 | self.ds_dict["test"], 90 | batch_size=self.batch_size, 91 | collate_fn=CollateDefault(add_to_batch_dict={"forward_mode": "encoder"}), 92 | **self.valid_dl_kwargs, 93 | ) 94 | 95 | return test_loader 96 | 97 | def predict_dataloader(self) -> DataLoader: 98 | return self.test_dataloader() 99 | 100 | 101 | def load_datasets( 102 | split_type: str = "cold_split", split_column: list[str] | str = ["Drug", "Target"] 103 | ) -> dict[str, DatasetDefault]: 104 | """ 105 | Automatically downloads (using tdc) the data and create dataset iterator for "train", "val" and "test". 106 | :return: dictionary that maps fold name "train", "val" and "test" to a dataset iterator 107 | """ 108 | 109 | data = DTI(name="BindingDB_Kd") 110 | data.harmonize_affinities(mode="max_affinity") 111 | data.convert_to_log(form="binding") 112 | split = data.get_split(method=split_type, column_name=split_column) 113 | ds_dict = {} 114 | for set_name in ["train", "valid", "test"]: 115 | set_df = split[set_name] 116 | print(f"{set_name} set size is {len(set_df)}") 117 | print(f"{set_name=} {set_df.Y.mean()=} {set_df.Y.std()=}") 118 | dynamic_pipeline = PipelineDefault( 119 | "dti", 120 | [ 121 | ( 122 | OpReadDataframe( 123 | set_df, 124 | key_column=None, 125 | columns_to_extract=["Target", "Drug", "Y"], 126 | ), 127 | dict(), 128 | ), 129 | ], 130 | ) 131 | 132 | ds = DatasetDefault(sample_ids=len(set_df), dynamic_pipeline=dynamic_pipeline) 133 | ds.create() 134 | ds_dict[set_name] = ds 135 | 136 | return ds_dict 137 | 138 | 139 | if __name__ == "__main__": 140 | load_datasets() 141 | -------------------------------------------------------------------------------- /mammal/examples/tests/test_main_finetune.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from pathlib import Path 3 | 4 | import hydra 5 | import pytest 6 | import pytorch_lightning as pl 7 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 8 | from fuse.utils.multiprocessing.helpers import num_available_cores 9 | from omegaconf import OmegaConf 10 | 11 | from mammal.main_finetune import * 12 | 13 | # pylint: disable=W0621 14 | 15 | """_summary_ 16 | 17 | Testing examples 18 | """ 19 | 20 | TEST_CONFIG_DIRPATH = str(Path(__file__).parents[0] / "../protein_solubility") 21 | TEST_CONFIG_FILENAME = "config.yaml" 22 | 23 | STATIC_OVERRIDES = [ 24 | "track_clearml=null", # Travis cannot connect to ClearML at the moment. We might be able to fix it with a dedicated user + config credentials. 25 | "trainer.max_epochs=1", # Small number for a faster run. 26 | "+trainer.limit_train_batches=2", # Small number for a faster run. 27 | "+trainer.limit_val_batches=3", # Small number for a faster run. 28 | "+trainer.enable_checkpointing=False", # Do not checkpoint - saves memory 29 | "model_dir=null", 30 | ] 31 | 32 | OVERRIDES = [ 33 | "task.data_module_kwargs.train_dl_kwargs.num_workers=0", # Using parallelization cause co 34 | "task.data_module_kwargs.valid_dl_kwargs.num_workers=0", # Using parallelization cause co 35 | "root=.", 36 | "name=sol_test", 37 | "+tokenizer.new_special_tokens=['','special_token2',yet_another_special_token]", 38 | ] + STATIC_OVERRIDES 39 | 40 | 41 | @pytest.fixture(scope="session") 42 | def cfg_dict(): 43 | with hydra.initialize_config_dir(TEST_CONFIG_DIRPATH, version_base="1.1"): 44 | _cfg = hydra.compose(TEST_CONFIG_FILENAME, overrides=OVERRIDES) 45 | yield _cfg 46 | 47 | 48 | @pytest.fixture(scope="session") 49 | def cfg_obj(cfg_dict): 50 | OmegaConf.register_new_resolver("num_cores_auto", num_available_cores, replace=True) 51 | cfg_obj = hydra.utils.instantiate(cfg_dict) 52 | return cfg_obj 53 | 54 | 55 | @pytest.fixture(scope="session") 56 | def cfg(cfg_obj): 57 | return cfg_obj 58 | 59 | 60 | @pytest.fixture(scope="session") 61 | def clearml_logger(): 62 | return None 63 | 64 | 65 | def test_context(cfg_dict): 66 | assert cfg_dict 67 | 68 | 69 | def test_context_obj(cfg): 70 | assert cfg 71 | 72 | 73 | def seed(seed_value: int) -> int: 74 | pl.seed_everything(seed_value, workers=True) 75 | 76 | return seed_value 77 | 78 | 79 | def test_seed(): 80 | original_seed_value = 12345 81 | seed_value = seed(original_seed_value) 82 | assert seed_value == original_seed_value 83 | 84 | 85 | @pytest.fixture(scope="session") 86 | def tokenizer_op(cfg_dict): 87 | # return ModularTokenizerOp.from_pretrained(cfg_dict.tokenizer.tokenizer_path) 88 | # The tokenizer is loaded with the extra special tokens, so we can check they are avaible 89 | return load_and_update_tokenizer_op(cfg_dict) 90 | 91 | 92 | def test_tokenizer(tokenizer_op): 93 | assert isinstance(tokenizer_op, ModularTokenizerOp) 94 | special_tokens = [ 95 | "", 96 | "", 97 | "", 98 | "", 99 | "", 100 | "", 101 | "", 102 | "", 103 | "special_token2", 104 | "yet_another_special_token", 105 | ] 106 | for token in special_tokens: 107 | tokenizer_op.get_token_id(token) # throws assert if fails 108 | never_seen_before = "never_seen_before" 109 | with pytest.raises(AssertionError): 110 | tokenizer_op.get_token_id(never_seen_before) 111 | num_new = tokenizer_op.add_new_special_tokens( 112 | [ 113 | never_seen_before, 114 | ] 115 | ) 116 | assert num_new == 1 117 | tokenizer_op.get_token_id(never_seen_before) 118 | 119 | 120 | @pytest.fixture(scope="session") 121 | def current_train_session_metadata(): 122 | return {} 123 | 124 | 125 | @pytest.fixture(scope="session") 126 | def test_task(cfg_obj, clearml_logger, tokenizer_op): 127 | _task_list = cfg_obj.task( 128 | tokenizer_op=tokenizer_op, 129 | logger=clearml_logger, 130 | ) 131 | return _task_list 132 | 133 | 134 | @pytest.fixture(scope="session") 135 | def pl_data_module(test_task): 136 | """get lightning data module""" 137 | return test_task.data_module() 138 | 139 | 140 | @pytest.fixture(scope="session") 141 | def pl_module(test_task, cfg): 142 | """get lightning module""" 143 | model = Mammal.from_pretrained( 144 | cfg.model.pretrained_kwargs.pretrained_model_name_or_path 145 | ) 146 | _pl_module = module( 147 | model=model, 148 | task=test_task, 149 | **OmegaConf.to_container(cfg.module, resolve=True), 150 | ) 151 | return _pl_module 152 | 153 | 154 | @pytest.mark.skipif( 155 | "ccc" not in socket.gethostname(), 156 | reason="Train consumes too much memory for a Travis run.", 157 | ) 158 | def test_evaluate(cfg, pl_data_module, pl_module): 159 | pl_trainer = pl.Trainer(**cfg.trainer) 160 | 161 | pl_data_module.setup("test") 162 | out = pl_trainer.validate( 163 | model=pl_module, 164 | dataloaders=pl_data_module.test_dataloader(), 165 | ) 166 | print(out) 167 | 168 | 169 | @pytest.mark.skipif( 170 | "ccc" not in socket.gethostname(), 171 | reason="Train consumes too much memory for a Travis run.", 172 | ) 173 | def test_train(cfg, pl_data_module, pl_module): 174 | pl_trainer = pl.Trainer(**cfg.trainer) 175 | pl_trainer.fit(model=pl_module, datamodule=pl_data_module) 176 | -------------------------------------------------------------------------------- /mammal/main_finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections.abc import Callable 3 | from functools import partial 4 | 5 | import hydra 6 | import pytorch_lightning as pl 7 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 8 | from fuse.dl.lightning.pl_module import LightningModuleDefault 9 | from fuse.utils import NDict 10 | from omegaconf import DictConfig, OmegaConf 11 | 12 | from mammal.model import Mammal 13 | from mammal.task import MammalTask 14 | 15 | 16 | def save_in_model_dir( 17 | model_dir: str, model: Mammal, tokenizer_op: ModularTokenizerOp 18 | ) -> None: 19 | """ 20 | Save model configuration and tokenizer in model_dir before starting the finetunning session 21 | :param model_dir: location to store the files 22 | :param model: the model to save the configuration for 23 | :param tokenizer_op: the tokenizer to save 24 | """ 25 | 26 | if model_dir is None: 27 | return 28 | 29 | os.makedirs(model_dir, exist_ok=True) 30 | 31 | # save model config 32 | model._save_pretrained(model_dir, save_config_only=True) 33 | 34 | # tokenizer 35 | tokenizer_op.save_pretrained(os.path.join(model_dir, "tokenizer")) 36 | 37 | 38 | def configure_optimizers( 39 | module: LightningModuleDefault, opt_callable: Callable, lr_sch_callable: Callable 40 | ) -> dict: 41 | """ 42 | A callback use by lightning module to set the learning rate scheduler and the optimizer 43 | :param module: the lightning module 44 | :param opt_callable: a callable that creates an optimizer given the model parameters 45 | :param lr_sch_callable: a callable that creates a learning rate scheduler given the optimizer 46 | 47 | """ 48 | opt = opt_callable(module.trainer.model.parameters()) 49 | lr_sch = lr_sch_callable(opt) 50 | return { 51 | "optimizer": opt, 52 | "lr_scheduler": {"scheduler": lr_sch, "interval": "step"}, 53 | } 54 | 55 | 56 | def module( 57 | model: Mammal, 58 | task: MammalTask, 59 | opt_callable: Callable, 60 | lr_sch_callable: Callable, 61 | **kwargs, 62 | ) -> pl.LightningModule: 63 | """ 64 | Create lightning module 65 | :param task: the task to finetune for 66 | :param opt_callable: a callable that creates an optimizer given the model parameters 67 | :param lr_sch_callable: a callable that creates a learning rate scheduler given the optimizer 68 | :param kwargs: additional LightningModuleDefault arguments 69 | """ 70 | optimizers_and_lr_schs_callable = partial( 71 | configure_optimizers, opt_callable=opt_callable, lr_sch_callable=lr_sch_callable 72 | ) 73 | return LightningModuleDefault( 74 | model=model, 75 | losses=task.losses(), 76 | validation_metrics=task.validation_metrics(), 77 | train_metrics=task.train_metrics(), 78 | optimizers_and_lr_schs=optimizers_and_lr_schs_callable, 79 | **kwargs, 80 | ) 81 | 82 | 83 | @hydra.main(version_base="1.2", config_path=None, config_name=None) 84 | def main(cfg: DictConfig): 85 | cfg = hydra.utils.instantiate(cfg) 86 | 87 | # print configuration 88 | NDict(OmegaConf.to_container(cfg, resolve=True)).print_tree(True) 89 | 90 | # connect to clearml - if configured 91 | if "track_clearml" in cfg and cfg["track_clearml"] is not None: 92 | try: 93 | from fuse.dl.lightning.pl_funcs import start_clearml_logger 94 | 95 | clearml_task = start_clearml_logger(**cfg.track_clearml) 96 | except Exception as e: 97 | print("Tracking using clearml failed: continue without tracking") 98 | print(e) 99 | clearml_task = None 100 | 101 | if clearml_task is not None: 102 | clearml_logger = clearml_task.get_logger() 103 | else: 104 | clearml_logger = None # will be None in dist training and rank != 0 105 | else: 106 | clearml_logger = None 107 | 108 | # seed 109 | pl.seed_everything(seed=cfg.seed, workers=True) 110 | 111 | # tokenizer 112 | tokenizer_op = load_and_update_tokenizer_op(cfg) 113 | 114 | # model 115 | model = Mammal.from_pretrained(**cfg.model.pretrained_kwargs) 116 | print(model) 117 | 118 | # initialize task 119 | task: MammalTask = cfg.task(tokenizer_op=tokenizer_op, logger=clearml_logger) 120 | 121 | # lightning data module 122 | pl_data_module = task.data_module() 123 | 124 | # lightning module 125 | pl_module = module( 126 | task=task, 127 | model=model, 128 | **OmegaConf.to_container(cfg.module, resolve=True), 129 | ) 130 | 131 | # create lightning trainer. 132 | pl_trainer = pl.Trainer(**cfg.trainer) 133 | 134 | if cfg.evaluate: 135 | pl_data_module.setup("test") 136 | out = pl_trainer.test( 137 | model=pl_module, 138 | dataloaders=pl_data_module.test_dataloader(), 139 | ) 140 | print(out) 141 | else: 142 | # save model_config and tokenizer in output_dir 143 | save_in_model_dir(cfg.module.model_dir, model, tokenizer_op) 144 | pl_trainer.fit(model=pl_module, datamodule=pl_data_module) 145 | 146 | 147 | def load_and_update_tokenizer_op(cfg): 148 | tokenizer_op = ModularTokenizerOp.from_pretrained(cfg.tokenizer.tokenizer_path) 149 | 150 | if "new_special_tokens" in cfg.tokenizer and len(cfg.tokenizer.new_special_tokens): 151 | num_new_tokens_added = tokenizer_op.add_new_special_tokens( 152 | cfg.tokenizer.new_special_tokens 153 | ) 154 | if num_new_tokens_added: 155 | # TODO: write better message 156 | print( 157 | 10 * "****", 158 | f" Added { num_new_tokens_added} special tokens to the tokenizer ", 159 | 10 * "****", 160 | ) 161 | 162 | return tokenizer_op 163 | 164 | 165 | if __name__ == "__main__": 166 | main() 167 | -------------------------------------------------------------------------------- /mammal/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mammal.keys import * 4 | 5 | 6 | class LossHead(torch.nn.Module): 7 | """ 8 | Cross Entropy Loss 9 | """ 10 | 11 | def __init__( 12 | self, 13 | *, # prevent positional args 14 | loss_type: str = "ce", 15 | loss_weight: float = 1.0, 16 | sample_token_weights_key: str | None = None, 17 | ignore_index: int | None = -100, 18 | pred_key: str = LOGITS, 19 | labels_key: str = LABELS_TOKENS, 20 | verify_no_weight_at_ignore: bool = True, 21 | ) -> None: 22 | """ 23 | :param verify_no_weight_at_ignore: verifies that weights at positions where the loss is ignored (via ignore_index) equal 0, otherwise the normalization of loss by sum of weights is skewed. 24 | """ 25 | 26 | super().__init__() 27 | self.loss_type = loss_type 28 | self.loss_weight = loss_weight 29 | self.sample_token_weights_key = sample_token_weights_key 30 | self.ignore_index = ignore_index 31 | self.labels_key = labels_key 32 | self.verify_no_weight_at_ignore = verify_no_weight_at_ignore 33 | self.pred_key = pred_key 34 | 35 | if self.loss_type == "ce": 36 | self.loss_function = torch.nn.CrossEntropyLoss( 37 | ignore_index=self.ignore_index, reduction="none" 38 | ) 39 | else: 40 | raise NotImplementedError(self._loss) 41 | 42 | def forward(self, batch_dict: dict) -> torch.Tensor: 43 | preds = batch_dict[self.pred_key] 44 | targets = batch_dict[self.labels_key] 45 | 46 | if self.sample_token_weights_key: 47 | if self.sample_token_weights_key in batch_dict: 48 | sample_token_weights = batch_dict[self.sample_token_weights_key] 49 | else: 50 | 51 | sample_token_weights = None 52 | else: 53 | sample_token_weights = None 54 | 55 | # concat the tokens in all samples, we loss is at the level of single token 56 | n_classes = preds.shape[-1] if len(preds.shape) > 2 else 1 57 | preds = preds.reshape(-1, n_classes).squeeze(dim=1) 58 | targets = targets.reshape(-1) 59 | if sample_token_weights is not None: 60 | weights = sample_token_weights.reshape(-1) 61 | 62 | losses = self.loss_function(preds, targets) 63 | 64 | if sample_token_weights is None: 65 | losses = losses[targets != self.ignore_index] 66 | loss = losses.mean() 67 | else: 68 | losses[weights == 0.0] = ( 69 | 0.0 # to make nan loss values be equal to zero if weights are zero 70 | ) 71 | if self.verify_no_weight_at_ignore: 72 | assert ( 73 | weights[targets == self.ignore_index].abs().sum() == 0 74 | ), "You are using none-zero weights at ignore index positions when calculating the loss - this is most likely skewing your loss evaluation.\nTo turn off this assertion pass 'verify_no_weight_at_ignore'=False to 'mammal.losses.LossHead" 75 | loss = (losses * weights).sum() / weights.sum() 76 | 77 | return loss * self.loss_weight 78 | 79 | 80 | class RMSELoss(torch.nn.Module): 81 | def __init__(self, reduction: str = "mean", eps: float = 1e-6) -> None: 82 | super().__init__() 83 | self.mse = torch.nn.MSELoss(reduction=reduction) 84 | self.eps = eps 85 | 86 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 87 | loss = (self.mse(inputs, targets) + self.eps) ** 0.5 88 | return loss 89 | 90 | 91 | class ScalarsPredictionsLoss(torch.nn.Module): 92 | """ 93 | Scalars prediction loss. Might be MSE (default), RMSE or MAE 94 | """ 95 | 96 | def __init__( 97 | self, 98 | *, # prevent positional args 99 | loss_type: str, 100 | loss_weight: float | None = None, 101 | pred_key: str = SCALARS_PREDICTION_HEAD_LOGITS, 102 | labels_scalars_values_key: str = LABELS_SCALARS_VALUES, 103 | labels_scalars_valid_mask_key: str = LABELS_SCALARS_VALID_MASK, 104 | ) -> None: 105 | 106 | super().__init__() 107 | self.loss_type = loss_type 108 | self.loss_weight = 1.0 if loss_weight is None else loss_weight 109 | self.pred_key = pred_key 110 | self.labels_scalars_values_key = labels_scalars_values_key 111 | self.labels_scalars_valid_mask_key = labels_scalars_valid_mask_key 112 | 113 | if self.loss_type == "mse": 114 | self.loss_function = torch.nn.MSELoss(reduction="none") 115 | elif self.loss_type == "mae": 116 | self.loss_function = torch.nn.L1Loss(reduction="none") 117 | elif self.loss_type == "rmse": 118 | self.loss_function = RMSELoss(reduction="none") 119 | else: 120 | raise NotImplementedError(self._loss) 121 | 122 | def forward(self, batch_dict: dict) -> torch.Tensor: 123 | if ( 124 | (not _legit(batch_dict, self.pred_key)) 125 | or (not _legit(batch_dict, self.labels_scalars_values_key)) 126 | or (not _legit(batch_dict, self.labels_scalars_valid_mask_key)) 127 | ): 128 | return 0.0 129 | 130 | preds = batch_dict[self.pred_key] 131 | targets = batch_dict[self.labels_scalars_values_key] 132 | valid = batch_dict[self.labels_scalars_valid_mask_key] 133 | 134 | assert ( 135 | preds.shape == targets.shape 136 | ), f"preds shape ({preds.shape}) is expected to be the same as targets shape ({targets.shape})" 137 | assert ( 138 | targets.shape == valid.shape 139 | ), f"targets shape ({valid.shape}) is expected to be the same as targets shape ({valid.shape})" 140 | 141 | # take only the valid elements 142 | preds = preds[valid.bool()] 143 | targets = targets[valid.bool()] 144 | 145 | curr_loss = self.loss_function(preds, targets) 146 | 147 | return curr_loss.mean() * self.loss_weight 148 | 149 | 150 | def _legit(batch_dict: dict, key: str) -> bool: 151 | if key not in batch_dict: 152 | return False 153 | if batch_dict[key] is None: 154 | return False 155 | if len(batch_dict[key]) == 0: 156 | return False 157 | return True 158 | -------------------------------------------------------------------------------- /mammal/examples/scrna_cell_type/scRNA_infer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from pprint import pprint 3 | 4 | import anndata 5 | import click 6 | import numpy as np 7 | import torch 8 | from anndata_op import OpReadAnnData 9 | from fuse.data.datasets.dataset_default import DatasetDefault 10 | from fuse.data.pipelines.pipeline_default import PipelineDefault 11 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 12 | from fuse.data.utils.collates import CollateDefault 13 | from pl_data_module import CellTypeDataModule 14 | from task import CellTypeTask 15 | from torch.utils.data.dataset import Dataset 16 | 17 | from mammal.keys import ( 18 | CLS_PRED, 19 | SCORES, 20 | ) 21 | from mammal.model import Mammal 22 | 23 | 24 | @click.command() 25 | @click.argument("task_name", default="cell_type") 26 | @click.option( 27 | "--model-path", 28 | "-m", 29 | help="Specify the model dir.", 30 | ) 31 | @click.option( 32 | "--tokenizer_path", 33 | default=None, 34 | type=str, 35 | help="Specify the tokenizer path.", 36 | ) 37 | @click.option( 38 | "--h5ad-file-path", 39 | "-i", 40 | type=str, 41 | help="Specify the A5HD (AnnData) input file.", 42 | default="data/Zheng_68k_processed.h5ad", 43 | ) 44 | @click.option("--sample_id", "-s", type=int, default=0) 45 | @click.option( 46 | "--verbose", 47 | "-v", 48 | count=True, 49 | default=1, 50 | ) 51 | @click.option( 52 | "--test-h5ad-file", 53 | "-T", 54 | is_flag=True, 55 | default=False, 56 | ) 57 | @click.option( 58 | "--device", default="cpu", help="Specify the device to use (default: 'cpu')." 59 | ) 60 | def main( 61 | task_name: str, 62 | h5ad_file_path: str, 63 | sample_id: int, 64 | model_path: str, 65 | tokenizer_path: str, 66 | verbose: int, 67 | test_h5ad_file: bool, 68 | device: str, 69 | ): 70 | try: 71 | anndata_object = anndata.read_h5ad(h5ad_file_path) 72 | except Exception as e: 73 | raise ValueError( 74 | f"Failed to read {h5ad_file_path} as a serialized AnnData " 75 | ) from e 76 | 77 | tokenizer_op, nn_model = get_tokenizer_and_model(tokenizer_path, model_path, device) 78 | 79 | # convert to MAMMAL style 80 | 81 | dynamic_pipeline = PipelineDefault( 82 | "cell_type", 83 | [ 84 | (OpReadAnnData(data=anndata_object), {"prefix": "scrna"}), 85 | ], 86 | ) 87 | 88 | data_source = DatasetDefault( 89 | sample_ids=anndata_object.shape[0], dynamic_pipeline=dynamic_pipeline 90 | ) 91 | data_source.create() 92 | 93 | sample_dict = create_sample_dict(task_name, data_source, sample_id, tokenizer_op) 94 | if test_h5ad_file or verbose > 2: 95 | print("/n/n sample dict:/n") 96 | pprint(dict(sample_dict)) 97 | if test_h5ad_file and "data.label" not in sample_dict: 98 | raise ValueError( 99 | "sample_dict['data.label'] is missing - data can not be used for training" 100 | ) 101 | unique_values, counts = np.unique( 102 | sample_dict["scrna.scrna"].data, return_counts=True 103 | ) 104 | n_values = len(unique_values) 105 | if n_values > 11: 106 | print("-" * 60) 107 | print( 108 | f"The data has {n_values} different expression bins which is typical for data that did not pass though the preprocessing." 109 | ) 110 | for value, count in zip(unique_values, counts): 111 | print(f"Value {value} appears {count} times") 112 | print("\n" * 4) 113 | 114 | batch_dict = CollateDefault(skip_keys=CellTypeDataModule.skip_keys)([sample_dict]) 115 | 116 | if test_h5ad_file or verbose > 1: 117 | key_to_print = ["data.query.encoder_input", "data.encoder_input_token_ids"] 118 | for key in key_to_print: 119 | print(f"{key}: {batch_dict[key]}") 120 | 121 | if test_h5ad_file or verbose > 2: 122 | n_zero = torch.sum(batch_dict["data.encoder_input_token_ids"] == 0).item() 123 | total_length = torch.sum(batch_dict["data.encoder_input_attention_mask"]).item() 124 | print( 125 | f"{n_zero} unknown from {total_length} tokens ({round((n_zero*100)/total_length,2)}%)" 126 | ) 127 | 128 | # run the model 129 | batch_dict = get_predictions(nn_model, batch_dict=batch_dict) 130 | ans = process_model_output(tokenizer_op, batch_dict) 131 | ans = { 132 | k: v.detach().numpy() if isinstance(v, torch.Tensor) else v 133 | for k, v in ans.items() 134 | } 135 | if test_h5ad_file or verbose: 136 | print(ans) 137 | 138 | 139 | def process_model_output(tokenizer_op, batch_dict): 140 | return CellTypeTask.process_model_output( 141 | tokenizer_op=tokenizer_op, 142 | decoder_output=batch_dict[CLS_PRED][0], 143 | decoder_output_scores=batch_dict[SCORES][0], 144 | ) 145 | 146 | 147 | def get_tokenizer_and_model(tokenizer_path, model_path, device): 148 | if tokenizer_path is None: 149 | tokenizer_path = Path(model_path) 150 | tokenizer_op = ModularTokenizerOp.from_pretrained(tokenizer_path) 151 | nn_model = Mammal.from_pretrained( 152 | pretrained_model_name_or_path=model_path, 153 | ) 154 | nn_model.eval() 155 | nn_model.to(device=device) 156 | 157 | return tokenizer_op, nn_model 158 | 159 | 160 | def create_sample_dict( 161 | task_name: str, data_source: Dataset, sample_id: int, tokenizer_op 162 | ): 163 | sequence_key = "scrna.scrna" 164 | # Create and load sample 165 | sample_dict = data_source[sample_id] 166 | 167 | sample_dict = CellTypeTask.data_preprocessing( 168 | sample_dict, 169 | sequence_key=sequence_key, 170 | input_max_seq_length=1260, 171 | encoder_input_max_seq_len=1260, 172 | labels_max_seq_len=4, 173 | tokenizer_op=tokenizer_op, 174 | ) 175 | 176 | return sample_dict 177 | 178 | 179 | def get_predictions(model, batch_dict): 180 | return model.generate( 181 | batch_dict, 182 | output_scores=True, 183 | return_dict_in_generate=True, 184 | max_new_tokens=5, 185 | ) 186 | 187 | 188 | def print_result(batch_dict, scalars_preds_processed_key): 189 | 190 | value = batch_dict[scalars_preds_processed_key] 191 | ans = {"scalar_result": value} 192 | 193 | # Print prediction 194 | 195 | batch_dict["scalar_result"] = value 196 | print(f"estimated value: {ans}") 197 | 198 | 199 | def process_sample(tokenizer_op, nn_model, sample_dict): 200 | # running in generate mode 201 | batch_dict = nn_model.generate( 202 | [sample_dict], 203 | output_scores=True, 204 | return_dict_in_generate=True, 205 | max_new_tokens=5, 206 | ) 207 | 208 | return batch_dict 209 | 210 | 211 | if __name__ == "__main__": 212 | main() 213 | -------------------------------------------------------------------------------- /mammal/examples/protein_solubility/pl_data_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from collections.abc import Callable 4 | 5 | import pandas as pd 6 | import pytorch_lightning as pl 7 | import wget 8 | from fuse.data.datasets.dataset_default import DatasetDefault 9 | from fuse.data.ops.ops_read import OpReadDataframe 10 | from fuse.data.pipelines.pipeline_default import PipelineDefault 11 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 12 | from fuse.data.utils.collates import CollateDefault 13 | from torch.utils.data.dataloader import DataLoader 14 | 15 | from mammal.keys import * # noqa 16 | 17 | 18 | class ProteinSolubilityDataModule(pl.LightningDataModule): 19 | def __init__( 20 | self, 21 | *, 22 | data_path: str, 23 | batch_size: int, 24 | tokenizer_op: ModularTokenizerOp, 25 | train_dl_kwargs: dict, 26 | valid_dl_kwargs: dict, 27 | seed: int, 28 | data_preprocessing: Callable, 29 | protein_max_seq_length: int, 30 | encoder_input_max_seq_len: int, 31 | labels_max_seq_len: int, 32 | ) -> None: 33 | """_summary_ 34 | Args: 35 | data_path (str): path to the raw data, if not exist, will download the data to the given path. 36 | batch_size (int): batch size 37 | tokenizer_op (ModularTokenizerOp): tokenizer op 38 | encoder_inputs_max_seq_len: max tokenizer sequence length for the encoder inputs, 39 | labels_max_seq_len: max tokenizer sequence length for the labels, 40 | train_dl_kwargs (dict): train dataloader constructor parameters 41 | valid_dl_kwargs (dict): validation dataloader constructor parameters 42 | seed (int): random seed 43 | """ 44 | super().__init__() 45 | self.data_path = data_path 46 | self.tokenizer_op = tokenizer_op 47 | self.protein_max_seq_length = protein_max_seq_length 48 | self.encoder_input_max_seq_len = encoder_input_max_seq_len 49 | self.labels_max_seq_len = labels_max_seq_len 50 | self.batch_size = batch_size 51 | self.train_dl_kwargs = train_dl_kwargs 52 | self.valid_dl_kwargs = valid_dl_kwargs 53 | self.seed = seed 54 | self.data_preprocessing = data_preprocessing 55 | 56 | self.pad_token_id = self.tokenizer_op.get_token_id("") 57 | 58 | def setup(self, stage: str) -> None: 59 | self.ds_dict = load_datasets(self.data_path) 60 | 61 | task_pipeline = [ 62 | ( 63 | # Prepare the input string(s) in modular tokenizer input format 64 | self.data_preprocessing, 65 | dict( 66 | protein_sequence_key="data.protein", 67 | solubility_label_key="data.label", 68 | tokenizer_op=self.tokenizer_op, 69 | protein_max_seq_length=self.protein_max_seq_length, 70 | encoder_input_max_seq_len=self.encoder_input_max_seq_len, 71 | labels_max_seq_len=self.labels_max_seq_len, 72 | ), 73 | ), 74 | ] 75 | 76 | for ds in self.ds_dict.values(): 77 | ds.dynamic_pipeline.extend(task_pipeline) 78 | 79 | def train_dataloader(self) -> DataLoader: 80 | train_loader = DataLoader( 81 | dataset=self.ds_dict["train"], 82 | batch_size=self.batch_size, 83 | collate_fn=CollateDefault(), 84 | shuffle=True, 85 | **self.train_dl_kwargs, 86 | ) 87 | return train_loader 88 | 89 | def val_dataloader(self) -> DataLoader: 90 | val_loader = DataLoader( 91 | self.ds_dict["val"], 92 | batch_size=self.batch_size, 93 | collate_fn=CollateDefault(), 94 | **self.valid_dl_kwargs, 95 | ) 96 | 97 | return val_loader 98 | 99 | def test_dataloader(self) -> DataLoader: 100 | test_loader = DataLoader( 101 | self.ds_dict["test"], 102 | batch_size=self.batch_size, 103 | collate_fn=CollateDefault(), 104 | **self.valid_dl_kwargs, 105 | ) 106 | 107 | return test_loader 108 | 109 | def predict_dataloader(self) -> DataLoader: 110 | return self.test_dataloader() 111 | 112 | 113 | _SOLUBILITY_URL = "https://zenodo.org/api/records/1162886/files-archive" 114 | 115 | 116 | def load_datasets(data_path: str) -> dict[str, DatasetDefault]: 117 | """ 118 | Automatically downloads the data and create dataset iterator for "train", "val" and "test". 119 | paper: https://academic.oup.com/bioinformatics/article/34/15/2605/4938490 120 | Data retrieved from: https://zenodo.org/records/1162886 121 | The benchmark requires classifying protein sequences into binary labels - Soluble or Insoluble (1 or 0). 122 | :param data_path: path to a directory to store the raw data 123 | :return: dictionary that maps fold name "train", "val" and "test" to a dataset iterator 124 | """ 125 | 126 | if not os.path.exists(data_path): 127 | os.makedirs(data_path) 128 | 129 | raw_data_path = os.path.join(data_path, "sameerkhurana10-DSOL_rv0.2-20562ad/data") 130 | if not os.path.exists(raw_data_path): 131 | wget.download(_SOLUBILITY_URL, data_path) 132 | file_path = os.path.join(data_path, "1162886.zip") 133 | shutil.unpack_archive(file_path, extract_dir=data_path) 134 | inner_file_path = os.path.join( 135 | data_path, "sameerkhurana10", "DSOL_rv0.2-v0.3.zip" 136 | ) 137 | shutil.unpack_archive(inner_file_path, extract_dir=data_path) 138 | assert os.path.exists( 139 | raw_data_path 140 | ), f"Error: download complete but {raw_data_path} doesn't exist" 141 | 142 | # read files 143 | df_dict = {} 144 | for set_name in ["train", "val", "test"]: 145 | input_df = pd.read_csv( 146 | os.path.join(raw_data_path, f"{set_name}_src"), names=["data.protein"] 147 | ) 148 | labels_df = pd.read_csv( 149 | os.path.join(raw_data_path, f"{set_name}_tgt"), names=["data.label"] 150 | ) 151 | df_dict[set_name] = (input_df, labels_df) 152 | 153 | ds_dict = {} 154 | for set_name in ["train", "val", "test"]: 155 | input_df, labels_df = df_dict[set_name] 156 | size = len(labels_df) 157 | print(f"{set_name} set size is {size}") 158 | dynamic_pipeline = PipelineDefault( 159 | "solubility", 160 | [ 161 | (OpReadDataframe(input_df, key_column=None), dict()), 162 | (OpReadDataframe(labels_df, key_column=None), dict()), 163 | ], 164 | ) 165 | 166 | ds = DatasetDefault(sample_ids=size, dynamic_pipeline=dynamic_pipeline) 167 | ds.create() 168 | ds_dict[set_name] = ds 169 | 170 | return ds_dict 171 | 172 | 173 | if __name__ == "__main__": 174 | load_datasets("data") 175 | -------------------------------------------------------------------------------- /mammal/examples/carcinogenicity/task.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 7 | 8 | from mammal.examples.carcinogenicity.pl_data_module import CarcinogenicityDataModule 9 | from mammal.keys import ( 10 | CLS_PRED, 11 | DECODER_INPUTS_ATTENTION_MASK, 12 | DECODER_INPUTS_STR, 13 | DECODER_INPUTS_TOKENS, 14 | ENCODER_INPUTS_ATTENTION_MASK, 15 | ENCODER_INPUTS_STR, 16 | ENCODER_INPUTS_TOKENS, 17 | LABELS_ATTENTION_MASK, 18 | LABELS_STR, 19 | LABELS_TOKENS, 20 | SCORES, 21 | ) 22 | from mammal.metrics import classification_metrics 23 | from mammal.task import ( 24 | MammalTask, 25 | MetricBase, 26 | ) 27 | 28 | 29 | class CarcinogenicityTask(MammalTask): 30 | def __init__( 31 | self, 32 | *, 33 | tokenizer_op: ModularTokenizerOp, 34 | data_module_kwargs: dict, 35 | logger: Any | None = None, 36 | ) -> None: 37 | super().__init__( 38 | name="carcinogenicity", 39 | logger=logger, 40 | tokenizer_op=tokenizer_op, 41 | ) 42 | self._data_module_kwargs = data_module_kwargs 43 | 44 | self.preds_key = CLS_PRED 45 | self.scores_key = SCORES 46 | self.labels_key = LABELS_TOKENS 47 | 48 | def data_module(self) -> pl.LightningDataModule: 49 | return CarcinogenicityDataModule( 50 | tokenizer_op=self._tokenizer_op, 51 | data_preprocessing=self.data_preprocessing, 52 | **self._data_module_kwargs, 53 | ) 54 | 55 | def train_metrics(self) -> dict[str, MetricBase]: 56 | metrics = super().train_metrics() 57 | metrics.update( 58 | classification_metrics( 59 | self.name(), 60 | class_position=1, 61 | tokenizer_op=self._tokenizer_op, 62 | class_tokens=["<0>", "<1>"], 63 | ) 64 | ) 65 | 66 | return metrics 67 | 68 | def validation_metrics(self) -> dict[str, MetricBase]: 69 | validation_metrics = super().validation_metrics() 70 | validation_metrics.update( 71 | classification_metrics( 72 | self.name(), 73 | class_position=1, 74 | tokenizer_op=self._tokenizer_op, 75 | class_tokens=["<0>", "<1>"], 76 | ) 77 | ) 78 | return validation_metrics 79 | 80 | @staticmethod 81 | def data_preprocessing( 82 | sample_dict: dict, 83 | *, 84 | sequence_key: str, 85 | label_key: int | None = None, 86 | drug_max_seq_length: int = 1250, 87 | encoder_input_max_seq_len: int | None = 1260, 88 | labels_max_seq_len: int | None = 4, 89 | tokenizer_op: ModularTokenizerOp, 90 | device: str | torch.device = "cpu", 91 | ) -> dict: 92 | drug_sequence = sample_dict[sequence_key] 93 | label = sample_dict.get(label_key, None) 94 | 95 | sample_dict[ENCODER_INPUTS_STR] = ( 96 | f"<@TOKENIZER-TYPE=SMILES><@TOKENIZER-TYPE=SMILES@MAX-LEN={drug_max_seq_length}>{drug_sequence}" 97 | ) 98 | tokenizer_op( 99 | sample_dict=sample_dict, 100 | key_in=ENCODER_INPUTS_STR, 101 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS, 102 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK, 103 | max_seq_len=encoder_input_max_seq_len, 104 | ) 105 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor( 106 | sample_dict[ENCODER_INPUTS_TOKENS], device=device 107 | ) 108 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor( 109 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=device 110 | ) 111 | 112 | if label is not None: 113 | pad_id = tokenizer_op.get_token_id("") 114 | ignore_token_value = -100 115 | sample_dict[LABELS_STR] = ( 116 | f"<@TOKENIZER-TYPE=SMILES><{label}>" 117 | ) 118 | tokenizer_op( 119 | sample_dict=sample_dict, 120 | key_in=LABELS_STR, 121 | key_out_tokens_ids=LABELS_TOKENS, 122 | key_out_attention_mask=LABELS_ATTENTION_MASK, 123 | max_seq_len=labels_max_seq_len, 124 | ) 125 | sample_dict[LABELS_TOKENS] = torch.tensor( 126 | sample_dict[LABELS_TOKENS], device=device 127 | ) 128 | sample_dict[LABELS_ATTENTION_MASK] = torch.tensor( 129 | sample_dict[LABELS_ATTENTION_MASK], device=device 130 | ) 131 | # replace pad_id with -100 to 132 | sample_dict[LABELS_TOKENS][ 133 | (sample_dict[LABELS_TOKENS][..., None] == torch.tensor(pad_id)) 134 | .any(-1) 135 | .nonzero() 136 | ] = ignore_token_value 137 | 138 | sample_dict[DECODER_INPUTS_STR] = ( 139 | f"<@TOKENIZER-TYPE=SMILES><{label}>" 140 | ) 141 | tokenizer_op( 142 | sample_dict=sample_dict, 143 | key_in=DECODER_INPUTS_STR, 144 | key_out_tokens_ids=DECODER_INPUTS_TOKENS, 145 | key_out_attention_mask=DECODER_INPUTS_ATTENTION_MASK, 146 | max_seq_len=labels_max_seq_len, 147 | ) 148 | sample_dict[DECODER_INPUTS_TOKENS] = torch.tensor( 149 | sample_dict[DECODER_INPUTS_TOKENS], device=device 150 | ) 151 | sample_dict[DECODER_INPUTS_ATTENTION_MASK] = torch.tensor( 152 | sample_dict[DECODER_INPUTS_ATTENTION_MASK], device=device 153 | ) 154 | 155 | return sample_dict 156 | 157 | @staticmethod 158 | def process_model_output( 159 | tokenizer_op: ModularTokenizerOp, 160 | decoder_output: np.ndarray, 161 | decoder_output_scores: np.ndarray, 162 | ) -> dict: 163 | negative_token_id = tokenizer_op.get_token_id("<0>") 164 | positive_token_id = tokenizer_op.get_token_id("<1>") 165 | label_id_to_int = { 166 | negative_token_id: 0, 167 | positive_token_id: 1, 168 | } 169 | classification_position = 1 170 | 171 | if decoder_output_scores is not None: 172 | not_normalized_score = decoder_output_scores[ 173 | classification_position, positive_token_id 174 | ] 175 | normalized_score = not_normalized_score / ( 176 | not_normalized_score 177 | + decoder_output_scores[classification_position, negative_token_id] 178 | + 1e-10 179 | ) 180 | 181 | ans = dict( 182 | pred=label_id_to_int.get(int(decoder_output[classification_position]), -1), 183 | not_normalized_scores=not_normalized_score, 184 | normalized_scores=normalized_score, 185 | ) 186 | 187 | return ans 188 | -------------------------------------------------------------------------------- /mammal/examples/scrna_cell_type/data/Zheng68k_to_anndata.py: -------------------------------------------------------------------------------- 1 | # script to pack zheng68k data downloaded from x10genomics into an AnnData/h5ad file. 2 | 3 | # This example follows [Zheng]() for identification of white blood cell types from single cell RNA expression data. 4 | 5 | 6 | # ## Outline of process 7 | # The finetune process requires the input data to be in scRNA-sec AnnData format (saved as an h5ad file) with cell types 8 | # as labels. If the data is not packed as AnnData, as is the case when downloading from the 9 | # 10xGenomics site as explained below, it need to first be packed into one and saved to the disk. 10 | # cell types, if present, should be stored in the `adata.obs['cell_type']` observation. 11 | 12 | # This script assumes that it is run from the data directory, 13 | # which is typically `bmfm-mammal-release/mammal/examples/scrna_cell_type/data` 14 | 15 | import os 16 | import subprocess 17 | from pathlib import Path 18 | 19 | import anndata 20 | import click 21 | import pandas as pd 22 | from scipy.io import mmread 23 | 24 | ### Obtaining the source data: 25 | # The main data is available online, for example in the [10xGenomics](https://www.10xgenomics.com/) cite. 26 | # The labels are based on the data in [LINK](https://www.10xgenomics.com/datasets/fresh-68-k-pbm-cs-donor-a-1-standard-1-1-0) 27 | # From this site download the file `fresh_68k_pbmc_donor_a_filtered_gene_bc_matrices.tar.gz` and place it in the data directory. 28 | 29 | 30 | DEFULT_LABELS_FILE = ( 31 | "zheng17_bulk_lables.txt" # yes, the original file is named this way. 32 | ) 33 | GZIP_FILE_NAME = "fresh_68k_pbmc_donor_a_filtered_gene_bc_matrices.tar.gz" 34 | RAW_H5AD_FILE = "Zheng_68k.h5ad" 35 | RAW_DATA_SUBDIR = Path("filtered_matrices_mex/hg19") 36 | 37 | 38 | @click.command() 39 | @click.option( 40 | "--output-h5ad-file", 41 | "-o", 42 | default=None, 43 | help="name of output H5AD file. default is adding '_preprocessed' to the input file", 44 | ) 45 | @click.option( 46 | "--data-dir", 47 | help="dirname for the downloaded and constructed data files", 48 | default=".", 49 | ) 50 | @click.option( 51 | "--labels_file", 52 | "-l", 53 | default=DEFULT_LABELS_FILE, 54 | ) 55 | @click.option( 56 | "--labels_key", 57 | "-k", 58 | default="cell_type", 59 | help="key to use for the cell type labels in the AnnData observations.", 60 | ) 61 | @click.option( 62 | "--verbose", "-v", is_flag=True, default=False, help="be verbose (default: off)." 63 | ) 64 | def main( 65 | output_h5ad_file: str, 66 | data_dir: os.PathLike, 67 | labels_file: os.PathLike, 68 | labels_key: str, 69 | verbose: bool = False, 70 | ): 71 | # all work is done in the data dir 72 | os.chdir(data_dir) 73 | barcode_file = RAW_DATA_SUBDIR / "barcodes.tsv" 74 | genes_file = RAW_DATA_SUBDIR / "genes.tsv" 75 | matrix_file = RAW_DATA_SUBDIR / "matrix.mtx" 76 | 77 | if not RAW_DATA_SUBDIR.exists(): 78 | # check if the file exists 79 | if not os.path.exists(GZIP_FILE_NAME): 80 | print( 81 | f"please download the file {GZIP_FILE_NAME} from https://www.10xgenomics.com/datasets/fresh-68-k-pbm-cs-donor-a-1-standard-1-1-0 into this data directory and then run this script again from that directory" 82 | ) 83 | raise FileNotFoundError( 84 | f"Both the {GZIP_FILE_NAME} and the raw data directory {RAW_DATA_SUBDIR} extracted from it not found under the current directory" 85 | ) 86 | else: 87 | if verbose: 88 | print(f"extracting files from {GZIP_FILE_NAME}") 89 | subprocess.run(["tar", "xvzf", GZIP_FILE_NAME], check=True) 90 | 91 | if labels_file is not None: # if we do not want to add labels 92 | if not os.path.exists(labels_file): 93 | if ( 94 | labels_file == DEFULT_LABELS_FILE 95 | ): # special case - we can download this file if needed. 96 | labels_file_url = "https://raw.githubusercontent.com/scverse/scanpy_usage/refs/heads/master/170503_zheng17/data/zheng17_bulk_lables.txt" 97 | if verbose: 98 | print(f"Missing cell-type-labels file {labels_file}") 99 | print(f"downloading it from {labels_file_url}") 100 | subprocess.run(["wget", labels_file_url], check=True) 101 | if verbose: 102 | print("downloaded") 103 | else: 104 | raise FileNotFoundError("please supply labels file") 105 | raw_adata = create_anndata_from_csv( 106 | barcode_file, 107 | genes_file, 108 | matrix_file, 109 | labels_file=labels_file, 110 | labels_key=labels_key, 111 | ) 112 | 113 | # Save result anndata object to disk 114 | raw_adata.write_h5ad(output_h5ad_file) 115 | 116 | if verbose: 117 | print(f"the raw {output_h5ad_file} is ready") 118 | print( 119 | "to use this h5ad file please filter, normalize and bin the counts. You can use process_h5ad_data.py to do this." 120 | ) 121 | 122 | # This is a standard AnnData file, and can be replaced by your own data. 123 | 124 | 125 | def create_anndata_from_csv( 126 | barcode_file, 127 | genes_file, 128 | matrix_file, 129 | labels_file, 130 | labels_key, 131 | ) -> anndata.AnnData: 132 | """Construct an h5ad file (an anndata object dump) from its components. 133 | 134 | 135 | Args: 136 | raw_h5ad_file (os.PathLike): name of file to save constructed AnnData into 137 | barcode_file (os.PathLike): this file holds the mapping from the sample index to the cell identifier 138 | genes_file (os.PathLike): Mapping from feature index to gene name 139 | matrix_file (os.PathLike): The actual data, is (sparse) matrix form. 140 | labels_file (os.PathLike, optional): File containing the cell types for each file. 141 | labels_key (str): name of observation to place labels under in the AnnData object. 142 | verbose (bool, optional): verbose output. Defaults to False. 143 | 144 | Returns: 145 | anndata.AnnData: the generated anndata object 146 | """ 147 | mmx = mmread(matrix_file) 148 | 149 | #### Create an AnnData object wrapping the read data 150 | 151 | # Notice that this code transposes the data to the correct direction 152 | 153 | anndata_object = anndata.AnnData(X=mmx.transpose().tocsr()) 154 | 155 | # Cell identifiers 156 | observation_names = pd.read_csv(barcode_file, header=None, sep="\t") 157 | # names of genes 158 | genes = pd.read_csv(genes_file, header=None, sep="\t") 159 | 160 | # use the gene names as variable names in the AnnData object 161 | anndata_object.var_names = genes[1] 162 | 163 | # use the cell barcodes as names for the samples 164 | anndata_object.obs_names = observation_names[0] 165 | 166 | if labels_file is not None: 167 | # cell types (this is actualy just one column) 168 | cell_type_labels = pd.read_csv(labels_file, header=None, sep="\t") 169 | # use cell types as labels for the samples 170 | anndata_object.obs[labels_key] = cell_type_labels.squeeze().to_numpy() 171 | 172 | return anndata_object 173 | 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /mammal_mcp/README.md: -------------------------------------------------------------------------------- 1 | # Mammal MCP server - making mammal tasks accessible to AI Agents. 2 | 3 | A service that provides the `ibm/biomed.omics.bl.sm.ma-ted-458m` model tasks to AI Agents. 4 | 5 | ## Overview 6 | 7 | MAMMAL (ibm/biomed.omics.bl.sm.ma-ted-458m) is a 'biomedical foundation model' (BMFM) that has been trained by IBM and the details can be found here -> https://github.com/BiomedSciAI/biomed-multi-alignment. 8 | 9 | This repository is a fastmcp server which creates entrypoints for AI Agents to make inference for tasks currently supported by MAMMAL. 10 | 11 | ## Getting started 12 | 13 | Change to mammal-mcp directory 14 | 15 | ```sh 16 | cd mammal_mcp 17 | ``` 18 | 19 | Create the environment: 20 | 21 | ```sh 22 | cp .env.example .env 23 | ``` 24 | 25 | These env vars control which modalities will be available to the agent. 26 | 27 | 28 | The first time you run the server you need to download the models. Therefore set all the tasks that you will subsequently use to true in the .env and run: 29 | 30 | ```sh 31 | uv run python -m server 32 | ``` 33 | 34 | Then wait for all the models to be downloaded before quiting the server. 35 | 36 | ## Running the server using STDIO (default) 37 | 38 | ### Integration into Claude Desktop 39 | 40 | **If using Claude as your MCP client, DO NOT, add any confidential or personal data into the system.** 41 | 42 | One of the easiest ways to experiment with the tools provided by mammal-mcp is to leverage the Claude Desktop. 43 | 44 | For that, update your Claude Desktop config file (located at `~/Library/Application Support/Claude/claude_desktop_config.json`) with the JSON below: 45 | 46 | ```json 47 | { 48 | "mcpServers": { 49 | "mammal": { 50 | "command": "", 51 | "args": [ 52 | "--directory", 53 | "", 54 | "run", 55 | "server.py" 56 | ] 57 | } 58 | } 59 | } 60 | 61 | 62 | ``` 63 | 64 | **- Change both placeholders in this JSON indicated with <>** 65 | 66 | Quit and re-start Claude Desktop, then use appropriate prompt. 67 | 68 | ### Integration into [MCPHost](https://github.com/mark3labs/mcphost) (using [Ollama](https://ollama.com/)) 69 | 70 | MCPHost is a host application that enables LLMs to interact with external tools through MCP. 71 | It supports Claude 3.5 Sonnet and Ollama models, and we'll choose Ollama for this case as an example of using local LLM. 72 | 73 | To install MCPHost, you'll need to install [go](https://go.dev/) as needed. 74 | Then, 75 | 76 | ```sh 77 | go install github.com/mark3labs/mcphost@latest 78 | ``` 79 | 80 | mcphost will be downloaded into the bin folder of go. You'll need to add PATH variable for that folder to launch it. 81 | 82 | ```sh 83 | PATH="$(go env GOPATH)/bin:$PATH" 84 | ``` 85 | 86 | Install [Ollama](https://ollama.com/) on your desktop, and download the model you'd like to use. For example: 87 | 88 | ```sh 89 | ollama run qwen3 90 | ``` 91 | 92 | Prepare configuration json file - make json file (for example mammal_mcp.json) as follows: 93 | 94 | ```json 95 | { 96 | "mcpServers": { 97 | "mammal": { 98 | "command": "", 99 | "args": [ 100 | "--directory", 101 | "", 102 | "run", 103 | "server.py" 104 | ] 105 | } 106 | } 107 | } 108 | ``` 109 | 110 | or, if you already have configuration json file for MCPHost, add "mammal" part above as the member of "mcpServers". 111 | 112 | Finally, launch MCPHost using LLM which you'd like to use with the configuration file you prepared above. For example: 113 | 114 | ```sh 115 | mcphost -m ollama:qwen3 --config 116 | ``` 117 | 118 | ## Pre-trained task usage (in Claude) 119 | 120 | Whichever task you want to utilize (we recommend not using anymore than 2 models at one time), first set this task to true in your .env file and then run: 121 | 122 | ```sh 123 | uv run python -m server 124 | ``` 125 | 126 | This will pre-download the required models for your task. Once complete you will see the following message `Assets loaded`. You can kill the server at this point. 127 | 128 | When you start Claude for any task you will get JSON parsing related error messages. This can be ignored. 129 | 130 | ### 1. Protein protein interaction prediction 131 | 132 | - Binary classification task to predict protein-protein interaction using the pre-trained model `ibm-research/biomed.omics.bl.sm.ma-ted-458m`. 133 | 134 | - Expected input are the either the amino acid sequences of the two proteins or the protein names. 135 | 136 | - Ensure `PROTEIN_PROTEIN_INTERACTION` is set to `true` in `.env` file 137 | 138 | Example prompt: 139 | 140 | ``` 141 | Do proteins VPS35 and VPS26 interact together? 142 | ``` 143 | 144 | ### 2. Protein solubility prediction 145 | 146 | - Binary classification task to predict protein solubility using the fine-tuned model `ibm-research/biomed.omics.bl.sm.ma-ted-458m.protein_solubility`. 147 | 148 | - Expected input are the either the amino acid sequences of the protein or the protein name. 149 | 150 | - Ensure `PROTEIN_SOLUBILITY` is set to `true` in `.env` file 151 | 152 | Example prompt: 153 | 154 | ``` 155 | Is protein VPS35 soluble in aqueous solutions? 156 | ``` 157 | ### 3. Drug-target binding prediction 158 | 159 | - Prediction of drug-target binding affinity using the fine-tuned model `ibm-research/biomed.omics.bl.sm.ma-ted-458m.dti_bindingdb_pkd`. 160 | 161 | - Expected input are the amino acid sequence of the target and the SMILES representation of the drug. Binding affinity is predicted using pKd (the negative logarithm of the dissociation constant, reflecting the strength of the interaction between a sm molecule and protein) 162 | 163 | - Ensure `DRUG_TARGET_BINDING` is set to `true` in `.env` file 164 | 165 | Example prompt: 166 | 167 | ``` 168 | What is the predicted binding affinity between the drug with the SMILES sequence "CC(=O)NCCC1=CNc2c1cc(OC)cc2" and target protein with amino acid sequence "NLMKRCTRGFRKLGKCTTLEEEKCKTLYPRGQCTCSDSKMNTHSCDCKSC" 169 | ``` 170 | 171 | ### 4. TCR-epitope binding 172 | 173 | - Binary classification task predicting binding of binding between T-cell receptor and epitope sequences using the fine-tuned model `ibm-research/biomed.omics.bl.sm.ma-ted-458m.tcr_epitope_bind`. 174 | 175 | - Expected inputs are the amino acid sequences of the epitope and T-cell receptor. 176 | 177 | - Ensure `TCR_EPITOPE_BINDING` is set to `true` in `.env` file 178 | 179 | Example prompt: 180 | 181 | ``` 182 | does the tcr with the following sequence NAGVTQTPKFQVLKTGQSMTLQCAQDMNHEYMSWYRQDPGMGLRLIHYSVGAGITDQGEVPNGYNVSRSTTEDFPLRLLSAAPSQTSVYFCASSYSWDRVLEQYFGPGTRLTVT bind to the epitope with following sequence LLQTGIHVRVSQPSL? 183 | ``` 184 | 185 | ## Running the server using Streamable-HTTP 186 | 187 | Change the STREAMABLE_HTTP environment variable in the .env file to true. 188 | 189 | Then run: 190 | 191 | ```sh 192 | uv run python -m server 193 | ``` 194 | 195 | The server should start on http://127.0.0.1:8001 (if you want to change the port number from 8001 then modify this in the .env file) after loading all the models for the tasks you have selected. 196 | -------------------------------------------------------------------------------- /mammal/task.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Any 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 7 | from fuse.eval import MetricBase 8 | from fuse.eval.metrics.sequence_gen.metrics_seq_gen_common import MetricPerplexity 9 | 10 | from mammal.keys import * # noqa 11 | from mammal.losses import LossHead, ScalarsPredictionsLoss 12 | from mammal.metrics import MetricSeqAccuracy 13 | 14 | 15 | class MammalTask: 16 | """ 17 | A class that holds all the requirements to define a new task. 18 | A new task expected to inherit and override all the necessary methods. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | *, 24 | name: str, 25 | logger: Any, 26 | tokenizer_op: ModularTokenizerOp, 27 | scalars_loss_weight: float = 1.0, 28 | scalars_loss_type: str = "mse", 29 | ) -> None: 30 | """ 31 | Args: 32 | :param loss: controls which loss function is used. Supported options are: ['ce', 'focal'] 33 | Note: for full control over the calculated loss(es), you can override the method "losses" 34 | :param loss_weight: multiplication factor for the loss. 35 | """ 36 | self._logger = logger 37 | self._tokenizer_op = tokenizer_op 38 | self._name = name 39 | self._scalars_loss_weight = scalars_loss_weight 40 | self._scalars_loss_type = scalars_loss_type 41 | 42 | def name(self) -> str: 43 | return self._name 44 | 45 | def data_module(self) -> pl.LightningDataModule: 46 | """ 47 | Return a lightning data module for the task. 48 | The dataloaders implemented in this datamodule expect to iterate over batches. 49 | Each batch represented by a dict. 50 | In the dictionary the following key-value pairs must be set: 51 | 52 | mammal.keys.ENCODER_INPUTS_STR # the original string representation of encoder input - used for debug 53 | mammal.keys.ENCODER_INPUTS_TOKENS # encoder input token ids 54 | mammal.keys.ENCODER_INPUTS_ATTENTION_MASK # attention mask of the tokenized encoder input (output of the tokenizer) 55 | 56 | mammal.keys.LABELS_STR # the original string representation of labels - used for debug 57 | mammal.keys.LABELS_TOKENS # labels token ids 58 | mammal.keys.LABELS_ATTENTION_MASK # attention mask of the tokenized labels (output of the tokenizer) 59 | 60 | In an encoder-decoder mode, also the following expected to be set: 61 | mammal.keys.DECODER_INPUTS_STR # the original string representation of decoder input - used for debug 62 | mammal.keys.DECODER_INPUTS_TOKENS # decoder input token ids (decoder start token followed by labels token ids) 63 | mammal.keys.DECODER_INPUTS_ATTENTION_MASK # attention mask of the tokenized decoder input (output of the tokenizer) 64 | 65 | """ 66 | raise NotImplementedError() 67 | 68 | def losses(self) -> dict[str, torch.nn.Module]: 69 | """ 70 | Returns dictionary of losses. The total loss will be the sum of all losses. 71 | Each loss element represented by a pytorch module that gets a batch represented by a dictionary 72 | The implementation is typical and work for most cases. 73 | It is the sum of a cross-entropy loss applied on any label != -100 and an mse of any scalar value. 74 | """ 75 | all_losses = {} 76 | 77 | loss_object = LossHead(loss_type="ce") 78 | all_losses[f"{self.name()}_ce"] = loss_object 79 | 80 | # scalars 81 | loss_object = ScalarsPredictionsLoss( 82 | loss_type=self._scalars_loss_type, 83 | loss_weight=self._scalars_loss_weight, 84 | ) 85 | all_losses[f"{self.name()}_scalars_mse"] = loss_object 86 | 87 | return all_losses 88 | 89 | def train_metrics(self) -> dict[str, MetricBase]: 90 | """ 91 | Fuse Metrics for trainset 92 | """ 93 | return self.get_metrics(is_train=True) 94 | 95 | def validation_metrics(self) -> dict[str, MetricBase]: 96 | """ 97 | Fuse Metrics for validationset 98 | """ 99 | return self.get_metrics(is_train=False) 100 | 101 | def get_metrics(self, is_train: bool) -> dict[str, MetricBase]: 102 | """ 103 | Default metrics to use: Perplexity, and Token Accuracy 104 | """ 105 | return OrderedDict( 106 | [ 107 | ( 108 | f"{self.name()}_perplexity", 109 | MetricPerplexity( 110 | preds=SCORES, target=LABELS_TOKENS, ignore_index=-100 111 | ), 112 | ), 113 | ( 114 | f"{self.name()}_token_acc", 115 | MetricSeqAccuracy(pred=CLS_PRED, target=LABELS_TOKENS), 116 | ), 117 | ] 118 | ) 119 | 120 | @staticmethod 121 | def data_preprocessing(sample_dict: dict, *args: list, **kwargs: dict) -> str: 122 | """ 123 | The point of this method is to get a task specific input (in a way that is easy to provide, for example, AA sequences), 124 | and to construct a query for the model. 125 | A query built from encoder_input and for if available for training also labels and decoder_input 126 | See examples in mammal/examples/protein_solubility/task.py 127 | 128 | This function also responsible to tokenize the query and to set any label that should participate in loss to -100. 129 | 130 | The function will get sample_dict with all the raw sample data and should add the following keys. 131 | 132 | mammal.keys.ENCODER_INPUTS_STR # the original string representation of encoder input - used for debug 133 | mammal.keys.ENCODER_INPUTS_TOKENS # encoder input token ids 134 | mammal.keys.ENCODER_INPUTS_ATTENTION_MASK # attention mask of the tokenized encoder input (output of the tokenizer) 135 | 136 | And if available for training also: 137 | mammal.keys.DECODER_INPUTS_STR # the original string representation of decoder input - used for debug 138 | mammal.keys.DECODER_INPUTS_TOKENS # decoder input token ids (decoder start token followed by labels token ids) 139 | mammal.keys.DECODER_INPUTS_ATTENTION_MASK # attention mask of the tokenized decoder input (output of the tokenizer) 140 | 141 | mammal.keys.LABELS_STR # the original string representation of labels - used for debug 142 | mammal.keys.LABELS_TOKENS # labels token ids 143 | mammal.keys.LABELS_ATTENTION_MASK # attention mask of the tokenized labels (output of the tokenizer) 144 | 145 | """ 146 | raise NotImplementedError() 147 | 148 | @staticmethod 149 | def process_model_output( 150 | tokenizer_op: ModularTokenizerOp, verbose: bool = False, **kwargs: dict 151 | ) -> dict: 152 | """ 153 | The point of this method is to process model output in a way that extract the key meaningful values from it. 154 | Some task will not expect encoder_output (encoder-decoder tasks) and some task will not expect decoder_output (encoder-only tasks) 155 | logits is expected in tasks that have predictive aspects, for example in binding binary classification. 156 | 157 | See examples in mammal/examples/protein_solubility/task.py 158 | 159 | """ 160 | raise NotImplementedError() 161 | -------------------------------------------------------------------------------- /mammal/metrics.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from functools import partial 3 | 4 | import numpy as np 5 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 6 | from fuse.eval.metrics.classification.metrics_classification_common import ( 7 | MetricAccuracy, 8 | MetricAUCROC, 9 | MetricMCC, 10 | ) 11 | from fuse.eval.metrics.metrics_common import MetricBase, MetricDefault 12 | from fuse.eval.metrics.regression.metrics import ( 13 | MetricMAE, 14 | MetricMSE, 15 | MetricPearsonCorrelation, 16 | MetricR2, 17 | MetricRMSE, 18 | MetricSpearmanCorrelation, 19 | ) 20 | 21 | from mammal.keys import * 22 | 23 | """ 24 | Generic LM metrics used by default for each task in this generic multitask t5 implementation 25 | """ 26 | 27 | 28 | class MetricSeqAccuracy(MetricDefault): 29 | """ 30 | Accuracy of a generated sequence. (Token Accuracy) 31 | """ 32 | 33 | def __init__( 34 | self, 35 | pred: str | None = None, 36 | target: str | None = None, 37 | ignore_index: int | None = -100, 38 | **kwargs: dict, 39 | ): 40 | """ 41 | :param pred: batch_dict key to class! predictions 42 | :param target: batch_dict to labels 43 | :param ignore_index: ignore this label values in the returned score 44 | """ 45 | super().__init__( 46 | metric_func=self.sequence_accuracy, pred=pred, target=target, **kwargs 47 | ) 48 | self.ignore_index = ignore_index 49 | 50 | def sequence_accuracy( 51 | self, 52 | pred: list[np.ndarray], 53 | target: list[np.ndarray], 54 | sample_weight: list[np.ndarray] | None = None, 55 | ) -> float: 56 | if isinstance(pred, list): 57 | pred = np.concatenate(pred) 58 | target = np.concatenate(target) 59 | if sample_weight is not None: 60 | sample_weight = np.concatenate(sample_weight) 61 | 62 | assert pred.shape == target.shape, f"shape does not match {pred.shape=} <> {target.shape=}" # type: ignore[attr-defined] 63 | 64 | indices = target != self.ignore_index 65 | are_same_indicators = pred[indices] == target[indices] 66 | if sample_weight is None: 67 | return np.sum(are_same_indicators) / np.sum(indices) 68 | 69 | sample_weight = sample_weight[indices] 70 | return np.sum(are_same_indicators * sample_weight) / (np.sum(sample_weight)) 71 | 72 | 73 | def classification_metrics( 74 | name: str, 75 | class_position: int, 76 | class_tokens: list[str], 77 | tokenizer_op: ModularTokenizerOp, 78 | scores_key: str = SCORES, 79 | cls_preds_key: str = CLS_PRED, 80 | labels_key: str = LABELS_TOKENS, 81 | ) -> dict[str, MetricBase]: 82 | """ 83 | Recommended metrics for classification AUC, Accuracy and MCC 84 | :param name: task name 85 | :param class_position: position (index) in the labels/predictions sequence length of the classification token 86 | :param class_tokens: list of possible class tokens 87 | :param tokenizer_op: tokenizer instance 88 | :param scores_key: batch_dict key that points to predictions scores 89 | :param cls_preds_key: batch dict key that points to predictions 90 | :param labels: batch dict key that points to labels 91 | """ 92 | metrics = {} 93 | 94 | class_token_ids = [ 95 | tokenizer_op.get_token_id(class_token) for class_token in class_tokens 96 | ] 97 | token_id_to_class_index = { 98 | token_id: cls_index for cls_index, token_id in enumerate(class_token_ids) 99 | } 100 | 101 | # this mode assumes the predicted class is always on a specific label 102 | pre_collect_fn = partial( 103 | extract_classification_predictions_and_labels, 104 | class_token_ids=class_token_ids, 105 | token_id_to_class_index=token_id_to_class_index, 106 | seq_pos=class_position, 107 | ) 108 | 109 | metrics[f"{name}_aucroc"] = MetricAUCROC( 110 | pred=scores_key, 111 | target=labels_key, 112 | batch_pre_collect_process_func=pre_collect_fn, 113 | ) 114 | metrics[f"{name}_acc"] = MetricAccuracy( 115 | pred=cls_preds_key, 116 | target=labels_key, 117 | batch_pre_collect_process_func=pre_collect_fn, 118 | ) 119 | metrics[f"{name}_mcc"] = MetricMCC( 120 | pred=cls_preds_key, 121 | target=labels_key, 122 | batch_pre_collect_process_func=pre_collect_fn, 123 | ) 124 | return metrics 125 | 126 | 127 | # extract specific positions from a batch_dict according to label/out_key 128 | def extract_classification_predictions_and_labels( 129 | batch_dict: dict, 130 | *, 131 | class_token_ids: list[int], 132 | token_id_to_class_index: dict[int, int], 133 | labels_key: str = LABELS_TOKENS, 134 | cls_preds_key: str = CLS_PRED, 135 | scores_key: str = SCORES, 136 | seq_pos: int = 1, 137 | ) -> dict: 138 | """ 139 | Extract the predictions and labels and convert them from vocabulary space to class index space 140 | This function currently optimized for a single gpu. For multi-gpu the returned tensors should be moved back to gpu. 141 | :param class_token_ids: list of ids of the class tokens 142 | :param token_id_to_class_index: mapping from token-id to index in class_token_ids. 143 | :param labels_key: batch_dict key which points to labels 144 | :param cls_preds_key: batch_dict key which points to cls_preds 145 | :param scores_key: batch_dict key which points to scores 146 | :param seq_pos: the position of the class token in labels 147 | """ 148 | device_labels = batch_dict[labels_key].device 149 | device_cls_preds = batch_dict[cls_preds_key].device 150 | 151 | labels = batch_dict[labels_key][:, seq_pos].cpu() 152 | cls_preds = batch_dict[cls_preds_key][:, seq_pos].contiguous().cpu() 153 | scores = batch_dict[scores_key][:, seq_pos, class_token_ids].contiguous() 154 | 155 | classification_labels = labels.apply_( 156 | lambda x: token_id_to_class_index.get(x, len(class_token_ids)) 157 | ) 158 | classification_cls_preds = cls_preds.apply_( 159 | lambda x: token_id_to_class_index.get(x, len(class_token_ids)) 160 | ) 161 | 162 | return { 163 | labels_key: classification_labels.to(device=device_labels), 164 | cls_preds_key: classification_cls_preds.to(device=device_cls_preds), 165 | scores_key: scores, 166 | } 167 | 168 | 169 | def regression_metrics( 170 | name: str, 171 | pred_scalars_key: str = SCALARS_PREDICTION_HEAD_LOGITS, 172 | target_scalars_key: str = LABELS_SCALARS_VALUES, 173 | process_func: Callable | None = None, 174 | ) -> dict[str, MetricBase]: 175 | """ 176 | Typical metrics for regression tasks: includes MetricPearsonCorrelation, MetricSpearmanCorrelation, MetricMAE, MetricMSE, MetricRMSE, MetricR2 177 | :param pred_scalars_key: key to scalar prediction (after it was extracted from model output) 178 | :param target_scalars_key: key to ground truth scalar. 179 | :param process_func: a function that extract the actual relevant scalar from model output and store it in batch_dict. 180 | """ 181 | metrics = {} 182 | metrics[f"{name}_pcorr"] = MetricPearsonCorrelation( 183 | pred=pred_scalars_key, 184 | target=target_scalars_key, 185 | batch_pre_collect_process_func=process_func, 186 | mask=None, 187 | ) 188 | metrics[f"{name}_spearcorr"] = MetricSpearmanCorrelation( 189 | pred=pred_scalars_key, 190 | target=target_scalars_key, 191 | batch_pre_collect_process_func=process_func, 192 | mask=None, 193 | ) 194 | metrics[f"{name}_mae"] = MetricMAE( 195 | pred=pred_scalars_key, 196 | target=target_scalars_key, 197 | batch_pre_collect_process_func=process_func, 198 | ) 199 | metrics[f"{name}_mse"] = MetricMSE( 200 | pred=pred_scalars_key, 201 | target=target_scalars_key, 202 | batch_pre_collect_process_func=process_func, 203 | ) 204 | 205 | metrics[f"{name}_rmse"] = MetricRMSE( 206 | pred=pred_scalars_key, 207 | target=target_scalars_key, 208 | batch_pre_collect_process_func=process_func, 209 | ) 210 | 211 | metrics[f"{name}_r2"] = MetricR2( 212 | pred=pred_scalars_key, 213 | target=target_scalars_key, 214 | batch_pre_collect_process_func=process_func, 215 | ) 216 | 217 | return metrics 218 | -------------------------------------------------------------------------------- /mammal/examples/protein_solubility/task.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 7 | 8 | from mammal.examples.protein_solubility.pl_data_module import ( 9 | ProteinSolubilityDataModule, 10 | ) 11 | from mammal.keys import ( 12 | CLS_PRED, 13 | DECODER_INPUTS_ATTENTION_MASK, 14 | DECODER_INPUTS_STR, 15 | DECODER_INPUTS_TOKENS, 16 | ENCODER_INPUTS_ATTENTION_MASK, 17 | ENCODER_INPUTS_STR, 18 | ENCODER_INPUTS_TOKENS, 19 | LABELS_ATTENTION_MASK, 20 | LABELS_STR, 21 | LABELS_TOKENS, 22 | SCORES, 23 | ) 24 | from mammal.metrics import classification_metrics 25 | from mammal.task import ( 26 | MammalTask, 27 | MetricBase, 28 | ) 29 | 30 | 31 | class ProteinSolubilityTask(MammalTask): 32 | def __init__( 33 | self, 34 | *, 35 | name: str, 36 | tokenizer_op: ModularTokenizerOp, 37 | data_module_kwargs: dict, 38 | seed: int, 39 | logger: Any | None = None, 40 | ) -> None: 41 | super().__init__( 42 | name=name, 43 | logger=logger, 44 | tokenizer_op=tokenizer_op, 45 | ) 46 | self._data_module_kwargs = data_module_kwargs 47 | self._seed = seed 48 | 49 | self.preds_key = CLS_PRED 50 | self.scores_key = SCORES 51 | self.labels_key = LABELS_TOKENS 52 | 53 | def data_module(self) -> pl.LightningDataModule: 54 | return ProteinSolubilityDataModule( 55 | tokenizer_op=self._tokenizer_op, 56 | seed=self._seed, 57 | data_preprocessing=self.data_preprocessing, 58 | **self._data_module_kwargs, 59 | ) 60 | 61 | def train_metrics(self) -> dict[str, MetricBase]: 62 | metrics = super().train_metrics() 63 | metrics.update( 64 | classification_metrics( 65 | self.name(), 66 | class_position=1, 67 | tokenizer_op=self._tokenizer_op, 68 | class_tokens=["<0>", "<1>"], 69 | ) 70 | ) 71 | 72 | return metrics 73 | 74 | def validation_metrics(self) -> dict[str, MetricBase]: 75 | validation_metrics = super().validation_metrics() 76 | validation_metrics.update( 77 | classification_metrics( 78 | self.name(), 79 | class_position=1, 80 | tokenizer_op=self._tokenizer_op, 81 | class_tokens=["<0>", "<1>"], 82 | ) 83 | ) 84 | return validation_metrics 85 | 86 | @staticmethod 87 | def data_preprocessing( 88 | sample_dict: dict, 89 | *, 90 | protein_sequence_key: str, 91 | tokenizer_op: ModularTokenizerOp, 92 | solubility_label_key: int | None = None, 93 | protein_max_seq_length: int = 1250, 94 | encoder_input_max_seq_len: int | None = 1260, 95 | labels_max_seq_len: int | None = 4, 96 | device: str | torch.device = "cpu", 97 | ) -> dict: 98 | """ 99 | :param sample_dict: a dictionary with raw data 100 | :param protein_sequence_key: sample_dict key which points to protein sequence 101 | :param solubility_label_key: sample_dict key which points to label 102 | :param protein_max_seq_length: max sequence length of a protein. Will be used to truncate the protein 103 | :param encoder_input_max_seq_len: max sequence length of labels. Will be used to truncate/pad the encoder_input. 104 | :param labels_max_seq_len: max sequence length of labels. Will be used to truncate/pad the labels. 105 | :param tokenizer_op: tokenizer op 106 | 107 | """ 108 | protein_sequence = sample_dict[protein_sequence_key] 109 | solubility_label = sample_dict.get(solubility_label_key, None) 110 | 111 | sample_dict[ENCODER_INPUTS_STR] = ( 112 | f"<@TOKENIZER-TYPE=AA><@TOKENIZER-TYPE=AA@MAX-LEN={protein_max_seq_length}>{protein_sequence}" 113 | ) 114 | tokenizer_op( 115 | sample_dict=sample_dict, 116 | key_in=ENCODER_INPUTS_STR, 117 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS, 118 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK, 119 | max_seq_len=encoder_input_max_seq_len, 120 | ) 121 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor( 122 | sample_dict[ENCODER_INPUTS_TOKENS], device=device 123 | ) 124 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor( 125 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK], device=device 126 | ) 127 | 128 | if solubility_label is not None: 129 | pad_id = tokenizer_op.get_token_id("") 130 | ignore_token_value = -100 131 | sample_dict[LABELS_STR] = ( 132 | f"<@TOKENIZER-TYPE=AA><{solubility_label}>" 133 | ) 134 | tokenizer_op( 135 | sample_dict=sample_dict, 136 | key_in=LABELS_STR, 137 | key_out_tokens_ids=LABELS_TOKENS, 138 | key_out_attention_mask=LABELS_ATTENTION_MASK, 139 | max_seq_len=labels_max_seq_len, 140 | ) 141 | sample_dict[LABELS_TOKENS] = torch.tensor( 142 | sample_dict[LABELS_TOKENS], device=device 143 | ) 144 | sample_dict[LABELS_ATTENTION_MASK] = torch.tensor( 145 | sample_dict[LABELS_ATTENTION_MASK], device=device 146 | ) 147 | # replace pad_id with -100 to 148 | pad_id_tns = torch.tensor(pad_id) 149 | sample_dict[LABELS_TOKENS][ 150 | (sample_dict[LABELS_TOKENS][..., None] == pad_id_tns).any(-1).nonzero() 151 | ] = ignore_token_value 152 | 153 | sample_dict[DECODER_INPUTS_STR] = ( 154 | f"<@TOKENIZER-TYPE=AA><{solubility_label}>" 155 | ) 156 | tokenizer_op( 157 | sample_dict=sample_dict, 158 | key_in=DECODER_INPUTS_STR, 159 | key_out_tokens_ids=DECODER_INPUTS_TOKENS, 160 | key_out_attention_mask=DECODER_INPUTS_ATTENTION_MASK, 161 | max_seq_len=labels_max_seq_len, 162 | ) 163 | sample_dict[DECODER_INPUTS_TOKENS] = torch.tensor( 164 | sample_dict[DECODER_INPUTS_TOKENS], device=device 165 | ) 166 | sample_dict[DECODER_INPUTS_ATTENTION_MASK] = torch.tensor( 167 | sample_dict[DECODER_INPUTS_ATTENTION_MASK], device=device 168 | ) 169 | 170 | return sample_dict 171 | 172 | @staticmethod 173 | def process_model_output( 174 | tokenizer_op: ModularTokenizerOp, 175 | decoder_output: np.ndarray, 176 | decoder_output_scores: np.ndarray, 177 | ) -> dict: 178 | """ 179 | Extract predicted solubility class and scores 180 | expecting decoder output to be <0> or <1> 181 | note - the normalized version will calculate the positive ('<1>') score divided by the sum of the scores for both '<0>' and '<1>' 182 | BE CAREFUL as both negative and positive absolute scores can be drastically low, and normalized score could be very high. 183 | outputs a dictionary containing: 184 | dict( 185 | predicted_token_str = #... e.g. '<1>' 186 | not_normalized_score = #the score for the positive token... e.g. 0.01 187 | normalized_score = #... (positive_token_score) / (positive_token_score+negative_token_score) 188 | ) 189 | if there is any error in parsing the model output, None is returned. 190 | """ 191 | 192 | negative_token_id = tokenizer_op.get_token_id("<0>") 193 | positive_token_id = tokenizer_op.get_token_id("<1>") 194 | label_id_to_int = { 195 | negative_token_id: 0, 196 | positive_token_id: 1, 197 | } 198 | classification_position = 1 199 | 200 | if decoder_output_scores is not None: 201 | not_normalized_score = decoder_output_scores[ 202 | classification_position, positive_token_id 203 | ] 204 | normalized_score = not_normalized_score / ( 205 | not_normalized_score 206 | + decoder_output_scores[classification_position, negative_token_id] 207 | + 1e-10 208 | ) 209 | ans = dict( 210 | pred=label_id_to_int.get(int(decoder_output[classification_position]), -1), 211 | not_normalized_scores=not_normalized_score, 212 | normalized_scores=normalized_score, 213 | ) 214 | 215 | return ans 216 | -------------------------------------------------------------------------------- /mammal/examples/dti_bindingdb_kd/task.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 7 | 8 | from mammal.examples.dti_bindingdb_kd.pl_data_module import ( 9 | DtiBindingdbKdDataModule, 10 | ) 11 | from mammal.keys import * 12 | from mammal.metrics import regression_metrics 13 | from mammal.task import ( 14 | MammalTask, 15 | MetricBase, 16 | ) 17 | 18 | 19 | class DtiBindingdbKdTask(MammalTask): 20 | def __init__( 21 | self, 22 | *, 23 | name: str, 24 | tokenizer_op: ModularTokenizerOp, 25 | data_module_kwargs: dict, 26 | seed: int, 27 | logger: Any | None = None, 28 | norm_y_mean: float = 0.0, 29 | norm_y_std: float = 1.0, 30 | ) -> None: 31 | """ 32 | :param name: task name. used for to log metrics and losses 33 | :param tokenizer op: the tokenizer used 34 | :param data_module_kwargs: arguments for data module constructor 35 | :param seed: seed for random operations. 36 | :param logger: typically clearml logger. Optional. 37 | :param norm_y_mean: Used to normalize the values. Metrics will still be calculated with the original values for a fair evaluation. 38 | Default value means - no normalization 39 | :param norm_y_std: Used to normalize the values. Metrics will still be calculated with the original values for a fair evaluation. 40 | Default value means - no normalization 41 | 42 | """ 43 | super().__init__( 44 | name=name, 45 | logger=logger, 46 | tokenizer_op=tokenizer_op, 47 | ) 48 | self._data_module_kwargs = data_module_kwargs 49 | self._seed = seed 50 | 51 | self.preds_key = CLS_PRED 52 | self.scores_key = SCORES 53 | self.labels_key = LABELS_TOKENS 54 | 55 | self.norm_y_mean = norm_y_mean 56 | self.norm_y_std = norm_y_std 57 | 58 | def data_module(self) -> pl.LightningDataModule: 59 | return DtiBindingdbKdDataModule( 60 | tokenizer_op=self._tokenizer_op, 61 | seed=self._seed, 62 | data_preprocessing=partial( 63 | self.data_preprocessing, 64 | norm_y_mean=self.norm_y_mean, 65 | norm_y_std=self.norm_y_std, 66 | ), 67 | **self._data_module_kwargs, 68 | ) 69 | 70 | def train_metrics(self) -> dict[str, MetricBase]: 71 | metrics = super().train_metrics() 72 | metrics.update( 73 | regression_metrics( 74 | self.name(), 75 | process_func=partial( 76 | self.process_model_output, 77 | norm_y_mean=self.norm_y_mean, 78 | norm_y_std=self.norm_y_std, 79 | ), 80 | pred_scalars_key="model.out.dti_bindingdb_kd", 81 | target_scalars_key="Y", 82 | ) 83 | ) 84 | 85 | return metrics 86 | 87 | def validation_metrics(self) -> dict[str, MetricBase]: 88 | validation_metrics = super().validation_metrics() 89 | validation_metrics.update( 90 | regression_metrics( 91 | self.name(), 92 | process_func=partial( 93 | self.process_model_output, 94 | norm_y_mean=self.norm_y_mean, 95 | norm_y_std=self.norm_y_std, 96 | ), 97 | pred_scalars_key="model.out.dti_bindingdb_kd", 98 | target_scalars_key="Y", 99 | ) 100 | ) 101 | return validation_metrics 102 | 103 | @staticmethod 104 | def data_preprocessing( 105 | sample_dict: dict, 106 | *, 107 | target_sequence_key: str, 108 | drug_sequence_key: str, 109 | ground_truth_key: int | None = None, 110 | target_max_seq_length: int = 1250, 111 | drug_max_seq_length: int = 256, 112 | encoder_input_max_seq_len: int = 1512, 113 | tokenizer_op: ModularTokenizerOp, 114 | norm_y_mean: float, 115 | norm_y_std: float, 116 | device: str | torch.device = "cpu", 117 | ) -> dict[str, Any]: 118 | """ 119 | :param norm_y_mean: Used to normalize the values. Metrics will still be calculated with the original values for a fair evaluation. 120 | Default value means - no normalization 121 | :param norm_y_std: Used to normalize the values. Metrics will still be calculated with the original values for a fair evaluation. 122 | Default value means - no normalization 123 | """ 124 | target_sequence = sample_dict[target_sequence_key] 125 | drug_sequence = sample_dict[drug_sequence_key] 126 | ground_truth_value = sample_dict.get(ground_truth_key, None) 127 | 128 | sample_dict[ENCODER_INPUTS_STR] = ( 129 | "<@TOKENIZER-TYPE=AA>" 130 | f"<@TOKENIZER-TYPE=AA@MAX-LEN={target_max_seq_length}>{target_sequence}" 131 | f"<@TOKENIZER-TYPE=SMILES@MAX-LEN={drug_max_seq_length}>{drug_sequence}" 132 | "" 133 | ) 134 | tokenizer_op( 135 | sample_dict, 136 | key_in=ENCODER_INPUTS_STR, 137 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS, 138 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK, 139 | max_seq_len=encoder_input_max_seq_len, 140 | key_out_scalars=ENCODER_INPUTS_SCALARS, 141 | ) 142 | 143 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor( 144 | sample_dict[ENCODER_INPUTS_TOKENS], device=device 145 | ) 146 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor( 147 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK], 148 | device=device, 149 | ) 150 | 151 | if ground_truth_value is not None: 152 | ground_truth_value = (ground_truth_value - norm_y_mean) / norm_y_std 153 | pad_id = tokenizer_op.get_token_id("") 154 | ignore_token_value = -100 155 | sample_dict[LABELS_STR] = ( 156 | f"<@TOKENIZER-TYPE=SCALARS_LITERALS>{ground_truth_value}<@TOKENIZER-TYPE=AA>" 157 | + "".join([""] * (encoder_input_max_seq_len - 1)) 158 | ) 159 | 160 | tokenizer_op( 161 | sample_dict, 162 | key_in=LABELS_STR, 163 | key_out_tokens_ids=LABELS_TOKENS, 164 | key_out_attention_mask=LABELS_ATTENTION_MASK, 165 | max_seq_len=encoder_input_max_seq_len, 166 | key_out_scalars=LABELS_SCALARS, 167 | validate_ends_with_eos=False, 168 | ) 169 | 170 | sample_dict[LABELS_TOKENS] = torch.tensor( 171 | sample_dict[LABELS_TOKENS], device=device 172 | ) 173 | sample_dict[LABELS_ATTENTION_MASK] = torch.tensor( 174 | sample_dict[LABELS_ATTENTION_MASK], device=device 175 | ) 176 | # replace pad_id with -100 to 177 | pad_id_tns = torch.tensor(pad_id) 178 | sample_dict[LABELS_TOKENS][ 179 | (sample_dict[LABELS_TOKENS][..., None] == pad_id_tns).any(-1).nonzero() 180 | ] = ignore_token_value 181 | 182 | sample_dict[LABELS_SCALARS_VALUES] = sample_dict[LABELS_SCALARS_VALUES].to( 183 | device=device 184 | ) 185 | sample_dict[LABELS_SCALARS_VALID_MASK] = sample_dict[ 186 | LABELS_SCALARS_VALID_MASK 187 | ].to(device=device) 188 | 189 | return sample_dict 190 | 191 | @staticmethod 192 | def process_model_output( 193 | batch_dict: dict, 194 | *, 195 | scalars_preds_key: str = SCALARS_PREDICTION_HEAD_LOGITS, 196 | scalars_preds_processed_key: str = "model.out.dti_bindingdb_kd", 197 | norm_y_mean: float, 198 | norm_y_std: float, 199 | ) -> dict: 200 | """ 201 | :param norm_y_mean: Used to normalize the values. Metrics will still be calculated with the original values for a fair evaluation. 202 | Default value means - no normalization 203 | :param norm_y_std: Used to normalize the values. Metrics will still be calculated with the original values for a fair evaluation. 204 | Default value means - no normalization 205 | """ 206 | scalars_preds = batch_dict[scalars_preds_key] 207 | 208 | batch_dict[scalars_preds_processed_key] = ( 209 | scalars_preds[:, 0] * norm_y_std + norm_y_mean 210 | ) 211 | 212 | return batch_dict 213 | -------------------------------------------------------------------------------- /mammal/examples/scrna_cell_type/pl_data_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections.abc import Callable 3 | from pathlib import Path 4 | from typing import Any 5 | 6 | import anndata 7 | import numpy as np 8 | import pandas as pd 9 | import pytorch_lightning as pl 10 | import scanpy as sc 11 | from anndata_op import OpReadAnnData 12 | from fuse.data.datasets.dataset_default import DatasetDefault 13 | from fuse.data.pipelines.pipeline_default import PipelineDefault 14 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 15 | from fuse.data.utils.collates import CollateDefault 16 | from sklearn.model_selection import train_test_split 17 | from torch.utils.data.dataloader import DataLoader 18 | 19 | from mammal.keys import * # noqa 20 | 21 | 22 | class CellTypeDataModule(pl.LightningDataModule): 23 | skip_keys = [ 24 | "scrna.gene_names", 25 | "scrna.scrna", 26 | ] 27 | 28 | def __init__( 29 | self, 30 | *, 31 | data_path: str, 32 | batch_size: int, 33 | tokenizer_op: ModularTokenizerOp, 34 | data_preprocessing: Callable, 35 | train_dl_kwargs: dict, 36 | valid_dl_kwargs: dict, 37 | label_name="cell_type", 38 | input_max_seq_length: int = 500, 39 | encoder_input_max_seq_len: int = 512, 40 | labels_max_seq_len: int = 20, 41 | seed: int = 42, 42 | stratify_by_label=False, 43 | ) -> None: 44 | """_summary_ 45 | Args: 46 | data_path (str): path to the raw data, if not exist, will download the data to the given path. 47 | batch_size (int): batch size 48 | tokenizer_op (ModularTokenizerOp): tokenizer op 49 | encoder_inputs_max_seq_len: max tokenizer sequence length for the encoder inputs, 50 | labels_max_seq_len: max tokenizer sequence length for the labels, 51 | train_dl_kwargs (dict): train dataloader constructor parameters 52 | valid_dl_kwargs (dict): validation dataloader constructor parameters 53 | seed (int): random seed 54 | """ 55 | super().__init__() 56 | self.data_path = data_path 57 | self.tokenizer_op = tokenizer_op 58 | self.input_max_seq_length = input_max_seq_length 59 | self.encoder_input_max_seq_len = encoder_input_max_seq_len 60 | self.labels_max_seq_len = labels_max_seq_len 61 | self.batch_size = batch_size 62 | self.train_dl_kwargs = train_dl_kwargs 63 | self.valid_dl_kwargs = valid_dl_kwargs 64 | self.label_name = label_name 65 | self.seed = seed 66 | self.data_preprocessing = data_preprocessing 67 | self.pad_token_id = self.tokenizer_op.get_token_id("") 68 | self.ds_dict: dict[str, Any] = {} 69 | if stratify_by_label: 70 | self.stratify_by = self.label_name 71 | else: 72 | self.stratify_by = None 73 | 74 | def setup(self, stage: str) -> None: 75 | self.ds_dict = load_datasets( 76 | data_path=self.data_path, stratify_by=self.stratify_by 77 | ) 78 | 79 | task_pipeline = [ 80 | ( 81 | # Prepare the input string(s) in modular tokenizer input format 82 | self.data_preprocessing, 83 | dict( 84 | sequence_key="scrna.scrna", 85 | label_key="data.label", 86 | tokenizer_op=self.tokenizer_op, 87 | input_max_seq_length=self.input_max_seq_length, 88 | encoder_input_max_seq_len=self.encoder_input_max_seq_len, 89 | labels_max_seq_len=self.labels_max_seq_len, 90 | ), 91 | ), 92 | ] 93 | 94 | for ds in self.ds_dict.values(): 95 | ds.dynamic_pipeline.extend(task_pipeline) 96 | 97 | def train_dataloader(self) -> DataLoader: 98 | train_loader = DataLoader( 99 | dataset=self.ds_dict["train"], 100 | batch_size=self.batch_size, 101 | collate_fn=self.collate_fn(), 102 | shuffle=True, 103 | **self.train_dl_kwargs, 104 | ) 105 | return train_loader 106 | 107 | def val_dataloader(self) -> DataLoader: 108 | val_loader = DataLoader( 109 | self.ds_dict["valid"], 110 | batch_size=self.batch_size, 111 | collate_fn=self.collate_fn(), 112 | **self.valid_dl_kwargs, 113 | ) 114 | 115 | return val_loader 116 | 117 | def test_dataloader(self) -> DataLoader: 118 | test_loader = DataLoader( 119 | self.ds_dict["test"], 120 | batch_size=self.batch_size, 121 | collate_fn=self.collate_fn(), 122 | **self.valid_dl_kwargs, 123 | ) 124 | 125 | return test_loader 126 | 127 | def predict_dataloader(self) -> DataLoader: 128 | return self.test_dataloader() 129 | 130 | def collate_fn(self): 131 | return CollateDefault(skip_keys=self.skip_keys) 132 | 133 | 134 | def anndata_train_test_split( 135 | h5ad_file, test_size=0.1, random_state=42, stratify_by=None 136 | ): 137 | 138 | if stratify_by is not None: 139 | stratify = h5ad_file.obs[stratify_by] 140 | else: 141 | stratify = None 142 | 143 | train_ids, valid_ids = train_test_split( 144 | h5ad_file.obs, 145 | test_size=test_size, 146 | random_state=random_state, 147 | stratify=stratify, 148 | ) 149 | train_adata = h5ad_file[train_ids.index] 150 | validata_adata = h5ad_file[valid_ids.index] 151 | return train_adata, validata_adata 152 | 153 | 154 | def load_datasets( 155 | data_path: str | Path = "data", stratify_by=None 156 | ) -> dict[str, DatasetDefault]: 157 | 158 | data_path = Path(data_path) 159 | if not data_path.is_absolute(): 160 | data_path = Path(__file__).parent / data_path 161 | anndata_object = anndata.read_h5ad(data_path) 162 | anndata_dict = {} 163 | anndata_dict["all_data"] = anndata_object 164 | anndata_dict["all_train"], anndata_dict["test"] = anndata_train_test_split( 165 | anndata_dict["all_data"], 166 | test_size=0.1, 167 | stratify_by=stratify_by, 168 | ) 169 | anndata_dict["train"], anndata_dict["valid"] = anndata_train_test_split( 170 | anndata_dict["all_train"], 171 | test_size=0.1 / (1.0 - 0.1), 172 | random_state=2024, 173 | stratify_by=stratify_by, 174 | ) 175 | 176 | ds_dict = {} 177 | for set_name in ["train", "valid", "test"]: 178 | input_anndata = anndata_dict[set_name] 179 | size = input_anndata.shape[0] 180 | print(f"{set_name} set size is {size}") 181 | 182 | dynamic_pipeline = PipelineDefault( 183 | "cell_type", 184 | [ 185 | (OpReadAnnData(input_anndata), {"prefix": "scrna"}), 186 | ], 187 | ) 188 | 189 | ds = DatasetDefault(sample_ids=size, dynamic_pipeline=dynamic_pipeline) 190 | ds.create() 191 | ds_dict[set_name] = ds 192 | 193 | return ds_dict 194 | 195 | 196 | def load_cell_type_mapping( 197 | mapping_key="cell_type", 198 | mapping_value="cell_type_ontology_term_id", 199 | cell_type_mapping="cell_type_mapping.csv", 200 | ): 201 | """ 202 | Load metadata_extra_mapping.csv from the given dataset metadata folder, 203 | and return the values of a requested key and value columns as a dictionary. 204 | 205 | This is used to convert the names from the ones in the input anndata to the 206 | ones that are known to the tokenizer. 207 | """ 208 | cell_type_mapping_file_path = Path(__file__).parent / cell_type_mapping 209 | 210 | if not os.path.exists(cell_type_mapping_file_path): 211 | raise FileNotFoundError(str(cell_type_mapping_file_path) + "is not found") 212 | else: 213 | mapping_df = pd.read_csv(cell_type_mapping_file_path, index_col=False) 214 | cell_type_mapping = dict( 215 | zip( 216 | mapping_df[mapping_key], 217 | mapping_df[mapping_value], 218 | ) 219 | ) 220 | return cell_type_mapping 221 | 222 | 223 | def preprocess_ann_data( 224 | anndata_object: anndata.AnnData, 225 | min_genes: int = 200, 226 | num_bins: int = 10, 227 | cell_type: str = "cell_type", 228 | ): 229 | """run preprocessing steps on anndata object 230 | assumes that the anndata object has a standard structure with counts per cell X gene, and cell type annotations in obs[cell_type]. 231 | 232 | steps include: 233 | - translate cell types to ontology term ids 234 | - filter out cells with less than 200 genes expressed 235 | - normalize expression data sum to 1 236 | - transform counts via log1p in base 2 237 | - digitize expression data into bins 238 | 239 | Args: 240 | ann_data_object (anndata.AnnData): input object. will be overwritten 241 | """ 242 | cell_type_mapper = load_cell_type_mapping() 243 | 244 | # place labels (cell types) in the "label" observations 245 | anndata_object.obs["label"] = [ 246 | cell_type_mapper[cell] for cell in anndata_object.obs[cell_type] 247 | ] 248 | # filter out cells with shallow reads 249 | sc.pp.filter_cells(anndata_object, min_genes=min_genes) 250 | # normalize depth and 251 | sc.pp.normalize_total(anndata_object, target_sum=1.0) 252 | # change to log1p space, which is approximately in the scale of 0 to 10 (log_2(1001)~=10) 253 | sc.pp.log1p(anndata_object, base=2) 254 | 255 | # split range to bins - more or less 0,2,3,..10 if using about 10 bins 256 | # the +1 is intended to create num_bin bins, as `linespace` creates the bin ends starting at the lowest and ending at the highest values. 257 | bins = np.linspace( 258 | anndata_object.X.data.min(), anndata_object.X.max(), num=num_bins + 1 259 | ) 260 | # Note that we change the reading values into their bins number in the main matrix. 261 | anndata_object.X.data = np.digitize(anndata_object.X.data, bins) 262 | 263 | return anndata_object 264 | -------------------------------------------------------------------------------- /mammal/examples/scrna_cell_type/task.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 7 | 8 | from mammal.examples.scrna_cell_type.pl_data_module import ( 9 | CellTypeDataModule, 10 | ) 11 | from mammal.keys import ( 12 | CLS_PRED, 13 | DECODER_INPUTS_ATTENTION_MASK, 14 | DECODER_INPUTS_STR, 15 | DECODER_INPUTS_TOKENS, 16 | ENCODER_INPUTS_ATTENTION_MASK, 17 | ENCODER_INPUTS_STR, 18 | ENCODER_INPUTS_TOKENS, 19 | LABELS_ATTENTION_MASK, 20 | LABELS_STR, 21 | LABELS_TOKENS, 22 | SCORES, 23 | ) 24 | from mammal.metrics import classification_metrics 25 | from mammal.task import ( 26 | MammalTask, 27 | MetricBase, 28 | ) 29 | 30 | ALL_CLASS_LABELS = [ 31 | "[CL:0000794]", 32 | "[CL:0001062]", 33 | "[CL:0000939]", 34 | "[CL:0000792]", 35 | "[CL:0000236]", 36 | "[CL:0001204]", 37 | "[CL:0001054]", 38 | "[CL:0000451]", 39 | "[CL:0000895]", 40 | "[CL:0000049]", 41 | "[CL:0000546]", 42 | ] 43 | 44 | 45 | class CellTypeTask(MammalTask): 46 | def __init__( 47 | self, 48 | *, 49 | tokenizer_op: ModularTokenizerOp, 50 | data_module_kwargs: dict, 51 | logger: Any | None = None, 52 | ) -> None: 53 | super().__init__( 54 | name="cell_type", 55 | logger=logger, 56 | tokenizer_op=tokenizer_op, 57 | ) 58 | self._data_module_kwargs = data_module_kwargs 59 | 60 | self.preds_key = CLS_PRED 61 | self.scores_key = SCORES 62 | self.labels_key = LABELS_TOKENS 63 | 64 | def data_module(self) -> pl.LightningDataModule: 65 | return CellTypeDataModule( 66 | tokenizer_op=self._tokenizer_op, 67 | data_preprocessing=self.data_preprocessing, 68 | stratify_by_label=True, 69 | **self._data_module_kwargs, 70 | ) 71 | 72 | def train_metrics(self) -> dict[str, MetricBase]: 73 | metrics = super().train_metrics() 74 | metrics.update( 75 | # TODO: update this 76 | classification_metrics( 77 | self.name(), 78 | class_position=1, 79 | tokenizer_op=self._tokenizer_op, 80 | class_tokens=ALL_CLASS_LABELS, 81 | ) 82 | ) 83 | 84 | return metrics 85 | 86 | def validation_metrics(self) -> dict[str, MetricBase]: 87 | validation_metrics = super().validation_metrics() 88 | validation_metrics.update( 89 | classification_metrics( 90 | self.name(), 91 | class_position=1, 92 | tokenizer_op=self._tokenizer_op, 93 | class_tokens=ALL_CLASS_LABELS, 94 | ) 95 | ) 96 | return validation_metrics 97 | 98 | @staticmethod 99 | def data_preprocessing( 100 | sample_dict: dict, 101 | *, 102 | sequence_key: str, 103 | label_key: int | None = None, 104 | # drug_max_seq_length: int = 1250, 105 | input_max_seq_length: int | None = 1260, 106 | encoder_input_max_seq_len: int | None = 1260, 107 | labels_max_seq_len: int | None = 4, 108 | tokenizer_op: ModularTokenizerOp, 109 | ) -> dict: 110 | """process a sample into the format expected by the model 111 | 112 | Args: 113 | sample_dict (dict): dictionary with the sample data 114 | sequence_key (str): key in the dictionary with the sequence 115 | tokenizer_op (ModularTokenizerOp): the tokenizer 116 | label_key (int | None, optional): key for the label. Defaults to None. 117 | input_max_seq_length (int | None, optional): sequence is truncated if longer than this. Defaults to 1260. 118 | encoder_input_max_seq_len (int | None, optional): maximal length of encoder input. Defaults to 1260. 119 | labels_max_seq_len (int | None, optional): maximal length of label sequence. Defaults to 4. 120 | 121 | Returns: 122 | dict: the sample dict with added keys and values: 123 | 124 | Mammal model expects a dictionary with a set of keys to be able to run. This method converts the data into the expected format. 125 | Here is a list of the required fields for an encoder-decoder task: 126 | ENCODER_INPUTS_STR 127 | ENCODER_INPUTS_TOKENS 128 | ENCODER_INPUTS_ATTENTION_MASK 129 | 130 | LABELS_STR 131 | LABELS_TOKENS 132 | LABELS_ATTENTION_MASK 133 | 134 | DECODER_INPUTS_STR 135 | DECODER_INPUTS_TOKENS 136 | DECODER_INPUTS_ATTENTION_MASK 137 | 138 | see MammalTask.data_module for more information about these keys and their use. 139 | 140 | 141 | The three *_str values are constricted here, and then the others are derived from them by the tokenizer_op 142 | """ 143 | scrna = sample_dict[sequence_key] 144 | if label_key: 145 | label = sample_dict.get(label_key, None) 146 | else: 147 | label = None 148 | 149 | # we have a link to the data of the specific cell, as a reference into the AnnData object 150 | # To get the canonical gene names we need to get access to the AnnData object itself. 151 | gene_names = scrna._view_args.parent.var_names.to_numpy() 152 | 153 | # This is where the data is converted to GeneFormer inspired "binned and sorted" 154 | # The binning is done in preprocess_ann_data, on load rather then when training. 155 | 156 | sorted_genes = CellTypeTask.convert_to_double_sorted_geneformer_sequence( 157 | scrna_sample=scrna, gene_names=gene_names 158 | ) 159 | sequence_string = "[" + "][".join(sorted_genes[:input_max_seq_length]) + "]" 160 | 161 | encoder_prompt = f"<@TOKENIZER-TYPE=GENE><{sequence_string}" 162 | sample_dict[ENCODER_INPUTS_STR] = encoder_prompt 163 | 164 | tokenizer_op( 165 | sample_dict=sample_dict, 166 | key_in=ENCODER_INPUTS_STR, 167 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS, 168 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK, 169 | max_seq_len=encoder_input_max_seq_len, 170 | ) 171 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor( 172 | sample_dict[ENCODER_INPUTS_TOKENS] 173 | ) 174 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor( 175 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] 176 | ) 177 | 178 | if label is not None: 179 | pad_id = tokenizer_op.get_token_id("") 180 | ignore_token_value = -100 181 | sample_dict[LABELS_STR] = ( 182 | f"<@TOKENIZER-TYPE=CELL_ATTRIBUTES>[{label}]" 183 | ) 184 | tokenizer_op( 185 | sample_dict=sample_dict, 186 | key_in=LABELS_STR, 187 | key_out_tokens_ids=LABELS_TOKENS, 188 | key_out_attention_mask=LABELS_ATTENTION_MASK, 189 | max_seq_len=labels_max_seq_len, 190 | ) 191 | sample_dict[LABELS_TOKENS] = torch.tensor(sample_dict[LABELS_TOKENS]) 192 | sample_dict[LABELS_ATTENTION_MASK] = torch.tensor( 193 | sample_dict[LABELS_ATTENTION_MASK] 194 | ) 195 | # replace pad_id with -100 to 196 | sample_dict[LABELS_TOKENS][ 197 | (sample_dict[LABELS_TOKENS][..., None] == torch.tensor(pad_id)) 198 | .any(-1) 199 | .nonzero() 200 | ] = ignore_token_value 201 | 202 | sample_dict[DECODER_INPUTS_STR] = ( 203 | f"<@TOKENIZER-TYPE=CELL_ATTRIBUTES><{label}>" 204 | ) 205 | tokenizer_op( 206 | sample_dict=sample_dict, 207 | key_in=DECODER_INPUTS_STR, 208 | key_out_tokens_ids=DECODER_INPUTS_TOKENS, 209 | key_out_attention_mask=DECODER_INPUTS_ATTENTION_MASK, 210 | max_seq_len=labels_max_seq_len, 211 | ) 212 | sample_dict[DECODER_INPUTS_TOKENS] = torch.tensor( 213 | sample_dict[DECODER_INPUTS_TOKENS] 214 | ) 215 | sample_dict[DECODER_INPUTS_ATTENTION_MASK] = torch.tensor( 216 | sample_dict[DECODER_INPUTS_ATTENTION_MASK] 217 | ) 218 | 219 | return sample_dict 220 | 221 | @staticmethod 222 | def convert_to_double_sorted_geneformer_sequence(scrna_sample, gene_names): 223 | """convert binned genes to double sorted GeneFormer like format. 224 | The sorting is done first over the binned expression values and then on the gene names 225 | This is achieved by zipping together the minus the bin (so to sort it from large to small) 226 | and the standardized gene name. 227 | sample.data are the non-zero values of the raw, sample.indices are the indexes for these values 228 | 229 | 230 | Args: 231 | sample: Dataframe with gene bins and matching indexes 232 | gene_names (list[str]):list of gene names matching the list above 233 | 234 | Returns: 235 | list[str] - gene names sorted by bin values and then by gene name 236 | """ 237 | return [ 238 | a[1] 239 | for a in sorted(zip(-scrna_sample.data, gene_names[scrna_sample.indices])) 240 | ] 241 | 242 | @staticmethod 243 | def process_model_output( 244 | tokenizer_op: ModularTokenizerOp, 245 | decoder_output: np.ndarray, 246 | decoder_output_scores: np.ndarray, 247 | ) -> dict | None: 248 | ans = None 249 | all_class_label_ids = [ 250 | tokenizer_op.get_token_id(class_label) for class_label in ALL_CLASS_LABELS 251 | ] 252 | classification_position = 1 253 | if decoder_output_scores is not None: 254 | class_scores = decoder_output_scores[classification_position][ 255 | all_class_label_ids 256 | ] 257 | best_match = class_scores.argmax() 258 | non_normalized_score = class_scores[best_match] 259 | normalization_factor = class_scores.sum() 260 | normalized_score = non_normalized_score / ( 261 | normalization_factor + 1e-30 262 | ) # incase non seem to match 263 | ans = { 264 | "cell_type": ALL_CLASS_LABELS[best_match], 265 | "pred": all_class_label_ids[best_match], 266 | "not_normalized_scores": non_normalized_score, 267 | "normalized_scores": normalized_score, 268 | } 269 | 270 | return ans 271 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![arXiv](https://img.shields.io/badge/arXiv-2410.22367-b31b1b.svg)](https://arxiv.org/abs/2410.22367) 2 | [![Open Source](https://badges.frapsoft.com/os/v1/open-source.svg)](https://opensource.org/) 3 | ![PyPI Downloads](https://static.pepy.tech/badge/biomed-multi-alignment) 4 | [![GitHub Clones](https://img.shields.io/badge/dynamic/json?color=success&label=Clone&query=count&url=https://gist.githubusercontent.com/mosheraboh/a19913f8cf752e05e84f0d09d997a403/raw/clone.json&logo=github)](https://github.com/MShawon/github-clone-count-badge) 5 | 6 | 7 | # biomed-multi-alignment 8 | 9 | **Update - MCP is now supported for MAMMAL agent integration visit [here](./mammal_mcp/README.md) for more information.** 10 | 11 | We introduce [**ibm/biomed.omics.bl.sm.ma-ted-458m**](https://arxiv.org/abs/2410.22367). 12 | A biomedical foundation model trained on over 2 billion biological samples across multiple modalities, including proteins, small molecules, and single-cell gene expression data. 13 | Designed for robust performance, it achieves state-of-the-art results on a variety of tasks across the entire drug discovery pipeline and diverse biomedical domains. 14 | 15 | The model is based on **MAMMAL** (**M**olecular **A**ligned **M**ulti-**M**odal **A**rchitecture and **L**anguage), a flexible, multi-domain architecture with an adaptable task prompt syntax. 16 | The syntax allows for dynamic combinations of tokens and scalars, enabling classification, regression, and generation tasks either within a single domain or with cross-domain entities. 17 | 18 | The model weights are stored at https://huggingface.co/ibm/biomed.omics.bl.sm.ma-ted-458m and the MAMMAL core code together with fine-tuning and inference can be found in this repo. 19 | Learn more by reading our [pre-print](https://arxiv.org/abs/2410.22367). 20 | 21 | ![Alt text](mammal.png) 22 | 23 | 24 | 25 | ## Installation 26 | MAMMAL is tested on Python >= 3.10 and PyTorch >= 2.0 27 | 28 | Follow the next steps to install MAMMAL in a new environment: 29 | 1. Create conda environment: 30 | ``` 31 | conda create -n mammal_env python=3.10 -y 32 | conda activate mammal_env 33 | ``` 34 | 35 | 2. Install PyTorch: (see [here](https://pytorch.org/get-started/locally/)). For example: 36 | ``` 37 | conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia 38 | ``` 39 | 40 | 3. Install the package in [editable mode](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs) using `pip`: 41 | ``` 42 | git clone git@github.com:BiomedSciAI/biomed-multi-alignment.git 43 | pip install -e ./biomed-multi-alignment[examples] 44 | ``` 45 | 46 | Another option is to install directly from PyPI: 47 | ``` 48 | pip install biomed-multi-alignment[examples] 49 | ``` 50 | 51 | 52 | 53 | 54 | # Examples 55 | We provide a variety of example tasks, covering one from each domain as well as a multi-domain task. To facilitate easy setup, we've selected tasks with datasets that can be automatically downloaded and come with established data splits. 56 | While these tasks may not necessarily have State-of-the-Art results we can compare to, they offer practical demonstrations of model application. 57 | 58 | Additionally, since the pre-trained model was also trained on a protein-protein interaction task, we demonstrate inference using this task with ibm/biomed.omics.bl.sm.ma-ted-458m. 59 | 60 | ## Protein-Protein Interaction 61 | **This example supported in a [google-colab](https://colab.research.google.com/github/BiomedSciAI/biomed-multi-alignment/blob/main/tutorials/begginer_inference.ipynb)** 62 | A simple example for a task already supported by `ibm/biomed.omics.bl.sm.ma-ted-458m`: 63 | ```python 64 | import torch 65 | from fuse.data.tokenizers.modular_tokenizer.op import ModularTokenizerOp 66 | from mammal.model import Mammal 67 | from mammal.keys import * 68 | 69 | # Load Model 70 | model = Mammal.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m") 71 | # Set model to evaluation mode 72 | model.eval() 73 | 74 | # Load Tokenizer 75 | tokenizer_op = ModularTokenizerOp.from_pretrained("ibm/biomed.omics.bl.sm.ma-ted-458m") 76 | 77 | # Prepare Input Prompt 78 | protein_calmodulin = "MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMISELDQDGFIDKEDLHDGDGKISFEEFLNLVNKEMTADVDGDGQVNYEEFVTMMTSK" 79 | protein_calcineurin = "MSSKLLLAGLDIERVLAEKNFYKEWDTWIIEAMNVGDEEVDRIKEFKEDEIFEEAKTLGTAEMQEYKKQKLEEAIEGAFDIFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIRQMWDQNGDWDRIKELKFGEIKKLSAKDTRGTIFIKVFENLGTGVDSEYEDVSKYMLKHQ" 80 | 81 | # Create and load sample 82 | sample_dict = dict() 83 | # Formatting prompt to match pre-training syntax 84 | sample_dict[ENCODER_INPUTS_STR] = f"<@TOKENIZER-TYPE=AA>{protein_calmodulin}{protein_calcineurin}" 85 | 86 | # Tokenize 87 | tokenizer_op( 88 | sample_dict=sample_dict, 89 | key_in=ENCODER_INPUTS_STR, 90 | key_out_tokens_ids=ENCODER_INPUTS_TOKENS, 91 | key_out_attention_mask=ENCODER_INPUTS_ATTENTION_MASK, 92 | ) 93 | sample_dict[ENCODER_INPUTS_TOKENS] = torch.tensor(sample_dict[ENCODER_INPUTS_TOKENS]) 94 | sample_dict[ENCODER_INPUTS_ATTENTION_MASK] = torch.tensor(sample_dict[ENCODER_INPUTS_ATTENTION_MASK]) 95 | 96 | # Generate Prediction 97 | batch_dict = model.generate( 98 | [sample_dict], 99 | output_scores=True, 100 | return_dict_in_generate=True, 101 | max_new_tokens=5, 102 | ) 103 | 104 | # Get output 105 | generated_output = tokenizer_op._tokenizer.decode(batch_dict[CLS_PRED][0]) 106 | print(f"{generated_output=}") 107 | ``` 108 | 109 | ## Protein solubility prediction 110 | Protein solubility is a critical factor in both pharmaceutical research and production processes, as it can significantly impact the quality and function of a protein. 111 | This is an example for finetuning `ibm/biomed.omics.bl.sm-ted-458` for protein solubility prediction (binary classification) based solely on the amino acid sequence. 112 | The benchmark is defined in: https://academic.oup.com/bioinformatics/article/34/15/2605/4938490 113 | Data was retrieved from: https://zenodo.org/records/1162886 114 | 115 | 116 | ### Finetune 117 | To finetune from pre-trained MAMMAL, run the following command: 118 | ``` 119 | python mammal/main_finetune.py --config-name config.yaml --config-path examples/protein_solubility 120 | ``` 121 | ### Inference 122 | To run inference, run the following command: 123 | ``` 124 | python mammal/examples/protein_solubility/main_infer.py 125 | ``` 126 | ### Evaluation 127 | To run the evaluation, use the following command: 128 | ``` 129 | python mammal/main_finetune.py --config-name config.yaml --config-path examples/protein_solubility evaluate=True model.pretrained_kwargs.pretrained_model_name_or_path=/best_epoch.ckpt 130 | ``` 131 | 132 | ## Drug carcinogenicity prediction 133 | A [TDC task](https://tdcommons.ai/single_pred_tasks/tox/#carcinogens) of a binary classification. Given a drug SMILES string, predict whether it can cause cancer. 134 | > A carcinogen is any substance, radionuclide, or radiation that promotes carcinogenesis, the formation of cancer. This may be due to the ability to damage the genome or to the disruption of cellular metabolic processes. 135 | 136 | ### Finetune 137 | To finetune from pre-trained MAMMAL, run the following command: 138 | ``` 139 | python mammal/main_finetune.py --config-name config.yaml --config-path examples/carcinogenicity 140 | ``` 141 | ### Inference 142 | To run inference, run the following command: 143 | ``` 144 | # python mammal/examples/carcinogenicity/main_infer.py 145 | python mammal/examples/carcinogenicity/main_infer.py ./carcinogenicity_finetune "CC(CCl)OC(C)CCl" 146 | ``` 147 | 148 | ## Drug Target Interaction 149 | Accurate prediction of drug-target binding affinity is essential in the early stages of drug discovery. 150 | This is an example of finetuning `ibm/biomed.omics.bl.sm-ted-458` the task. 151 | Prediction of binding affinities using pKd, the negative logarithm of the dissociation constant, which reflects the strength of the interaction between a small molecule (drug) and a protein (target). 152 | The expected inputs for the model are the amino acid sequence of the target and the SMILES representation of the drug. 153 | The benchmark defined on: https://tdcommons.ai/multi_pred_tasks/dti/ 154 | We also harmonize the values using `data.harmonize_affinities(mode = 'max_affinity')` and transforming to log-scale. 155 | By default we are using Drug+Target cold-split, as provided by tdcommons. 156 | 157 | ### Finetune 158 | To finetune from pretrained MAMMAL, run the following command: 159 | ``` 160 | python mammal/main_finetune.py --config-name config.yaml --config-path examples/dti_bindingdb_kd 161 | ``` 162 | ### Inference 163 | To run inference, run the following command: 164 | ``` 165 | python mammal/examples/dti_bindingdb_kd/main_infer.py 166 | ``` 167 | `` and `` should be the values specified in the finetuning configuration file (config.yaml). 168 | ### Evaluation 169 | To run the evaluation, run the following command: 170 | ``` 171 | python mammal/main_finetune.py --config-name config.yaml --config-path examples/dti_bindingdb_kd evaluate=True model.pretrained_kwargs.pretrained_model_name_or_path=/best_epoch.ckpt 172 | ``` 173 | 174 | # Modular Tokenizer 175 | Since many of the tasks on which **ibm/biomed.omics.bl.sm.ma-ted-458m** is trained use different modalities (amino acid sequences, SMILES, gene expressions, etc.), we implemented a modular tokenizer that can combine multiple tokenizers, mapping their dictionaries to a consistent ID space (https://github.com/BiomedSciAI/fuse-med-ml/tree/master/fuse/data/tokenizers/modular_tokenizer). 176 | 177 | # Tutorials 178 | If you are interested in a specific guide / tutorial, feel free to [open an issue](https://github.com/BiomedSciAI/biomed-multi-alignment/issues/new). 179 | 180 | ### Advanced 181 | * Create a new Mammal task. [[link](./tutorials/advanced_create_new_task.ipynb)] 182 | 183 | 184 | # Citations 185 | If you find our work useful for your research, we ask you to cite the relevant papers: 186 | ``` 187 | @misc{shoshan2024mammalmolecularaligned, 188 | title={MAMMAL -- Molecular Aligned Multi-Modal Architecture and Language}, 189 | author={Yoel Shoshan and Moshiko Raboh and Michal Ozery-Flato and Vadim Ratner and Alex Golts and Jeffrey K. Weber and Ella Barkan and Simona Rabinovici-Cohen and Sagi Polaczek and Ido Amos and Ben Shapira and Liam Hazan and Matan Ninio and Sivan Ravid and Michael M. Danziger and Joseph A. Morrone and Parthasarathy Suryanarayanan and Michal Rosen-Zvi and Efrat Hexter}, 190 | year={2024}, 191 | eprint={2410.22367}, 192 | archivePrefix={arXiv}, 193 | primaryClass={q-bio.QM}, 194 | url={https://arxiv.org/abs/2410.22367}, 195 | } 196 | ``` 197 | --------------------------------------------------------------------------------