├── ContraNovo ├── __init__.py ├── version.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── ms_io.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ └── datasets.cpython-39.pyc │ └── datasets.py ├── denovo │ ├── __init__.py │ ├── __pycache__ │ │ ├── model.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── db_index.cpython-39.pyc │ │ ├── evaluate.cpython-39.pyc │ │ ├── parser2.cpython-39.pyc │ │ ├── clipmodel.cpython-39.pyc │ │ ├── dataloaders.cpython-39.pyc │ │ ├── db_dataset.cpython-39.pyc │ │ ├── db_dataloader.cpython-39.pyc │ │ └── model_runner.cpython-39.pyc │ ├── db_dataset.py │ ├── db_index.py │ ├── dataloaders.py │ ├── db_dataloader.py │ ├── parser2.py │ ├── evaluate.py │ ├── clipmodel.py │ └── model_runner.py ├── __pycache__ │ ├── utils.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── masses.cpython-39.pyc │ ├── utils2.cpython-39.pyc │ └── ContraNovo.cpython-39.pyc ├── components │ ├── __pycache__ │ │ ├── mixins.cpython-39.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── encoders.cpython-39.pyc │ │ ├── feedforward.cpython-39.pyc │ │ └── transformers.cpython-39.pyc │ ├── __init__.py │ ├── mixins.py │ ├── feedforward.py │ ├── encoders.py │ └── transformers.py ├── utils.py ├── utils2.py ├── config.yaml ├── masses.py ├── config.py └── ContraNovo.py ├── Model.png ├── README.md └── environment.yml /ContraNovo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ContraNovo/version.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ContraNovo/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ContraNovo/denovo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/Model.png -------------------------------------------------------------------------------- /ContraNovo/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/__pycache__/masses.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/__pycache__/masses.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/__pycache__/utils2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/__pycache__/utils2.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/__pycache__/ContraNovo.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/__pycache__/ContraNovo.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/data/__pycache__/ms_io.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/data/__pycache__/ms_io.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/data/__pycache__/datasets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/data/__pycache__/datasets.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/denovo/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/denovo/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/denovo/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/denovo/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/denovo/__pycache__/db_index.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/denovo/__pycache__/db_index.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/denovo/__pycache__/evaluate.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/denovo/__pycache__/evaluate.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/denovo/__pycache__/parser2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/denovo/__pycache__/parser2.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/components/__pycache__/mixins.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/components/__pycache__/mixins.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/denovo/__pycache__/clipmodel.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/denovo/__pycache__/clipmodel.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/denovo/__pycache__/dataloaders.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/denovo/__pycache__/dataloaders.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/denovo/__pycache__/db_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/denovo/__pycache__/db_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/components/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/components/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/components/__pycache__/encoders.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/components/__pycache__/encoders.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/denovo/__pycache__/db_dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/denovo/__pycache__/db_dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/denovo/__pycache__/model_runner.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/denovo/__pycache__/model_runner.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/components/__pycache__/feedforward.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/components/__pycache__/feedforward.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/components/__pycache__/transformers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BEAM-Labs/ContraNovo/HEAD/ContraNovo/components/__pycache__/transformers.cpython-39.pyc -------------------------------------------------------------------------------- /ContraNovo/components/__init__.py: -------------------------------------------------------------------------------- 1 | """Components for building cool models""" 2 | from .transformers import SpectrumEncoder, PeptideDecoder 3 | from .feedforward import FeedForward 4 | from .mixins import ModelMixin 5 | -------------------------------------------------------------------------------- /ContraNovo/components/mixins.py: -------------------------------------------------------------------------------- 1 | """Useful mixins for model classes""" 2 | import pandas as pd 3 | 4 | 5 | class ModelMixin: 6 | """Add some useful methods for models.""" 7 | 8 | def init(self): 9 | """Initialize the ModelMixin""" 10 | self._history = [] 11 | 12 | @property 13 | def history(self): 14 | """The training history of a model.""" 15 | return pd.DataFrame(self._history) 16 | 17 | @property 18 | def n_parameters(self): 19 | """The number of learnable parameters.""" 20 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ContraNovo: A Contrastive Learning Approach to Enhance De Novo Peptide Sequencing 2 | 3 | This repository contains inference code for **A Contrastive Learning Approach to Enhance De Novo Peptide Sequencing (ContraNovo)** 4 | 5 | ![Model](./Model.png) 6 | 7 | #### Reproduce Steps 8 | 9 | - Get resource and create the conda environment. 10 | 11 | ``` 12 | git clone git@github.com:BEAM-Labs/ContraNovo.git 13 | cd ContraNovo 14 | conda env create -f environment.yml 15 | conda activate ContraNovo 16 | ``` 17 | 18 | - Download ContraNovo checkpoint and Bacillus.10k.mgf from google drive. 19 | 20 | ContraNovo.ckpt :https://drive.google.com/file/d/1knNUqSwPf98j388Ds2E6bG8tAXx8voWR/view?usp=drive_link 21 | 22 | Bacillus.10k.mgf : https://drive.google.com/file/d/1HqfCETZLV9ZB-byU0pqNNRXbaPbTAceT/view?usp=drive_link 23 | 24 | - Run ContraNovo test on bacillus.10k.mgf 25 | 26 | ``` 27 | python -m ContraNovo.ContraNovo --mode=eval --peak_path=./ContraNovo/bacillus.10k.mgf --model=./ContraNovo/ContraNovo.ckpt 28 | ``` 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /ContraNovo/utils.py: -------------------------------------------------------------------------------- 1 | """Small utility functions""" 2 | import os 3 | import platform 4 | import re 5 | from typing import Tuple 6 | 7 | import psutil 8 | import torch 9 | 10 | 11 | def n_workers() -> int: 12 | """ 13 | Get the number of workers to use for data loading. 14 | 15 | This is the maximum number of CPUs allowed for the process, scaled for the 16 | number of GPUs being used. 17 | 18 | On Windows and MacOS, we only use the main process. See: 19 | https://discuss.pytorch.org/t/errors-when-using-num-workers-0-in-dataloader/97564/4 20 | https://github.com/pytorch/pytorch/issues/70344 21 | 22 | Returns 23 | ------- 24 | int 25 | The number of workers. 26 | """ 27 | # Windows or MacOS: no multiprocessing. 28 | if platform.system() in ["Windows", "Darwin"]: 29 | return 0 30 | # Linux: scale the number of workers by the number of GPUs (if present). 31 | try: 32 | n_cpu = len(psutil.Process().cpu_affinity()) 33 | except AttributeError: 34 | n_cpu = os.cpu_count() 35 | return ( 36 | n_cpu // n_gpu if (n_gpu := torch.cuda.device_count()) > 1 else n_cpu 37 | ) 38 | 39 | 40 | def split_version(version: str) -> Tuple[str, str, str]: 41 | """ 42 | Split the version into its semantic versioning components. 43 | 44 | Parameters 45 | ---------- 46 | version : str 47 | The version number. 48 | 49 | Returns 50 | ------- 51 | major : str 52 | The major release. 53 | minor : str 54 | The minor release. 55 | patch : str 56 | The patch release. 57 | """ 58 | version_regex = re.compile(r"(\d+)\.(\d+)\.*(\d*)(?:.dev\d+.+)?") 59 | return tuple(g for g in version_regex.match(version).groups()) 60 | -------------------------------------------------------------------------------- /ContraNovo/components/feedforward.py: -------------------------------------------------------------------------------- 1 | """A flexible feed-forward neural network.""" 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class FeedForward(torch.nn.Module): 7 | """Create a feed forward neural net with leaky GELU activations 8 | 9 | Parameters 10 | ---------- 11 | in_dim : int 12 | The input dimensionality. 13 | out_dim : int 14 | The output dimensionality. 15 | layers : int or tuple of int. 16 | If an int, layer sizes are linearly interpolated between the input and 17 | output dimensions using this number of layers. Otherwise, each element 18 | specifies the size of a layer. 19 | dropout : float, optional 20 | If greater than zero, add dropout layers with the specified 21 | probability. 22 | append : torch.nn.Module or None, optional 23 | A final layer to append, such as a sigmoid or tanh. 24 | """ 25 | 26 | def __init__(self, in_dim, out_dim, layers, dropout=0, append=None): 27 | """Initiazlize a FeedForward network""" 28 | super().__init__() 29 | try: 30 | sizes = np.array([in_dim] + list(layers) + [out_dim]) 31 | except TypeError: 32 | sizes = np.ceil(np.linspace(in_dim, out_dim, int(layers) + 1)) 33 | 34 | sizes = sizes.astype(int) 35 | stack = [] 36 | for idx in range(len(sizes) - 1): 37 | stack.append(torch.nn.Linear(sizes[idx], sizes[idx + 1])) 38 | if idx < len(sizes) - 2: 39 | stack.append(torch.nn.LeakyReLU()) 40 | if dropout: 41 | stack.append(torch.nn.Dropout(dropout)) 42 | 43 | if append is not None: 44 | stack.append(append) 45 | 46 | self.layers = torch.nn.Sequential(*stack) 47 | 48 | def forward(self, X): 49 | """The forward pass""" 50 | return self.layers(X) 51 | -------------------------------------------------------------------------------- /ContraNovo/utils2.py: -------------------------------------------------------------------------------- 1 | """Common utility functions""" 2 | import pandas as pd 3 | from tensorboard.backend.event_processing.event_accumulator import ( 4 | EventAccumulator, 5 | ) 6 | 7 | 8 | def read_tensorboard_scalars(path): 9 | """Read scalars from Tensorboard logs. 10 | 11 | Parameters 12 | ---------- 13 | path : str 14 | The path of the scalar log file. 15 | 16 | Returns 17 | ------- 18 | pandas.DataFrame 19 | A dataframe containing the scalar values. 20 | """ 21 | event = EventAccumulator(path) 22 | event.Reload() 23 | data = [] 24 | for tag in event.Tags()["scalars"]: 25 | tag_df = pd.DataFrame( 26 | event.Scalars(tag), columns=["wall_time", "step", "value"] 27 | ) 28 | tag_df["tag"] = tag 29 | data.append(tag_df) 30 | 31 | return pd.concat(data) 32 | 33 | 34 | def listify(obj): 35 | """Turn an object into a list, but don't split strings.""" 36 | try: 37 | assert not isinstance(obj, str) 38 | iter(obj) 39 | except (AssertionError, TypeError): 40 | obj = [obj] 41 | 42 | return list(obj) 43 | 44 | 45 | # For Parameter Checking ------------------------------------------------------ 46 | def check_int(integer, name): 47 | """Verify that an object is an integer, or coercible to one. 48 | 49 | Parameters 50 | ---------- 51 | integer : int 52 | The integer to check. 53 | name : str 54 | The name to print in the error message if it fails. 55 | """ 56 | if isinstance(integer, int): 57 | return integer 58 | 59 | # Else if it is a float: 60 | coerced = int(integer) 61 | if coerced != integer: 62 | raise ValueError(f"'{name}' must be an integer.") 63 | 64 | return coerced 65 | 66 | 67 | def check_positive_int(integer, name): 68 | """Verify that an object is an integer and positive. 69 | 70 | Parameters 71 | ---------- 72 | integer : int 73 | The integer to check. 74 | name : str 75 | The name to print in the error message if it fails. 76 | """ 77 | try: 78 | integer = check_int(integer, name) 79 | assert integer > 0 80 | except (ValueError, AssertionError): 81 | raise ValueError(f"'{name}' must be a positive integer.") 82 | 83 | return integer 84 | -------------------------------------------------------------------------------- /ContraNovo/config.yaml: -------------------------------------------------------------------------------- 1 | ### 2 | # ContraNovo configuration. 3 | # Blank entries are interpreted as "None" 4 | ### 5 | 6 | # Random seed to ensure reproducible results. 7 | random_seed: 200 8 | # random_seed: -1 # Random 9 | 10 | # Spectrum processing options. 11 | n_peaks: 300 12 | min_mz: 50.5 13 | max_mz: 4500.0 14 | min_intensity: 0.0 15 | remove_precursor_tol: 2.0 # Da 16 | max_charge: 10 17 | precursor_mass_tol: 50 # ppm 18 | isotope_error_range: [0, 1] 19 | 20 | # Model architecture options. 21 | dim_model: 512 22 | n_head: 8 23 | dim_feedforward: 1024 24 | n_layers: 9 25 | dropout: 0.18 26 | dim_intensity: 27 | custom_encoder: 28 | max_length: 100 29 | residues: 30 | "G": 57.021464 31 | "A": 71.037114 32 | "S": 87.032028 33 | "P": 97.052764 34 | "V": 99.068414 35 | "T": 101.047670 36 | "C+57.021": 160.030649 # 103.009185 + 57.021464 37 | "L": 113.084064 38 | "I": 113.084064 39 | "N": 114.042927 40 | "D": 115.026943 41 | "Q": 128.058578 42 | "K": 128.094963 43 | "E": 129.042593 44 | "M": 131.040485 45 | "H": 137.058912 46 | "F": 147.068414 47 | "R": 156.101111 48 | "Y": 163.063329 49 | "W": 186.079313 50 | # Amino acid modifications. 51 | "M+15.995": 147.035400 # Met oxidation: 131.040485 + 15.994915 52 | "N+0.984": 115.026943 # Asn deamidation: 114.042927 + 0.984016 53 | "Q+0.984": 129.042594 # Gln deamidation: 128.058578 + 0.984016 54 | # N-terminal modifications. 55 | "+42.011": 42.010565 # Acetylation 56 | "+43.006": 43.005814 # Carbamylation 57 | "-17.027": -17.026549 # NH3 loss 58 | "+43.006-17.027": 25.980265 59 | n_log: 1 60 | tb_summarywriter: 61 | 62 | # Neptune logger 63 | enable_neptune: True 64 | neptune_project: "DeNovo/clip" 65 | neptune_api_token: 66 | tags: ["9-speice", "bacillus","Lr = 0.0002 dp 0.15,0.4"] 67 | n_nodes: 1 68 | train_from_resume: False 69 | 70 | 71 | # Use epochs instead of iters 72 | warmup_iters: 73 | max_iters: 74 | 75 | max_epochs: 150 76 | warm_up_epochs: 1 77 | learning_rate: 0.0004 78 | weight_decay: 1e-5 79 | gradient_clip_val: 1.5 80 | gradient_clip_algorithm: "norm" 81 | accumulate_grad_batches: 1 82 | sync_batchnorm: False 83 | SWA: False # 84 | 85 | # Training/inference options. 86 | train_batch_size: 64 87 | predict_batch_size: 512 88 | # n_beams: 5 89 | n_beams: 5 # No beam search 90 | 91 | logger: 92 | 93 | num_sanity_val_steps: 0 94 | 95 | train_from_scratch: True 96 | 97 | save_model: True 98 | model_save_folder_path: "./clipcasa" 99 | save_weights_only: True 100 | every_n_train_steps: 2500 101 | -------------------------------------------------------------------------------- /ContraNovo/masses.py: -------------------------------------------------------------------------------- 1 | """Amino acid masses and other useful mass spectrometry calculations""" 2 | import re 3 | 4 | 5 | class PeptideMass: 6 | """A simple class for calculating peptide masses 7 | 8 | Parameters 9 | ---------- 10 | residues: Dict or str {"massivekb", "canonical"}, optional 11 | The amino acid dictionary and their masses. By default this is only 12 | the 20 canonical amino acids, with cysteine carbamidomethylated. If 13 | "massivekb", this dictionary will include the modifications found in 14 | MassIVE-KB. Additionally, a dictionary can be used to specify a custom 15 | collection of amino acids and masses. 16 | """ 17 | 18 | canonical = { 19 | "G": 57.021463735, 20 | "A": 71.037113805, 21 | "S": 87.032028435, 22 | "P": 97.052763875, 23 | "V": 99.068413945, 24 | "T": 101.047678505, 25 | "C+57.021": 103.009184505 + 57.02146, 26 | "L": 113.084064015, 27 | "I": 113.084064015, 28 | "N": 114.042927470, 29 | "D": 115.026943065, 30 | "Q": 128.058577540, 31 | "K": 128.094963050, 32 | "E": 129.042593135, 33 | "M": 131.040484645, 34 | "H": 137.058911875, 35 | "F": 147.068413945, 36 | # "U": 150.953633405, 37 | "R": 156.101111050, 38 | "Y": 163.063328575, 39 | "W": 186.079312980, 40 | # "O": 237.147726925, 41 | } 42 | 43 | # Modfications found in MassIVE-KB 44 | massivekb = { 45 | # N-terminal mods: 46 | "+42.011": 42.010565, # Acetylation 47 | "+43.006": 43.005814, # Carbamylation 48 | "-17.027": -17.026549, # NH3 loss 49 | "+43.006-17.027": (43.006814 - 17.026549), 50 | # AA mods: 51 | "M+15.995": canonical["M"] + 15.994915, # Met Oxidation 52 | "N+0.984": canonical["N"] + 0.984016, # Asn Deamidation 53 | "Q+0.984": canonical["Q"] + 0.984016, # Gln Deamidation 54 | } 55 | 56 | # Constants 57 | hydrogen = 1.007825035 58 | oxygen = 15.99491463 59 | h2o = 2 * hydrogen + oxygen 60 | proton = 1.00727646688 61 | 62 | def __init__(self, residues="canonical"): 63 | """Initialize the PeptideMass object""" 64 | if residues == "canonical": 65 | self.masses = self.canonical 66 | elif residues == "massivekb": 67 | self.masses = self.canonical 68 | self.masses.update(self.massivekb) 69 | else: 70 | self.masses = residues 71 | 72 | def __len__(self): 73 | """Return the length of the residue dictionary""" 74 | return len(self.masses) 75 | 76 | def mass(self, seq, charge=None): 77 | """Calculate a peptide's mass or m/z. 78 | 79 | Parameters 80 | ---------- 81 | seq : list or str 82 | The peptide sequence, using tokens defined in ``self.residues``. 83 | charge : int, optional 84 | The charge used to compute m/z. Otherwise the neutral peptide mass 85 | is calculated 86 | 87 | Returns 88 | ------- 89 | float 90 | The computed mass or m/z. 91 | """ 92 | if isinstance(seq, str): 93 | seq = re.split(r"(?<=.)(?=[A-Z])", seq) 94 | 95 | calc_mass = sum([self.masses[aa] for aa in seq]) + self.h2o 96 | if charge is not None: 97 | calc_mass = (calc_mass / charge) + self.proton 98 | 99 | return calc_mass 100 | -------------------------------------------------------------------------------- /ContraNovo/config.py: -------------------------------------------------------------------------------- 1 | """Parse the YAML configuration.""" 2 | import logging 3 | from pathlib import Path 4 | from typing import Optional, Dict, Callable, Tuple, Union 5 | 6 | import yaml 7 | import torch 8 | 9 | from . import utils 10 | 11 | logger = logging.getLogger("ContraNovo") 12 | 13 | 14 | class Config: 15 | _default_config = Path(__file__).parent / "config.yaml" 16 | _config_types = dict( 17 | random_seed=int, 18 | n_peaks=int, 19 | min_mz=float, 20 | max_mz=float, 21 | min_intensity=float, 22 | remove_precursor_tol=float, 23 | max_charge=int, 24 | precursor_mass_tol=float, 25 | isotope_error_range=lambda min_max: (int(min_max[0]), int(min_max[1])), 26 | min_peptide_len=int, 27 | dim_model=int, 28 | n_head=int, 29 | dim_feedforward=int, 30 | n_layers=int, 31 | dropout=float, 32 | dim_intensity=int, 33 | max_length=int, 34 | n_log=int, 35 | warmup_iters=int, 36 | max_iters=int, 37 | learning_rate=float, 38 | weight_decay=float, 39 | train_batch_size=int, 40 | predict_batch_size=int, 41 | n_beams=int, 42 | top_match=int, 43 | max_epochs=int, 44 | num_sanity_val_steps=int, 45 | train_from_scratch=bool, 46 | save_model=bool, 47 | model_save_folder_path=str, 48 | save_weights_only=bool, 49 | every_n_train_steps=int, 50 | no_gpu=bool, 51 | ) 52 | 53 | def __init__(self, config_file: Optional[str] = None): 54 | """Initialize a Config object.""" 55 | self.file = str(config_file) if config_file is not None else "default" 56 | with self._default_config.open() as f_in: 57 | self._params = yaml.safe_load(f_in) 58 | 59 | if config_file is None: 60 | self._user_config = {} 61 | else: 62 | with Path(config_file).open() as f_in: 63 | self._user_config = yaml.safe_load(f_in) 64 | 65 | # Validate: 66 | for key, val in self._config_types.items(): 67 | self.validate_param(key, val) 68 | 69 | # Add extra configuration options and scale by the number of GPUs. 70 | n_gpus = 0 if self["no_gpu"] else torch.cuda.device_count() 71 | self._params["n_workers"] = utils.n_workers() 72 | if n_gpus > 1: 73 | self._params["train_batch_size"] = ( 74 | self["train_batch_size"] // n_gpus 75 | ) 76 | 77 | def __getitem__(self, param: str) -> Union[int, bool, str, Tuple, Dict]: 78 | """Retrieve a parameter""" 79 | return self._params[param] 80 | 81 | def __getattr__(self, param: str) -> Union[int, bool, str, Tuple, Dict]: 82 | """Retrieve a parameter""" 83 | return self._params[param] 84 | 85 | def validate_param(self, param: str, param_type: Callable): 86 | """Verify a parameter is the correct type. 87 | 88 | Parameters 89 | ---------- 90 | param : str 91 | The ContraNovo parameter 92 | param_type : Callable 93 | The expected callable type of the parameter. 94 | """ 95 | try: 96 | param_val = self._user_config.get(param, self._params[param]) 97 | if param == "residues": 98 | residues = { 99 | str(aa): float(mass) for aa, mass in param_val.items() 100 | } 101 | self._params["residues"] = residues 102 | elif param_val is not None: 103 | self._params[param] = param_type(param_val) 104 | except (TypeError, ValueError) as err: 105 | logger.error( 106 | "Incorrect type for configuration value %s: %s", param, err 107 | ) 108 | raise TypeError( 109 | f"Incorrect type for configuration value {param}: {err}" 110 | ) 111 | 112 | def items(self) -> Tuple[str, ...]: 113 | """Return the parameters""" 114 | return self._params.items() 115 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ContraNovo 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.05.30=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.10=h7f8727e_0 15 | - pip=23.2.1=py39h06a4308_0 16 | - python=3.9.17=h955ad1f_0 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.0.0=py39h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.38.4=py39h06a4308_0 22 | - xz=5.4.2=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - absl-py==1.4.0 26 | - aiohttp==3.8.5 27 | - aiosignal==1.3.1 28 | - appdirs==1.4.4 29 | - arrow==1.2.3 30 | - async-timeout==4.0.2 31 | - attrs==23.1.0 32 | - boto3==1.28.33 33 | - botocore==1.31.33 34 | - bravado==11.0.3 35 | - bravado-core==6.1.0 36 | - cachetools==5.3.1 37 | - certifi==2023.7.22 38 | - cffi==1.15.1 39 | - charset-normalizer==3.2.0 40 | - click==8.1.6 41 | - contourpy==1.1.0 42 | - cryptography==41.0.3 43 | - cycler==0.11.0 44 | - deprecated==1.2.14 45 | - einops==0.6.1 46 | - fastobo==0.12.2 47 | - filelock==3.13.1 48 | - fonttools==4.42.0 49 | - fqdn==1.5.1 50 | - frozenlist==1.4.0 51 | - fsspec==2023.6.0 52 | - future==0.18.3 53 | - gitdb==4.0.10 54 | - gitpython==3.1.32 55 | - google-auth==2.22.0 56 | - google-auth-oauthlib==1.0.0 57 | - grpcio==1.56.2 58 | - h5py==3.9.0 59 | - idna==3.4 60 | - importlib-metadata==6.8.0 61 | - importlib-resources==6.0.1 62 | - isoduration==20.11.0 63 | - jinja2==3.1.2 64 | - jmespath==1.0.1 65 | - joblib==1.3.1 66 | - jsonpointer==2.4 67 | - jsonref==1.1.0 68 | - jsonschema==4.19.0 69 | - jsonschema-specifications==2023.7.1 70 | - kiwisolver==1.4.4 71 | - lark==1.1.7 72 | - lightning-utilities==0.9.0 73 | - llvmlite==0.40.1 74 | - lmdb==1.4.1 75 | - lxml==4.9.3 76 | - markdown==3.4.4 77 | - markupsafe==2.1.3 78 | - matplotlib==3.7.2 79 | - monotonic==1.6 80 | - mpmath==1.3.0 81 | - msgpack==1.0.5 82 | - multidict==6.0.4 83 | - natsort==8.4.0 84 | - neptune==1.6.1 85 | - networkx==3.2.1 86 | - numba==0.57.1 87 | - numpy==1.24.4 88 | - nvidia-cublas-cu12==12.1.3.1 89 | - nvidia-cuda-cupti-cu12==12.1.105 90 | - nvidia-cuda-nvrtc-cu12==12.1.105 91 | - nvidia-cuda-runtime-cu12==12.1.105 92 | - nvidia-cudnn-cu12==8.9.2.26 93 | - nvidia-cufft-cu12==11.0.2.54 94 | - nvidia-curand-cu12==10.3.2.106 95 | - nvidia-cusolver-cu12==11.4.5.107 96 | - nvidia-cusparse-cu12==12.1.0.106 97 | - nvidia-nccl-cu12==2.18.1 98 | - nvidia-nvjitlink-cu12==12.3.101 99 | - nvidia-nvtx-cu12==12.1.105 100 | - oauthlib==3.2.2 101 | - packaging==23.1 102 | - pandas==2.0.3 103 | - pillow==10.0.0 104 | - protobuf==4.23.4 105 | - psutil==5.9.5 106 | - pyasn1==0.5.0 107 | - pyasn1-modules==0.3.0 108 | - pycparser==2.21 109 | - pygithub==2.1.1 110 | - pyjwt==2.8.0 111 | - pynacl==1.5.0 112 | - pyparsing==3.0.9 113 | - pyteomics==4.6 114 | - python-dateutil==2.8.2 115 | - pytorch-lightning==1.9.5 116 | - pytz==2023.3 117 | - pyyaml==6.0.1 118 | - referencing==0.30.2 119 | - requests==2.31.0 120 | - requests-oauthlib==1.3.1 121 | - rfc3339-validator==0.1.4 122 | - rfc3987==1.3.8 123 | - rpds-py==0.9.2 124 | - rsa==4.9 125 | - s3transfer==0.6.2 126 | - scikit-learn==1.3.0 127 | - scipy==1.11.1 128 | - simplejson==3.19.1 129 | - six==1.16.0 130 | - smmap==5.0.0 131 | - spectrum-utils==0.4.2 132 | - swagger-spec-validator==3.0.3 133 | - sympy==1.12 134 | - tensorboard==2.13.0 135 | - tensorboard-data-server==0.7.1 136 | - threadpoolctl==3.2.0 137 | - tqdm==4.65.1 138 | - triton==2.1.0 139 | - typing-extensions==4.7.1 140 | - tzdata==2023.3 141 | - uri-template==1.3.0 142 | - urllib3==1.26.16 143 | - webcolors==1.13 144 | - websocket-client==1.6.2 145 | - werkzeug==2.3.6 146 | - wrapt==1.15.0 147 | - yarl==1.9.2 148 | - zipp==3.16.2 149 | -------------------------------------------------------------------------------- /ContraNovo/denovo/db_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple 3 | from .db_index import DB_Index 4 | import numpy as np 5 | import spectrum_utils.spectrum as sus 6 | from torch.utils.data import Dataset 7 | def cumsum(it): 8 | total = 0 9 | for x in it: 10 | total += x 11 | yield total 12 | class DbDataset(Dataset): 13 | ''' 14 | Read and Write and manage multiple DB files and process the data from those files (peaks) 15 | ''' 16 | def __init__(self, db_indexs, n_peaks: int = 150, 17 | min_mz: float = 140.0, 18 | max_mz: float = 2500.0, 19 | min_intensity: float = 0.01, 20 | remove_precursor_tol: float = 2.0, 21 | random_state: Optional[int] = None): 22 | super().__init__() 23 | self.n_peaks = n_peaks 24 | self.min_mz = min_mz 25 | self.max_mz = max_mz 26 | self.min_intensity = min_intensity 27 | self.remove_precursor_tol = remove_precursor_tol 28 | self.rng = np.random.default_rng(random_state) 29 | self._indexs = db_indexs 30 | 31 | def __len__(self): 32 | return self.n_spectra 33 | 34 | def __getitem__(self,idx): 35 | for i, each_offset in enumerate(self.offset): 36 | if idx < each_offset: 37 | if i == 0: 38 | new_idx = idx 39 | else: 40 | new_idx = idx - self.offset[i-1] 41 | mz_array, int_array, precursor_mz, precursor_charge, peptide = self.indexs[i][new_idx] 42 | spectrum = self._process_peaks(np.array(mz_array), np.array(int_array), precursor_mz, precursor_charge) 43 | 44 | # print(peptide) 45 | return spectrum, precursor_mz, precursor_charge, peptide.replace("pyro-","-17.027") 46 | 47 | def _process_peaks( 48 | self, 49 | mz_array: np.ndarray, 50 | int_array: np.ndarray, 51 | precursor_mz: float, 52 | precursor_charge: int, 53 | ) -> torch.Tensor: 54 | """ 55 | Preprocess the spectrum by removing noise peaks and scaling the peak 56 | intensities. 57 | 58 | Parameters 59 | ---------- 60 | mz_array : numpy.ndarray of shape (n_peaks,) 61 | The spectrum peak m/z values. 62 | int_array : numpy.ndarray of shape (n_peaks,) 63 | The spectrum peak intensity values. 64 | precursor_mz : float 65 | The precursor m/z. 66 | precursor_charge : int 67 | The precursor charge. 68 | 69 | Returns 70 | ------- 71 | torch.Tensor of shape (n_peaks, 2) 72 | A tensor of the spectrum with the m/z and intensity peak values. 73 | """ 74 | spectrum = sus.MsmsSpectrum( 75 | "", 76 | precursor_mz, 77 | precursor_charge, 78 | mz_array.astype(np.float64), 79 | int_array.astype(np.float32), 80 | ) 81 | try: 82 | spectrum.set_mz_range(self.min_mz, self.max_mz) 83 | if len(spectrum.mz) == 0: 84 | raise ValueError 85 | spectrum.remove_precursor_peak(self.remove_precursor_tol, "Da") 86 | if len(spectrum.mz) == 0: 87 | raise ValueError 88 | spectrum.filter_intensity(self.min_intensity, self.n_peaks) 89 | if len(spectrum.mz) == 0: 90 | raise ValueError 91 | spectrum.scale_intensity("root", 1) 92 | intensities = spectrum.intensity / np.linalg.norm( 93 | spectrum.intensity 94 | ) 95 | return torch.tensor(np.array([spectrum.mz, intensities])).T.float() 96 | except ValueError: 97 | # Replace invalid spectra by a dummy spectrum. 98 | return torch.tensor([[0, 1]]).float() 99 | 100 | 101 | @property 102 | def offset(self): 103 | sizes_list = [] 104 | for each in self.indexs: 105 | sizes_list.append(each.n_spectra) 106 | return list(cumsum(sizes_list)) 107 | 108 | @property 109 | def n_spectra(self) -> int: 110 | """The total number of spectra.""" 111 | total = 0 112 | for each in self.indexs: 113 | total += each.n_spectra 114 | return total 115 | @property 116 | def indexs(self): 117 | """The underlying SpectrumIndex.""" 118 | return self._indexs 119 | @property 120 | def rng(self): 121 | """The NumPy random number generator.""" 122 | return self._rng 123 | @rng.setter 124 | def rng(self, seed): 125 | """Set the NumPy random number generator.""" 126 | self._rng = np.random.default_rng(seed) -------------------------------------------------------------------------------- /ContraNovo/denovo/db_index.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import numpy as np 3 | import logging 4 | from pathlib import Path 5 | from .parser2 import MzmlParser, MzxmlParser, MgfParser 6 | import os 7 | import pickle 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | def listify(obj): 11 | """Turn an object into a list, but don't split strings.""" 12 | try: 13 | assert not isinstance(obj, str) 14 | iter(obj) 15 | except (AssertionError, TypeError): 16 | obj = [obj] 17 | 18 | return list(obj) 19 | class DB_Index: 20 | ''' 21 | read and store and manage (read and write) a signle IMDB file 22 | ''' 23 | def __init__(self, db_path, filenames=None, ms_level = 2, valid_charge = None, annotated = True, lock = True ): 24 | ''' 25 | self.n_spectra is the current index to write 26 | e.g, if self.n_spectra = 0, we update(0, spectra) 27 | ''' 28 | #cant overwrite rn if there is already a db file in path 29 | index_path = Path(db_path) 30 | self.db_path = index_path 31 | self.lock = lock 32 | print("what lock is this:", lock) 33 | 34 | 35 | 36 | self.filenames = filenames 37 | self.ms_level=ms_level 38 | self.valid_charge=valid_charge 39 | self.annotated = bool(annotated) 40 | is_exist = index_path.exists() 41 | 42 | self._init_db() 43 | 44 | #create a new 45 | if not is_exist: 46 | # print("I am not exist") 47 | txn = self.env.begin(write=True) 48 | txn.put( "ms_level".encode(), str(self.ms_level).encode()) 49 | txn.put("n_spectra".encode(), str(0).encode()) 50 | txn.put("n_peaks".encode(), str(0).encode()) 51 | self.n_spectra = 0 52 | # print("Before Writing:", self.annotated) 53 | # txn.put("annotated".encode(), str(int(self.annotated)).encode()) 54 | # annotated_read = bool(txn.get("annotated".encode()).decode()) 55 | # print("jinzhi:",annotated_read,self.annotated) 56 | txn.commit() 57 | else: 58 | txn=self.env.begin() 59 | self.n_spectra = int(txn.get("n_spectra".encode()).decode()) 60 | ms_level_read = int(txn.get("ms_level".encode()).decode()) 61 | # annotated_read = bool(txn.get("annotated".encode()).decode()) 62 | try: 63 | assert ms_level_read == self.ms_level 64 | # assert annotated_read == self.annotated 65 | except: 66 | raise ValueError(f"{self.db_path} already existed, but it has a inconsistent ms_level/annotated with input parameter!") 67 | if filenames is not None: 68 | filenames = listify(filenames) 69 | LOGGER.info("Reading %i files...", len(filenames)) 70 | for ms_file in filenames: 71 | self.add_file(ms_file) 72 | 73 | 74 | def write_to_db(self, parser, n ): 75 | #write all spectrums in the parser to lmdb 76 | txn = self.env.begin(write=True) 77 | assert n == len(parser.precursor_charge) 78 | for i in range(len(parser.precursor_charge)): 79 | #print("write in %i th spectrum,", i) 80 | #remember to round the charge 81 | #precusor m/z, precusor charge, peaks number, m/z1, .... , m/zn, i_1, ...., i_n 82 | collection_data = {"precursor_mz": parser.precursor_mz[i], "precursor_charge": parser.precursor_charge[i], 83 | "mz_array": parser.mz_arrays[i], "intensity_array": parser.intensity_arrays[i]} 84 | 85 | collection_data["pep"] = parser.annotations[i] 86 | # print(parser.annotations[i]) 87 | 88 | buffers = pickle.dumps(collection_data) 89 | 90 | txn.put(str(self.n_spectra).encode(), buffers) 91 | self.n_spectra+=1 92 | txn.put("n_spectra".encode(), str(self.n_spectra).encode()) 93 | txn.commit() 94 | 95 | def __getitem__(self, idx): 96 | txn = self.env.begin() 97 | buffer = txn.get(str(idx).encode()) 98 | data= pickle.loads(buffer) 99 | pep = data["pep"] 100 | out = (data["mz_array"], data["intensity_array"], data["precursor_mz"], data["precursor_charge"], pep) 101 | return out 102 | def __len__ (self): 103 | return self.n_spectra 104 | 105 | def add_file(self, a_file): 106 | a_file = Path(a_file) 107 | parser = self._get_parser(a_file) 108 | parser.read() 109 | n_spectra_in_file = parser.n_spectra 110 | self.write_to_db(parser, n_spectra_in_file) 111 | 112 | def _init_db(self): 113 | if self.lock == True: 114 | read = False 115 | else: 116 | read = True 117 | self.env = lmdb.open(str(self.db_path), map_size=4099511627776, subdir=False, readonly = read, lock=self.lock) 118 | 119 | 120 | def _get_parser(self, ms_data_file): 121 | #to allow change of annotations 122 | kw_args = dict(ms_level=self.ms_level, valid_charge=self.valid_charge, annotationsLabel=self.annotated) 123 | if ms_data_file.suffix.lower() == ".mzml": 124 | return MzmlParser(ms_data_file, **kw_args) 125 | 126 | if ms_data_file.suffix.lower() == ".mzxml": 127 | return MzxmlParser(ms_data_file, **kw_args) 128 | 129 | if ms_data_file.suffix.lower() == ".mgf": 130 | return MgfParser(ms_data_file, **kw_args) 131 | 132 | raise ValueError("Only mzML, mzXML, and MGF files are supported.") -------------------------------------------------------------------------------- /ContraNovo/components/encoders.py: -------------------------------------------------------------------------------- 1 | """Simple encoders for input into Transformers and the like.""" 2 | import torch 3 | import einops 4 | import numpy as np 5 | 6 | 7 | class MassEncoder(torch.nn.Module): 8 | """Encode mass values using sine and cosine waves. 9 | 10 | Parameters 11 | ---------- 12 | dim_model : int 13 | The number of features to output. 14 | min_wavelength : float 15 | The minimum wavelength to use. 16 | max_wavelength : float 17 | The maximum wavelength to use. 18 | """ 19 | 20 | def __init__(self, dim_model, min_wavelength=0.001, max_wavelength=10000): 21 | """Initialize the MassEncoder""" 22 | super().__init__() 23 | 24 | n_sin = int(dim_model / 2) 25 | n_cos = dim_model - n_sin 26 | 27 | if min_wavelength: 28 | base = min_wavelength / (2 * np.pi) 29 | scale = max_wavelength / min_wavelength 30 | else: 31 | base = 1 32 | scale = max_wavelength / (2 * np.pi) 33 | 34 | sin_term = base * scale ** ( 35 | torch.arange(0, n_sin).float() / (n_sin - 1) 36 | ) 37 | cos_term = base * scale ** ( 38 | torch.arange(0, n_cos).float() / (n_cos - 1) 39 | ) 40 | 41 | self.register_buffer("sin_term", sin_term) 42 | self.register_buffer("cos_term", cos_term) 43 | 44 | def forward(self, X): 45 | """Encode m/z values. 46 | 47 | Parameters 48 | ---------- 49 | X : torch.Tensor of shape (n_masses) 50 | The masses to embed. 51 | 52 | Returns 53 | ------- 54 | torch.Tensor of shape (n_masses, dim_model) 55 | The encoded features for the mass spectra. 56 | """ 57 | sin_mz = torch.sin(X / self.sin_term) 58 | cos_mz = torch.cos(X / self.cos_term) 59 | return torch.cat([sin_mz, cos_mz], axis=-1) 60 | 61 | 62 | class PeakEncoder(MassEncoder): 63 | """Encode m/z values in a mass spectrum using sine and cosine waves. 64 | 65 | Parameters 66 | ---------- 67 | dim_model : int 68 | The number of features to output. 69 | dim_intensity : int, optional 70 | The number of features to use for intensity. The remaining features 71 | will be used to encode the m/z values. 72 | min_wavelength : float, optional 73 | The minimum wavelength to use. 74 | max_wavelength : float, optional 75 | The maximum wavelength to use. 76 | """ 77 | 78 | def __init__( 79 | self, 80 | dim_model, 81 | dim_intensity=None, 82 | min_wavelength=0.001, 83 | max_wavelength=10000, 84 | ): 85 | """Initialize the MzEncoder""" 86 | self.dim_intensity = dim_intensity 87 | self.dim_mz = dim_model 88 | if self.dim_intensity is not None: 89 | self.dim_mz -= self.dim_intensity 90 | 91 | super().__init__( 92 | dim_model=self.dim_mz, 93 | min_wavelength=min_wavelength, 94 | max_wavelength=max_wavelength, 95 | ) 96 | 97 | if self.dim_intensity is None: 98 | self.int_encoder = torch.nn.Linear(1, dim_model, bias=False) 99 | else: 100 | self.int_encoder = MassEncoder( 101 | dim_model=dim_intensity, 102 | min_wavelength=0, 103 | max_wavelength=1, 104 | ) 105 | 106 | def forward(self, X): 107 | """Encode m/z values and intensities. 108 | 109 | Note that we expect intensities to fall within the interval [0, 1]. 110 | 111 | Parameters 112 | ---------- 113 | X : torch.Tensor of shape (n_spectra, n_peaks, 2) 114 | The spectra to embed. Axis 0 represents a mass spectrum, axis 1 115 | contains the peaks in the mass spectrum, and axis 2 is essentially 116 | a 2-tuple specifying the m/z-intensity pair for each peak. These 117 | should be zero-padded, such that all of the spectra in the batch 118 | are the same length. 119 | 120 | Returns 121 | ------- 122 | torch.Tensor of shape (n_spectr, n_peaks, dim_model) 123 | The encoded features for the mass spectra. 124 | """ 125 | m_over_z = X[:, :, [0]] 126 | encoded = super().forward(m_over_z) 127 | intensity = self.int_encoder(X[:, :, [1]]) 128 | if self.dim_intensity is None: 129 | return encoded + intensity 130 | 131 | return torch.cat([encoded, intensity], dim=2) 132 | 133 | 134 | class PositionalEncoder(torch.nn.Module): 135 | """The positional encoder for sequences. 136 | 137 | Parameters 138 | ---------- 139 | dim_model : int 140 | The number of features to output. 141 | """ 142 | 143 | def __init__(self, dim_model, max_wavelength=10000): 144 | """Initialize the MzEncoder""" 145 | super().__init__() 146 | 147 | n_sin = int(dim_model / 2) 148 | n_cos = dim_model - n_sin 149 | scale = max_wavelength / (2 * np.pi) 150 | 151 | sin_term = scale ** (torch.arange(0, n_sin).float() / (n_sin - 1)) 152 | cos_term = scale ** (torch.arange(0, n_cos).float() / (n_cos - 1)) 153 | self.register_buffer("sin_term", sin_term) 154 | self.register_buffer("cos_term", cos_term) 155 | 156 | def forward(self, X): 157 | """Encode positions in a sequence. 158 | 159 | Parameters 160 | ---------- 161 | X : torch.Tensor of shape (batch_size, n_sequence, n_features) 162 | The first dimension should be the batch size (i.e. each is one 163 | peptide) and the second dimension should be the sequence (i.e. 164 | each should be an amino acid representation). 165 | 166 | Returns 167 | ------- 168 | torch.Tensor of shape (batch_size, n_sequence, n_features) 169 | The encoded features for the mass spectra. 170 | """ 171 | pos = torch.arange(X.shape[1]).type_as(self.sin_term) 172 | pos = einops.repeat(pos, "n -> b n", b=X.shape[0]) 173 | sin_in = einops.repeat(pos, "b n -> b n f", f=len(self.sin_term)) 174 | cos_in = einops.repeat(pos, "b n -> b n f", f=len(self.cos_term)) 175 | 176 | sin_pos = torch.sin(sin_in / self.sin_term) 177 | cos_pos = torch.cos(cos_in / self.cos_term) 178 | encoded = torch.cat([sin_pos, cos_pos], axis=2) 179 | return encoded + X 180 | -------------------------------------------------------------------------------- /ContraNovo/ContraNovo.py: -------------------------------------------------------------------------------- 1 | """The command line entry point for Contranovo.""" 2 | import datetime 3 | import functools 4 | import logging 5 | import os 6 | import re 7 | import shutil 8 | import sys 9 | import warnings 10 | from typing import Optional, Tuple 11 | 12 | warnings.filterwarnings("ignore", category=DeprecationWarning) 13 | 14 | import appdirs 15 | import click 16 | import github 17 | import requests 18 | import torch 19 | import tqdm 20 | import yaml 21 | from pytorch_lightning.lite import LightningLite 22 | 23 | from . import utils 24 | from .denovo import model_runner 25 | 26 | 27 | logger = logging.getLogger("ContraNovo") 28 | 29 | 30 | @click.command() 31 | @click.option( 32 | "--mode", 33 | required=True, 34 | default="denovo", 35 | help="\b\nThe mode in which to run Contranovo:\n" 36 | '- "denovo" will predict peptide sequences for\nunknown MS/MS spectra.\n' 37 | '- "train" will train a model (from scratch or by\ncontinuing training a ' 38 | "previously trained model).\n" 39 | '- "eval" will evaluate the performance of a\ntrained model using ' 40 | "previously acquired spectrum\nannotations.", 41 | type=click.Choice(["denovo", "train", "eval"]), 42 | ) 43 | @click.option( 44 | "--model", 45 | help="The file name of the model weights (.ckpt file).", 46 | type=click.Path(exists=True, dir_okay=False), 47 | ) 48 | @click.option( 49 | "--peak_path", 50 | required=True, 51 | help="The file path with peak files for predicting peptide sequences or " 52 | "training ContraNovo.", 53 | ) 54 | @click.option( 55 | "--peak_path_val", 56 | help="The file path with peak files to be used as validation data during " 57 | "training.", 58 | ) 59 | @click.option( 60 | "--peak_path_test", 61 | help="The file path with peak files to be used as testing data during " 62 | "training.", 63 | ) 64 | @click.option( 65 | "--config", 66 | help="The file name of the configuration file with custom options. If not " 67 | "specified, a default configuration will be used.", 68 | type=click.Path(exists=True, dir_okay=False), 69 | ) 70 | @click.option( 71 | "--output", 72 | help="The base output file name to store logging (extension: .log) and " 73 | "(optionally) prediction results (extension: .csv).", 74 | type=click.Path(dir_okay=False), 75 | ) 76 | def main( 77 | mode: str, 78 | model: Optional[str], 79 | peak_path: str, 80 | peak_path_val: Optional[str], 81 | peak_path_test: Optional[str], 82 | config: Optional[str], 83 | output: Optional[str], 84 | ): 85 | 86 | # print("hello xiang") 87 | if output is None: 88 | output = os.path.join( 89 | os.getcwd(), 90 | f"ContraNovo_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}", 91 | ) 92 | else: 93 | output = os.path.splitext(os.path.abspath(output))[0] 94 | 95 | # Configure logging. 96 | logging.captureWarnings(True) 97 | root = logging.getLogger() 98 | root.setLevel(logging.DEBUG) 99 | log_formatter = logging.Formatter( 100 | "{asctime} {levelname} [{name}/{processName}] {module}.{funcName} : " 101 | "{message}", 102 | style="{", 103 | ) 104 | console_handler = logging.StreamHandler(sys.stderr) 105 | console_handler.setLevel(logging.DEBUG) 106 | console_handler.setFormatter(log_formatter) 107 | root.addHandler(console_handler) 108 | file_handler = logging.FileHandler(f"{output}.log") 109 | file_handler.setFormatter(log_formatter) 110 | root.addHandler(file_handler) 111 | # Disable dependency non-critical log messages. 112 | logging.getLogger("depthcharge").setLevel(logging.INFO) 113 | logging.getLogger("github").setLevel(logging.WARNING) 114 | logging.getLogger("h5py").setLevel(logging.WARNING) 115 | logging.getLogger("numba").setLevel(logging.WARNING) 116 | logging.getLogger("pytorch_lightning").setLevel(logging.WARNING) 117 | logging.getLogger("torch").setLevel(logging.WARNING) 118 | logging.getLogger("urllib3").setLevel(logging.WARNING) 119 | 120 | # Read parameters from the config file. 121 | if config is None: 122 | config = os.path.join( 123 | os.path.dirname(os.path.realpath(__file__)), "config.yaml" 124 | ) 125 | config_fn = config 126 | with open(config) as f_in: 127 | config = yaml.safe_load(f_in) 128 | # Ensure that the config values have the correct type. 129 | config_types = dict( 130 | random_seed=int, 131 | n_peaks=int, 132 | min_mz=float, 133 | max_mz=float, 134 | min_intensity=float, 135 | remove_precursor_tol=float, 136 | max_charge=int, 137 | precursor_mass_tol=float, 138 | isotope_error_range=lambda min_max: (int(min_max[0]), int(min_max[1])), 139 | dim_model=int, 140 | n_head=int, 141 | dim_feedforward=int, 142 | n_layers=int, 143 | dropout=float, 144 | dim_intensity=int, 145 | max_length=int, 146 | n_log=int, 147 | warmup_iters=int, 148 | max_iters=int, 149 | learning_rate=float, 150 | weight_decay=float, 151 | train_batch_size=int, 152 | predict_batch_size=int, 153 | n_beams=int, 154 | max_epochs=int, 155 | num_sanity_val_steps=int, 156 | train_from_scratch=bool, 157 | save_model=bool, 158 | model_save_folder_path=str, 159 | save_weights_only=bool, 160 | every_n_train_steps=int, 161 | ) 162 | for k, t in config_types.items(): 163 | try: 164 | if config[k] is not None: 165 | config[k] = t(config[k]) 166 | except (TypeError, ValueError) as e: 167 | logger.error("Incorrect type for configuration value %s: %s", k, e) 168 | raise TypeError(f"Incorrect type for configuration value {k}: {e}") 169 | config["residues"] = { 170 | str(aa): float(mass) for aa, mass in config["residues"].items() 171 | } 172 | # Add extra configuration options and scale by the number of GPUs. 173 | n_gpus = torch.cuda.device_count() 174 | config["n_workers"] = utils.n_workers() 175 | if n_gpus > 1: 176 | config["train_batch_size"] = config["train_batch_size"] // n_gpus 177 | 178 | import random 179 | if(config["random_seed"]==-1): 180 | config["random_seed"]=random.randint(1, 9999) 181 | LightningLite.seed_everything(seed=config["random_seed"], workers=True) 182 | 183 | # Log the active configuration. 184 | logger.debug("mode = %s", mode) 185 | logger.debug("model = %s", model) 186 | logger.debug("peak_path = %s", peak_path) 187 | logger.debug("peak_path_val = %s", peak_path_val) 188 | logger.debug("peak_path_test = %s", peak_path_test) 189 | logger.debug("config = %s", config_fn) 190 | logger.debug("output = %s", output) 191 | for key, value in config.items(): 192 | logger.debug("%s = %s", str(key), str(value)) 193 | 194 | # Run ContraNovo in the specified mode. 195 | if mode == "denovo": 196 | logger.info("Predict peptide sequences with ContraNovo.") 197 | writer = None 198 | # writer.set_metadata( 199 | # config, peak_path=peak_path, model=model, config_filename=config_fn 200 | # ) 201 | model_runner.predict(peak_path, model, config, writer) 202 | writer.save() 203 | elif mode == "eval": 204 | logger.info("Evaluate a trained ContraNovo model.") 205 | model_runner.evaluate(peak_path, model, config) 206 | elif mode == "train": 207 | logger.info("Train the ContraNovo model.") 208 | model_runner.train(peak_path, peak_path_val, peak_path_test, model, config) 209 | 210 | if __name__ == "__main__": 211 | main() 212 | -------------------------------------------------------------------------------- /ContraNovo/denovo/dataloaders.py: -------------------------------------------------------------------------------- 1 | """Data loaders for the de novo sequencing task.""" 2 | import functools 3 | import os 4 | from typing import List, Optional, Tuple 5 | 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch 9 | from depthcharge.data import AnnotatedSpectrumIndex 10 | 11 | from ..data.datasets import AnnotatedSpectrumDataset, SpectrumDataset 12 | 13 | 14 | class DeNovoDataModule(pl.LightningDataModule): 15 | """ 16 | Data loader to prepare MS/MS spectra for a Spec2Pep predictor. 17 | 18 | Parameters 19 | ---------- 20 | train_index : Optional[AnnotatedSpectrumIndex] 21 | The spectrum index file corresponding to the training data. 22 | valid_index : Optional[AnnotatedSpectrumIndex] 23 | The spectrum index file corresponding to the validation data. 24 | test_index : Optional[AnnotatedSpectrumIndex] 25 | The spectrum index file corresponding to the testing data. 26 | batch_size : int 27 | The batch size to use for training and evaluating. 28 | n_peaks : Optional[int] 29 | The number of top-n most intense peaks to keep in each spectrum. `None` 30 | retains all peaks. 31 | min_mz : float 32 | The minimum m/z to include. The default is 140 m/z, in order to exclude 33 | TMT and iTRAQ reporter ions. 34 | max_mz : float 35 | The maximum m/z to include. 36 | min_intensity : float 37 | Remove peaks whose intensity is below `min_intensity` percentage of the 38 | base peak intensity. 39 | remove_precursor_tol : float 40 | Remove peaks within the given mass tolerance in Dalton around the 41 | precursor mass. 42 | n_workers : int, optional 43 | The number of workers to use for data loading. By default, the number of 44 | available CPU cores on the current machine is used. 45 | random_state : Optional[int] 46 | The NumPy random state. ``None`` leaves mass spectra in the order they 47 | were parsed. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | train_index: Optional[AnnotatedSpectrumIndex] = None, 53 | valid_index: Optional[AnnotatedSpectrumIndex] = None, 54 | test_index: Optional[AnnotatedSpectrumIndex] = None, 55 | batch_size: int = 128, 56 | n_peaks: Optional[int] = 150, 57 | min_mz: float = 50.0, 58 | max_mz: float = 2500.0, 59 | min_intensity: float = 0.01, 60 | remove_precursor_tol: float = 2.0, 61 | n_workers: Optional[int] = None, 62 | random_state: Optional[int] = None, 63 | ): 64 | super().__init__() 65 | self.train_index = train_index 66 | self.valid_index = valid_index 67 | self.test_index = test_index 68 | self.batch_size = batch_size 69 | self.n_peaks = n_peaks 70 | self.min_mz = min_mz 71 | self.max_mz = max_mz 72 | self.min_intensity = min_intensity 73 | self.remove_precursor_tol = remove_precursor_tol 74 | self.n_workers = n_workers if n_workers is not None else os.cpu_count() 75 | self.rng = np.random.default_rng(random_state) 76 | self.train_dataset = None 77 | self.valid_dataset = None 78 | self.test_dataset = None 79 | 80 | def setup(self, stage: str = None, annotated: bool = True) -> None: 81 | """ 82 | Set up the PyTorch Datasets. 83 | 84 | Parameters 85 | ---------- 86 | stage : str {"fit", "validate", "test"} 87 | The stage indicating which Datasets to prepare. All are prepared by 88 | default. 89 | annotated: bool 90 | True if peptide sequence annotations are available for the test 91 | data. 92 | """ 93 | if stage in (None, "fit", "validate"): 94 | make_dataset = functools.partial( 95 | AnnotatedSpectrumDataset, 96 | n_peaks=self.n_peaks, 97 | min_mz=self.min_mz, 98 | max_mz=self.max_mz, 99 | min_intensity=self.min_intensity, 100 | remove_precursor_tol=self.remove_precursor_tol, 101 | ) 102 | if self.train_index is not None: 103 | self.train_dataset = make_dataset( 104 | self.train_index, 105 | random_state=self.rng, 106 | ) 107 | if self.valid_index is not None: 108 | self.valid_dataset = make_dataset(self.valid_index) 109 | if stage in (None, "test"): 110 | make_dataset = functools.partial( 111 | AnnotatedSpectrumDataset if annotated else SpectrumDataset, 112 | n_peaks=self.n_peaks, 113 | min_mz=self.min_mz, 114 | max_mz=self.max_mz, 115 | min_intensity=self.min_intensity, 116 | remove_precursor_tol=self.remove_precursor_tol, 117 | ) 118 | if self.test_index is not None: 119 | self.test_dataset = make_dataset(self.test_index) 120 | 121 | def _make_loader( 122 | self, dataset: torch.utils.data.Dataset 123 | ) -> torch.utils.data.DataLoader: 124 | """ 125 | Create a PyTorch DataLoader. 126 | 127 | Parameters 128 | ---------- 129 | dataset : torch.utils.data.Dataset 130 | A PyTorch Dataset. 131 | 132 | Returns 133 | ------- 134 | torch.utils.data.DataLoader 135 | A PyTorch DataLoader. 136 | """ 137 | return torch.utils.data.DataLoader( 138 | dataset, 139 | batch_size=self.batch_size, 140 | collate_fn=prepare_batch, 141 | pin_memory=True, 142 | num_workers=self.n_workers, 143 | ) 144 | 145 | def train_dataloader(self) -> torch.utils.data.DataLoader: 146 | """Get the training DataLoader.""" 147 | return self._make_loader(self.train_dataset) 148 | 149 | def val_dataloader(self) -> torch.utils.data.DataLoader: 150 | """Get the validation DataLoader.""" 151 | return self._make_loader(self.valid_dataset) 152 | 153 | def test_dataloader(self) -> torch.utils.data.DataLoader: 154 | """Get the test DataLoader.""" 155 | return self._make_loader(self.test_dataset) 156 | 157 | def predict_dataloader(self) -> torch.utils.data.DataLoader: 158 | """Get the predict DataLoader.""" 159 | return self._make_loader(self.test_dataset) 160 | 161 | 162 | def prepare_batch( 163 | batch: List[Tuple[torch.Tensor, float, int, str]] 164 | ) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]: 165 | """ 166 | Collate MS/MS spectra into a batch. 167 | 168 | The MS/MS spectra will be padded so that they fit nicely as a tensor. 169 | However, the padded elements are ignored during the subsequent steps. 170 | 171 | Parameters 172 | ---------- 173 | batch : List[Tuple[torch.Tensor, float, int, str]] 174 | A batch of data from an AnnotatedSpectrumDataset, consisting of for each 175 | spectrum (i) a tensor with the m/z and intensity peak values, (ii), the 176 | precursor m/z, (iii) the precursor charge, (iv) the spectrum identifier. 177 | 178 | Returns 179 | ------- 180 | spectra : torch.Tensor of shape (batch_size, n_peaks, 2) 181 | The padded mass spectra tensor with the m/z and intensity peak values 182 | for each spectrum. 183 | precursors : torch.Tensor of shape (batch_size, 3) 184 | A tensor with the precursor neutral mass, precursor charge, and 185 | precursor m/z. 186 | spectrum_ids : np.ndarray 187 | The spectrum identifiers (during de novo sequencing) or peptide 188 | sequences (during training). 189 | """ 190 | spectra, precursor_mzs, precursor_charges, spectrum_ids = list(zip(*batch)) 191 | spectra = torch.nn.utils.rnn.pad_sequence(spectra, batch_first=True) 192 | precursor_mzs = torch.tensor(precursor_mzs) 193 | precursor_charges = torch.tensor(precursor_charges) 194 | precursor_masses = (precursor_mzs - 1.007276) * precursor_charges 195 | precursors = torch.vstack( 196 | [precursor_masses, precursor_charges, precursor_mzs] 197 | ).T.float() 198 | return spectra, precursors, np.asarray(spectrum_ids) 199 | -------------------------------------------------------------------------------- /ContraNovo/data/datasets.py: -------------------------------------------------------------------------------- 1 | """A PyTorch Dataset class for annotated spectra.""" 2 | from typing import Optional, Tuple 3 | 4 | import depthcharge 5 | import numpy as np 6 | import spectrum_utils.spectrum as sus 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class SpectrumDataset(Dataset): 12 | """ 13 | Parse and retrieve collections of MS/MS spectra. 14 | 15 | Parameters 16 | ---------- 17 | spectrum_index : depthcharge.data.SpectrumIndex 18 | The MS/MS spectra to use as a dataset. 19 | n_peaks : Optional[int] 20 | The number of top-n most intense peaks to keep in each spectrum. `None` 21 | retains all peaks. 22 | min_mz : float 23 | The minimum m/z to include. The default is 140 m/z, in order to exclude 24 | TMT and iTRAQ reporter ions. 25 | max_mz : float 26 | The maximum m/z to include. 27 | min_intensity : float 28 | Remove peaks whose intensity is below `min_intensity` percentage of the 29 | base peak intensity. 30 | remove_precursor_tol : float 31 | Remove peaks within the given mass tolerance in Dalton around the 32 | precursor mass. 33 | random_state : Optional[int] 34 | The NumPy random state. ``None`` leaves mass spectra in the order they 35 | were parsed. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | spectrum_index: depthcharge.data.SpectrumIndex, 41 | n_peaks: int = 150, 42 | min_mz: float = 140.0, 43 | max_mz: float = 2500.0, 44 | min_intensity: float = 0.01, 45 | remove_precursor_tol: float = 2.0, 46 | random_state: Optional[int] = None, 47 | ): 48 | """Initialize a SpectrumDataset""" 49 | super().__init__() 50 | self.n_peaks = n_peaks 51 | self.min_mz = min_mz 52 | self.max_mz = max_mz 53 | self.min_intensity = min_intensity 54 | self.remove_precursor_tol = remove_precursor_tol 55 | self.rng = np.random.default_rng(random_state) 56 | self._index = spectrum_index 57 | 58 | def __len__(self) -> int: 59 | """The number of spectra.""" 60 | return self.n_spectra 61 | 62 | def __getitem__( 63 | self, idx 64 | ) -> Tuple[torch.Tensor, float, int, Tuple[str, str]]: 65 | """ 66 | Return the MS/MS spectrum with the given index. 67 | 68 | Parameters 69 | ---------- 70 | idx : int 71 | The index of the spectrum to return. 72 | 73 | Returns 74 | ------- 75 | spectrum : torch.Tensor of shape (n_peaks, 2) 76 | A tensor of the spectrum with the m/z and intensity peak values. 77 | precursor_mz : float 78 | The precursor m/z. 79 | precursor_charge : int 80 | The precursor charge. 81 | spectrum_id: Tuple[str, str] 82 | The unique spectrum identifier, formed by its original peak file and 83 | identifier (index or scan number) therein. 84 | """ 85 | mz_array, int_array, precursor_mz, precursor_charge = self.index[idx] 86 | spectrum = self._process_peaks( 87 | mz_array, int_array, precursor_mz, precursor_charge 88 | ) 89 | return ( 90 | spectrum, 91 | precursor_mz, 92 | precursor_charge, 93 | self.get_spectrum_id(idx), 94 | ) 95 | 96 | def get_spectrum_id(self, idx: int) -> Tuple[str, str]: 97 | """ 98 | Return the identifier of the MS/MS spectrum with the given index. 99 | 100 | Parameters 101 | ---------- 102 | idx : int 103 | The index of the MS/MS spectrum within the SpectrumIndex. 104 | 105 | Returns 106 | ------- 107 | ms_data_file : str 108 | The peak file from which the MS/MS spectrum was originally parsed. 109 | identifier : str 110 | The MS/MS spectrum identifier, per PSI recommendations. 111 | """ 112 | with self.index: 113 | return self.index.get_spectrum_id(idx) 114 | 115 | def _process_peaks( 116 | self, 117 | mz_array: np.ndarray, 118 | int_array: np.ndarray, 119 | precursor_mz: float, 120 | precursor_charge: int, 121 | ) -> torch.Tensor: 122 | """ 123 | Preprocess the spectrum by removing noise peaks and scaling the peak 124 | intensities. 125 | 126 | Parameters 127 | ---------- 128 | mz_array : numpy.ndarray of shape (n_peaks,) 129 | The spectrum peak m/z values. 130 | int_array : numpy.ndarray of shape (n_peaks,) 131 | The spectrum peak intensity values. 132 | precursor_mz : float 133 | The precursor m/z. 134 | precursor_charge : int 135 | The precursor charge. 136 | 137 | Returns 138 | ------- 139 | torch.Tensor of shape (n_peaks, 2) 140 | A tensor of the spectrum with the m/z and intensity peak values. 141 | """ 142 | spectrum = sus.MsmsSpectrum( 143 | "", 144 | precursor_mz, 145 | precursor_charge, 146 | mz_array.astype(np.float64), 147 | int_array.astype(np.float32), 148 | ) 149 | try: 150 | spectrum.set_mz_range(self.min_mz, self.max_mz) 151 | if len(spectrum.mz) == 0: 152 | raise ValueError 153 | spectrum.remove_precursor_peak(self.remove_precursor_tol, "Da") 154 | if len(spectrum.mz) == 0: 155 | raise ValueError 156 | spectrum.filter_intensity(self.min_intensity, self.n_peaks) 157 | if len(spectrum.mz) == 0: 158 | raise ValueError 159 | spectrum.scale_intensity("root", 1) 160 | intensities = spectrum.intensity / np.linalg.norm( 161 | spectrum.intensity 162 | ) 163 | return torch.tensor(np.array([spectrum.mz, intensities])).T.float() 164 | except ValueError: 165 | # Replace invalid spectra by a dummy spectrum. 166 | return torch.tensor([[0, 1]]).float() 167 | 168 | @property 169 | def n_spectra(self) -> int: 170 | """The total number of spectra.""" 171 | return self.index.n_spectra 172 | 173 | @property 174 | def index(self) -> depthcharge.data.SpectrumIndex: 175 | """The underlying SpectrumIndex.""" 176 | return self._index 177 | 178 | @property 179 | def rng(self): 180 | """The NumPy random number generator.""" 181 | return self._rng 182 | 183 | @rng.setter 184 | def rng(self, seed): 185 | """Set the NumPy random number generator.""" 186 | self._rng = np.random.default_rng(seed) 187 | 188 | 189 | class AnnotatedSpectrumDataset(SpectrumDataset): 190 | """ 191 | Parse and retrieve collections of annotated MS/MS spectra. 192 | 193 | Parameters 194 | ---------- 195 | annotated_spectrum_index : depthcharge.data.SpectrumIndex 196 | The MS/MS spectra to use as a dataset. 197 | n_peaks : Optional[int] 198 | The number of top-n most intense peaks to keep in each spectrum. `None` 199 | retains all peaks. 200 | min_mz : float 201 | The minimum m/z to include. The default is 140 m/z, in order to exclude 202 | TMT and iTRAQ reporter ions. 203 | max_mz : float 204 | The maximum m/z to include. 205 | min_intensity : float 206 | Remove peaks whose intensity is below `min_intensity` percentage of the 207 | base peak intensity. 208 | remove_precursor_tol : float 209 | Remove peaks within the given mass tolerance in Dalton around the 210 | precursor mass. 211 | random_state : Optional[int] 212 | The NumPy random state. ``None`` leaves mass spectra in the order they 213 | were parsed. 214 | """ 215 | 216 | def __init__( 217 | self, 218 | annotated_spectrum_index: depthcharge.data.SpectrumIndex, 219 | n_peaks: int = 150, 220 | min_mz: float = 140.0, 221 | max_mz: float = 2500.0, 222 | min_intensity: float = 0.01, 223 | remove_precursor_tol: float = 2.0, 224 | random_state: Optional[int] = None, 225 | ): 226 | super().__init__( 227 | annotated_spectrum_index, 228 | n_peaks=n_peaks, 229 | min_mz=min_mz, 230 | max_mz=max_mz, 231 | min_intensity=min_intensity, 232 | remove_precursor_tol=remove_precursor_tol, 233 | random_state=random_state, 234 | ) 235 | 236 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, float, int, str]: 237 | """ 238 | Return the annotated MS/MS spectrum with the given index. 239 | 240 | Parameters 241 | ---------- 242 | idx : int 243 | The index of the spectrum to return. 244 | 245 | Returns 246 | ------- 247 | spectrum : torch.Tensor of shape (n_peaks, 2) 248 | A tensor of the spectrum with the m/z and intensity peak values. 249 | precursor_mz : float 250 | The precursor m/z. 251 | precursor_charge : int 252 | The precursor charge. 253 | annotation : str 254 | The peptide annotation of the spectrum. 255 | """ 256 | ( 257 | mz_array, 258 | int_array, 259 | precursor_mz, 260 | precursor_charge, 261 | peptide, 262 | ) = self.index[idx] 263 | spectrum = self._process_peaks( 264 | mz_array, int_array, precursor_mz, precursor_charge 265 | ) 266 | return spectrum, precursor_mz, precursor_charge, peptide 267 | -------------------------------------------------------------------------------- /ContraNovo/denovo/db_dataloader.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | from typing import List, Optional, Tuple 4 | from torch.utils.data import RandomSampler 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | from .db_index import DB_Index 9 | from .db_dataset import DbDataset 10 | class DeNovoDataModule(pl.LightningDataModule): 11 | def __init__( 12 | self, 13 | train_index= None, # list 14 | valid_index= None, #list 15 | test_index= None, # list 16 | batch_size: int = 128, 17 | n_peaks: Optional[int] = 150, 18 | min_mz: float = 50.0, 19 | max_mz: float = 2500.0, 20 | min_intensity: float = 0.01, 21 | remove_precursor_tol: float = 2.0, 22 | n_workers: Optional[int] = None, 23 | random_state: Optional[int] = None, 24 | train_filenames = None, 25 | val_filenames = None, 26 | test_filenames = None, 27 | train_index_path = None, 28 | val_index_path = None, 29 | test_index_path = None, 30 | annotated = True, 31 | valid_charge = None , 32 | ms_level = 2, 33 | mode = "fit" 34 | 35 | ): 36 | super().__init__() 37 | self.annotated = annotated 38 | self.valid_charge = valid_charge 39 | self.ms_level = ms_level 40 | self.train_index = train_index 41 | self.valid_index = valid_index 42 | self.test_index = test_index 43 | self.batch_size = batch_size 44 | self.n_peaks = n_peaks 45 | self.min_mz = min_mz 46 | self.max_mz = max_mz 47 | self.min_intensity = min_intensity 48 | self.remove_precursor_tol = remove_precursor_tol 49 | self.n_workers = n_workers if n_workers is not None else os.cpu_count() 50 | self.rng = np.random.default_rng(random_state) 51 | self.train_dataset = None 52 | self.valid_dataset = None 53 | self.test_dataset = None 54 | self.train_filenames = train_filenames #either a list or None, need to examine 55 | self.val_filenames = val_filenames 56 | self.test_filenames = test_filenames 57 | self.train_index_path = train_index_path # always a list, one or more values in the list 58 | self.val_index_path = val_index_path 59 | self.test_index_path = test_index_path 60 | self.mode = mode 61 | def setup(self, stage=None): 62 | if stage in (None, "fit", "validate"): 63 | make_dataset = functools.partial( 64 | DbDataset, 65 | n_peaks=self.n_peaks, 66 | min_mz=self.min_mz, 67 | max_mz=self.max_mz, 68 | min_intensity=self.min_intensity, 69 | remove_precursor_tol=self.remove_precursor_tol, 70 | ) 71 | self.train_index = [] 72 | for each in self.train_index_path: 73 | self.train_index.append(DB_Index(each, None, self.ms_level, self.valid_charge, self.annotated, lock=False)) 74 | self.train_dataset = make_dataset(self.train_index, random_state=self.rng ) 75 | 76 | print("index_len:", len(self.train_index[0])) 77 | 78 | self.valid_index = [] 79 | for each in self.val_index_path: 80 | self.valid_index.append(DB_Index(each, None, self.ms_level, self.valid_charge, self.annotated, lock=False)) 81 | 82 | self.valid_dataset = make_dataset(self.valid_index, random_state= self.rng) 83 | 84 | self.test_index = [] 85 | for each in self.test_index_path: 86 | self.test_index.append(DB_Index(each, None, self.ms_level, self.valid_charge, self.annotated, lock=False)) 87 | self.test_dataset = make_dataset(self.test_index) 88 | 89 | 90 | 91 | elif stage in ( "test"): 92 | make_dataset = functools.partial( 93 | DbDataset, 94 | n_peaks=self.n_peaks, 95 | min_mz=self.min_mz, 96 | max_mz=self.max_mz, 97 | min_intensity=self.min_intensity, 98 | remove_precursor_tol=self.remove_precursor_tol, 99 | ) 100 | self.test_index = [] 101 | for each in self.test_index_path: 102 | self.test_index.append(DB_Index(each, None, self.ms_level, self.valid_charge, self.annotated, lock = False)) 103 | self.test_dataset = make_dataset(self.test_index) 104 | 105 | 106 | def prepare_data(self) -> None: 107 | #rule: if db_index file is None, we create index using filenames 108 | # else: we ignore filenames!!! 109 | # overall, if any filename is provided, we have to call prepare_data 110 | #need a mode for training/val/test 111 | print("prepare_data ing.....") 112 | 113 | 114 | if self.train_index == None and self.mode == "fit": #prepare train_index 115 | 116 | ''' 117 | try: 118 | assert self.train_filenames != None 119 | except: 120 | raise ValueError("No training file provided ") 121 | ''' 122 | if self.train_filenames == None : 123 | lock = False 124 | else: 125 | lock = True 126 | for each in self.train_index_path: 127 | DB_Index(each, self.train_filenames, self.ms_level, self.valid_charge, self.annotated, lock= lock) 128 | if self.valid_index == None and self.mode=="fit": # prepare val_index 129 | 130 | ''' 131 | try: 132 | assert self.val_filenames != None 133 | except: 134 | raise ValueError("No validation file provided ") 135 | ''' 136 | if self.val_filenames == None: 137 | lock = False 138 | else: 139 | lock = True 140 | for each in self.val_index_path: 141 | DB_Index(each, self.val_filenames, self.ms_level, self.valid_charge, self.annotated, lock=lock) 142 | if self.test_index == None : 143 | ''' 144 | try: 145 | assert self.test_filenames != None 146 | except: 147 | raise ValueError("No training file provided ") 148 | 149 | ''' 150 | if self.test_filenames == None: 151 | lock = False 152 | else: 153 | lock = True 154 | for each in self.test_index_path: 155 | DB_Index(each, self.test_filenames, self.ms_level, self.valid_charge, self.annotated, lock=lock) 156 | if self.train_index != None: # to be changed, add a checker for existance 157 | pass 158 | 159 | def _make_loader( 160 | self, dataset: torch.utils.data.Dataset, sampler = None 161 | ) -> torch.utils.data.DataLoader: 162 | return torch.utils.data.DataLoader( 163 | dataset, 164 | batch_size=self.batch_size, 165 | collate_fn=prepare_batch, 166 | pin_memory=True, 167 | num_workers=self.n_workers, 168 | sampler = sampler 169 | ) 170 | 171 | def train_dataloader(self) -> torch.utils.data.DataLoader: 172 | """Get the training DataLoader.""" 173 | 174 | 175 | assert self.train_dataset != None 176 | M= 898183 #498183 177 | sampler = RandomSampler(self.train_dataset, replacement=True, num_samples=M) 178 | return self._make_loader(self.train_dataset, sampler = sampler) 179 | 180 | def val_dataloader(self) -> torch.utils.data.DataLoader: 181 | """Get the validation DataLoader.""" 182 | if self.mode == "fit": 183 | return [self._make_loader(self.valid_dataset), self._make_loader(self.test_dataset)] 184 | return self._make_loader(self.valid_dataset) 185 | 186 | def test_dataloader(self) -> torch.utils.data.DataLoader: 187 | """Get the test DataLoader.""" 188 | return self._make_loader(self.test_dataset) 189 | 190 | def predict_dataloader(self) -> torch.utils.data.DataLoader: 191 | """Get the predict DataLoader.""" 192 | return self._make_loader(self.test_dataset) 193 | 194 | def prepare_batch( 195 | batch: List[Tuple[torch.Tensor, float, int, str]] 196 | ) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]: 197 | """ 198 | Collate MS/MS spectra into a batch. 199 | 200 | The MS/MS spectra will be padded so that they fit nicely as a tensor. 201 | However, the padded elements are ignored during the subsequent steps. 202 | 203 | Parameters 204 | ---------- 205 | batch : List[Tuple[torch.Tensor, float, int, str]] 206 | A batch of data from an AnnotatedSpectrumDataset, consisting of for each 207 | spectrum (i) a tensor with the m/z and intensity peak values, (ii), the 208 | precursor m/z, (iii) the precursor charge, (iv) the spectrum identifier. 209 | 210 | Returns 211 | ------- 212 | spectra : torch.Tensor of shape (batch_size, n_peaks, 2) 213 | The padded mass spectra tensor with the m/z and intensity peak values 214 | for each spectrum. 215 | precursors : torch.Tensor of shape (batch_size, 3) 216 | A tensor with the precursor neutral mass, precursor charge, and 217 | precursor m/z. 218 | spectrum_ids : np.ndarray 219 | The spectrum identifiers (during de novo sequencing) or peptide 220 | sequences (during training). 221 | """ 222 | spectra, precursor_mzs, precursor_charges, spectrum_ids = list(zip(*batch)) 223 | spectra = torch.nn.utils.rnn.pad_sequence(spectra, batch_first=True) 224 | precursor_mzs = torch.tensor(precursor_mzs) 225 | precursor_charges = torch.tensor(precursor_charges) 226 | precursor_masses = (precursor_mzs - 1.007276) * precursor_charges 227 | precursors = torch.vstack( 228 | [precursor_masses, precursor_charges, precursor_mzs] 229 | ).T.float() 230 | return spectra, precursors, np.asarray(spectrum_ids) -------------------------------------------------------------------------------- /ContraNovo/denovo/parser2.py: -------------------------------------------------------------------------------- 1 | """Mass spectrometry data parsers""" 2 | import logging 3 | from pathlib import Path 4 | from abc import ABC, abstractmethod 5 | 6 | import numpy as np 7 | from tqdm.auto import tqdm 8 | from pyteomics.mzml import MzML 9 | from pyteomics.mzxml import MzXML 10 | from pyteomics.mgf import MGF 11 | 12 | 13 | LOGGER = logging.getLogger(__name__) 14 | 15 | 16 | class BaseParser(ABC): 17 | """A base parser class to inherit from. 18 | 19 | Parameters 20 | ---------- 21 | ms_data_file : str or Path 22 | The mzML file to parse. 23 | ms_level : int 24 | The MS level of the spectra to parse. 25 | valid_charge : Iterable[int], optional 26 | Only consider spectra with the specified precursor charges. If `None`, 27 | any precursor charge is accepted. 28 | id_type : str, optional 29 | The Hupo-PSI prefix for the spectrum identifier. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | ms_data_file, 35 | ms_level, 36 | valid_charge=None, 37 | id_type="scan", 38 | ): 39 | """Initialize the BaseParser""" 40 | self.path = Path(ms_data_file) 41 | self.ms_level = ms_level 42 | self.valid_charge = None if valid_charge is None else set(valid_charge) 43 | self.id_type = id_type 44 | self.offset = None 45 | self.precursor_mz = [] 46 | self.precursor_charge = [] 47 | self.scan_id = [] 48 | self.mz_arrays = [] 49 | self.intensity_arrays = [] 50 | 51 | @abstractmethod 52 | def open(self): 53 | """Open the file as an iterable""" 54 | pass 55 | 56 | @abstractmethod 57 | def parse_spectrum(self, spectrum): 58 | """Parse a single spectrum 59 | 60 | Parameters 61 | ---------- 62 | spectrum : dict 63 | The dictionary defining the spectrum in a given format. 64 | """ 65 | pass 66 | 67 | def read(self): 68 | """Read the ms data file""" 69 | n_skipped = 0 70 | with self.open() as spectra: 71 | for spectrum in tqdm(spectra, desc=str(self.path), unit="spectra"): 72 | try: 73 | self.parse_spectrum(spectrum) 74 | except (IndexError, KeyError, ValueError): 75 | n_skipped += 1 76 | 77 | if n_skipped: 78 | LOGGER.warning( 79 | "Skipped %d spectra with invalid precursor info", n_skipped 80 | ) 81 | 82 | #self.precursor_mz = np.array(self.precursor_mz, dtype=np.float32) 83 | #self.precursor_charge = np.array(self.precursor_charge,dtype=np.float32,) 84 | 85 | self.scan_id = np.array(self.scan_id) 86 | 87 | # Build the index 88 | sizes = np.array([0] + [len(s) for s in self.mz_arrays]) 89 | self.offset = sizes[:-1].cumsum() 90 | #self.mz_arrays = np.array(self.mz_arrays, dtype = np.float32) 91 | #self.intensity_arrays = np.array(self.intensity_arrays, dtype=np.float32) 92 | #self.mz_arrays = np.concatenate(self.mz_arrays).astype(np.float64) 93 | #self.intensity_arrays = np.concatenate(self.intensity_arrays).astype(np.float32) 94 | 95 | @property 96 | def n_spectra(self): 97 | 98 | """The number of spectra""" 99 | 100 | return self.offset.shape[0] 101 | 102 | @property 103 | def n_peaks(self): 104 | #might not correct 105 | """The number of peaks in the file.""" 106 | mz_arrays_temp = np.concatenate(self.mz_arrays) 107 | return mz_arrays_temp.shape[0] 108 | 109 | 110 | 111 | class MzmlParser(BaseParser): 112 | """Parse mass spectra from an mzML file. 113 | 114 | Parameters 115 | ---------- 116 | ms_data_file : str or Path 117 | The mzML file to parse. 118 | ms_level : int 119 | The MS level of the spectra to parse. 120 | valid_charge : Iterable[int], optional 121 | Only consider spectra with the specified precursor charges. If `None`, 122 | any precursor charge is accepted. 123 | """ 124 | 125 | def __init__(self, ms_data_file, ms_level=2, valid_charge=None): 126 | """Initialize the MzmlParser.""" 127 | super().__init__( 128 | ms_data_file, 129 | ms_level=ms_level, 130 | valid_charge=valid_charge, 131 | ) 132 | 133 | def open(self): 134 | """Open the mzML file for reading""" 135 | return MzML(str(self.path)) 136 | 137 | def parse_spectrum(self, spectrum): 138 | """Parse a single spectrum. 139 | 140 | Parameters 141 | ---------- 142 | spectrum : dict 143 | The dictionary defining the spectrum in mzML format. 144 | """ 145 | if spectrum["ms level"] != self.ms_level: 146 | return 147 | 148 | if self.ms_level > 1: 149 | precursor = spectrum["precursorList"]["precursor"][0] 150 | precursor_ion = precursor["selectedIonList"]["selectedIon"][0] 151 | precursor_mz = float(precursor_ion["selected ion m/z"]) 152 | if "charge state" in precursor_ion: 153 | precursor_charge = int(precursor_ion["charge state"]) 154 | elif "possible charge state" in precursor_ion: 155 | precursor_charge = int(precursor_ion["possible charge state"]) 156 | else: 157 | precursor_charge = 0 158 | else: 159 | precursor_mz, precursor_charge = None, 0 160 | 161 | if self.valid_charge is None or precursor_charge in self.valid_charge: 162 | self.mz_arrays.append(list(spectrum["m/z array"])) 163 | self.intensity_arrays.append(list(spectrum["intensity array"])) 164 | self.precursor_mz.append(precursor_mz) 165 | self.precursor_charge.append(precursor_charge) 166 | self.scan_id.append(_parse_scan_id(spectrum["id"])) 167 | 168 | 169 | class MzxmlParser(BaseParser): 170 | """Parse mass spectra from an mzXML file. 171 | 172 | Parameters 173 | ---------- 174 | ms_data_file : str or Path 175 | The mzXML file to parse. 176 | ms_level : int 177 | The MS level of the spectra to parse. 178 | valid_charge : Iterable[int], optional 179 | Only consider spectra with the specified precursor charges. If `None`, 180 | any precursor charge is accepted. 181 | """ 182 | 183 | def __init__(self, ms_data_file, ms_level=2, valid_charge=None): 184 | """Initialize the MzxmlParser.""" 185 | super().__init__( 186 | ms_data_file, 187 | ms_level=ms_level, 188 | valid_charge=valid_charge, 189 | ) 190 | 191 | 192 | def open(self): 193 | """Open the mzXML file for reading""" 194 | return MzXML(str(self.path)) 195 | 196 | def parse_spectrum(self, spectrum): 197 | """Parse a single spectrum. 198 | 199 | Parameters 200 | ---------- 201 | spectrum : dict 202 | The dictionary defining the spectrum in mzXML format. 203 | """ 204 | if spectrum["msLevel"] != self.ms_level: 205 | return 206 | 207 | if self.ms_level > 1: 208 | precursor = spectrum["precursorMz"][0] 209 | precursor_mz = float(precursor["precursorMz"]) 210 | precursor_charge = int(precursor.get("precursorCharge", 0)) 211 | else: 212 | precursor_mz, precursor_charge = None, 0 213 | 214 | if self.valid_charge is None or precursor_charge in self.valid_charge: 215 | self.mz_arrays.append(list(spectrum["m/z array"])) 216 | self.intensity_arrays.append(list(spectrum["intensity array"])) 217 | self.precursor_mz.append(precursor_mz) 218 | self.precursor_charge.append(precursor_charge) 219 | self.scan_id.append(_parse_scan_id(spectrum["id"])) 220 | 221 | 222 | class MgfParser(BaseParser): 223 | """Parse mass spectra from an MGF file. 224 | 225 | Parameters 226 | ---------- 227 | ms_data_file : str or Path 228 | The MGF file to parse. 229 | ms_level : int 230 | The MS level of the spectra to parse. 231 | valid_charge : Iterable[int], optional 232 | Only consider spectra with the specified precursor charges. If `None`, 233 | any precursor charge is accepted. 234 | annotations : bool 235 | Include peptide annotations. 236 | """ 237 | 238 | def __init__( 239 | self, 240 | ms_data_file, 241 | ms_level=2, 242 | valid_charge=None, 243 | annotationsLabel=False, 244 | ): 245 | """Initialize the MgfParser.""" 246 | super().__init__( 247 | ms_data_file, 248 | ms_level=ms_level, 249 | valid_charge=valid_charge, 250 | id_type="index", 251 | ) 252 | self.annotationsLabel = annotationsLabel 253 | self.annotations = [] 254 | self._counter = 0 255 | 256 | def open(self): 257 | """Open the MGF file for reading""" 258 | return MGF(str(self.path)) 259 | 260 | def parse_spectrum(self, spectrum): 261 | """Parse a single spectrum. 262 | 263 | Parameters 264 | ---------- 265 | spectrum : dict 266 | The dictionary defining the spectrum in MGF format. 267 | """ 268 | if self.ms_level > 1: 269 | precursor_mz = float(spectrum["params"]["pepmass"][0]) 270 | precursor_charge = int(spectrum["params"].get("charge", [0])[0]) 271 | else: 272 | precursor_mz, precursor_charge = None, 0 273 | 274 | if self.annotationsLabel: 275 | self.annotations.append(spectrum["params"].get("seq")) 276 | else: 277 | # print(spectrum) 278 | # print((spectrum["params"]["title"])) 279 | self.annotations.append(spectrum["params"]["title"]) 280 | 281 | if self.valid_charge is None or precursor_charge in self.valid_charge: 282 | self.mz_arrays.append(list(spectrum["m/z array"])) 283 | self.intensity_arrays.append(list(spectrum["intensity array"])) 284 | self.precursor_mz.append(precursor_mz) 285 | self.precursor_charge.append(precursor_charge) 286 | self.scan_id.append(self._counter) 287 | 288 | self._counter += 1 289 | 290 | 291 | def _parse_scan_id(scan_str): 292 | """Remove the string prefix from the scan ID. 293 | 294 | Adapted from: 295 | https://github.com/bittremieux/GLEAMS/blob/ 296 | 8831ad6b7a5fc391f8d3b79dec976b51a2279306/gleams/ 297 | ms_io/mzml_io.py#L82-L85 298 | 299 | Parameters 300 | ---------- 301 | scan_str : str 302 | The scan ID string. 303 | 304 | Returns 305 | ------- 306 | int 307 | The scan ID number. 308 | """ 309 | try: 310 | return int(scan_str) 311 | except ValueError: 312 | try: 313 | return int(scan_str[scan_str.find("scan=") + len("scan=") :]) 314 | except ValueError: 315 | pass 316 | 317 | raise ValueError(f"Failed to parse scan number") 318 | -------------------------------------------------------------------------------- /ContraNovo/denovo/evaluate.py: -------------------------------------------------------------------------------- 1 | """Methods to evaluate peptide-spectrum predictions.""" 2 | import re 3 | from typing import Dict, Iterable, List, Tuple 4 | 5 | import numpy as np 6 | from spectrum_utils.utils import mass_diff 7 | 8 | 9 | def aa_match_prefix( 10 | peptide1: List[str], 11 | peptide2: List[str], 12 | aa_dict: Dict[str, float], 13 | cum_mass_threshold: float = 0.5, 14 | ind_mass_threshold: float = 0.1, 15 | ) -> Tuple[np.ndarray, bool]: 16 | """ 17 | Find the matching prefix amino acids between two peptide sequences. 18 | 19 | This is a similar evaluation criterion as used by DeepNovo. 20 | 21 | Parameters 22 | ---------- 23 | peptide1 : List[str] 24 | The first tokenized peptide sequence to be compared. 25 | peptide2 : List[str] 26 | The second tokenized peptide sequence to be compared. 27 | aa_dict : Dict[str, float] 28 | Mapping of amino acid tokens to their mass values. 29 | cum_mass_threshold : float 30 | Mass threshold in Dalton to accept cumulative mass-matching amino acid 31 | sequences. 32 | ind_mass_threshold : float 33 | Mass threshold in Dalton to accept individual mass-matching amino acids. 34 | 35 | Returns 36 | ------- 37 | aa_matches : np.ndarray of length max(len(peptide1), len(peptide2)) 38 | Boolean flag indicating whether each paired-up amino acid matches across 39 | both peptide sequences. 40 | pep_match : bool 41 | Boolean flag to indicate whether the two peptide sequences fully match. 42 | """ 43 | aa_matches = np.zeros(max(len(peptide1), len(peptide2)), np.bool_) 44 | # Find longest mass-matching prefix. 45 | i1, i2, cum_mass1, cum_mass2 = 0, 0, 0.0, 0.0 46 | while i1 < len(peptide1) and i2 < len(peptide2): 47 | aa_mass1 = aa_dict.get(peptide1[i1], 0) 48 | aa_mass2 = aa_dict.get(peptide2[i2], 0) 49 | if ( 50 | abs(mass_diff(cum_mass1 + aa_mass1, cum_mass2 + aa_mass2, True)) 51 | < cum_mass_threshold 52 | ): 53 | aa_matches[max(i1, i2)] = ( 54 | abs(mass_diff(aa_mass1, aa_mass2, True)) < ind_mass_threshold 55 | ) 56 | i1, i2 = i1 + 1, i2 + 1 57 | cum_mass1, cum_mass2 = cum_mass1 + aa_mass1, cum_mass2 + aa_mass2 58 | elif cum_mass2 + aa_mass2 > cum_mass1 + aa_mass1: 59 | i1, cum_mass1 = i1 + 1, cum_mass1 + aa_mass1 60 | else: 61 | i2, cum_mass2 = i2 + 1, cum_mass2 + aa_mass2 62 | return aa_matches, aa_matches.all() 63 | 64 | 65 | def aa_match_prefix_suffix( 66 | peptide1: List[str], 67 | peptide2: List[str], 68 | aa_dict: Dict[str, float], 69 | cum_mass_threshold: float = 0.5, 70 | ind_mass_threshold: float = 0.1, 71 | ) -> Tuple[np.ndarray, bool]: 72 | """ 73 | Find the matching prefix and suffix amino acids between two peptide 74 | sequences. 75 | 76 | Parameters 77 | ---------- 78 | peptide1 : List[str] 79 | The first tokenized peptide sequence to be compared. 80 | peptide2 : List[str] 81 | The second tokenized peptide sequence to be compared. 82 | aa_dict : Dict[str, float] 83 | Mapping of amino acid tokens to their mass values. 84 | cum_mass_threshold : float 85 | Mass threshold in Dalton to accept cumulative mass-matching amino acid 86 | sequences. 87 | ind_mass_threshold : float 88 | Mass threshold in Dalton to accept individual mass-matching amino acids. 89 | 90 | Returns 91 | ------- 92 | aa_matches : np.ndarray of length max(len(peptide1), len(peptide2)) 93 | Boolean flag indicating whether each paired-up amino acid matches across 94 | both peptide sequences. 95 | pep_match : bool 96 | Boolean flag to indicate whether the two peptide sequences fully match. 97 | """ 98 | # Find longest mass-matching prefix. 99 | aa_matches, pep_match = aa_match_prefix( 100 | peptide1, peptide2, aa_dict, cum_mass_threshold, ind_mass_threshold 101 | ) 102 | # No need to evaluate the suffixes if the sequences already fully match. 103 | if pep_match: 104 | return aa_matches, pep_match 105 | # Find longest mass-matching suffix. 106 | i1, i2 = len(peptide1) - 1, len(peptide2) - 1 107 | i_stop = np.argwhere(~aa_matches)[0] 108 | cum_mass1, cum_mass2 = 0.0, 0.0 109 | while i1 >= i_stop and i2 >= i_stop: 110 | aa_mass1 = aa_dict.get(peptide1[i1], 0) 111 | aa_mass2 = aa_dict.get(peptide2[i2], 0) 112 | if ( 113 | abs(mass_diff(cum_mass1 + aa_mass1, cum_mass2 + aa_mass2, True)) 114 | < cum_mass_threshold 115 | ): 116 | aa_matches[max(i1, i2)] = ( 117 | abs(mass_diff(aa_mass1, aa_mass2, True)) < ind_mass_threshold 118 | ) 119 | i1, i2 = i1 - 1, i2 - 1 120 | cum_mass1, cum_mass2 = cum_mass1 + aa_mass1, cum_mass2 + aa_mass2 121 | elif cum_mass2 + aa_mass2 > cum_mass1 + aa_mass1: 122 | i1, cum_mass1 = i1 - 1, cum_mass1 + aa_mass1 123 | else: 124 | i2, cum_mass2 = i2 - 1, cum_mass2 + aa_mass2 125 | return aa_matches, aa_matches.all() 126 | 127 | 128 | def aa_match( 129 | peptide1: List[str], 130 | peptide2: List[str], 131 | aa_dict: Dict[str, float], 132 | cum_mass_threshold: float = 0.5, 133 | ind_mass_threshold: float = 0.1, 134 | mode: str = "best", 135 | ) -> Tuple[np.ndarray, bool]: 136 | """ 137 | Find the matching amino acids between two peptide sequences. 138 | 139 | Parameters 140 | ---------- 141 | peptide1 : List[str] 142 | The first tokenized peptide sequence to be compared. 143 | peptide2 : List[str] 144 | The second tokenized peptide sequence to be compared. 145 | aa_dict : Dict[str, float] 146 | Mapping of amino acid tokens to their mass values. 147 | cum_mass_threshold : float 148 | Mass threshold in Dalton to accept cumulative mass-matching amino acid 149 | sequences. 150 | ind_mass_threshold : float 151 | Mass threshold in Dalton to accept individual mass-matching amino acids. 152 | mode : {"best", "forward", "backward"} 153 | The direction in which to find matching amino acids. 154 | 155 | Returns 156 | ------- 157 | aa_matches : np.ndarray of length max(len(peptide1), len(peptide2)) 158 | Boolean flag indicating whether each paired-up amino acid matches across 159 | both peptide sequences. 160 | pep_match : bool 161 | Boolean flag to indicate whether the two peptide sequences fully match. 162 | """ 163 | if mode == "best": 164 | return aa_match_prefix_suffix( 165 | peptide1, peptide2, aa_dict, cum_mass_threshold, ind_mass_threshold 166 | ) 167 | elif mode == "forward": 168 | return aa_match_prefix( 169 | peptide1, peptide2, aa_dict, cum_mass_threshold, ind_mass_threshold 170 | ) 171 | elif mode == "backward": 172 | aa_matches, pep_match = aa_match_prefix( 173 | list(reversed(peptide1)), 174 | list(reversed(peptide2)), 175 | aa_dict, 176 | cum_mass_threshold, 177 | ind_mass_threshold, 178 | ) 179 | return aa_matches[::-1], pep_match 180 | else: 181 | raise ValueError("Unknown evaluation mode") 182 | 183 | 184 | def aa_match_batch( 185 | peptides1: Iterable, 186 | peptides2: Iterable, 187 | aa_dict: Dict[str, float], 188 | cum_mass_threshold: float = 0.5, 189 | ind_mass_threshold: float = 0.1, 190 | mode: str = "best", 191 | ) -> Tuple[List[Tuple[np.ndarray, bool]], int, int]: 192 | """ 193 | Find the matching amino acids between multiple pairs of peptide sequences. 194 | 195 | Parameters 196 | ---------- 197 | peptides1 : Iterable 198 | The first list of peptide sequences to be compared. 199 | peptides2 : Iterable 200 | The second list of peptide sequences to be compared. 201 | aa_dict : Dict[str, float] 202 | Mapping of amino acid tokens to their mass values. 203 | cum_mass_threshold : float 204 | Mass threshold in Dalton to accept cumulative mass-matching amino acid 205 | sequences. 206 | ind_mass_threshold : float 207 | Mass threshold in Dalton to accept individual mass-matching amino acids. 208 | mode : {"best", "forward", "backward"} 209 | The direction in which to find matching amino acids. 210 | 211 | Returns 212 | ------- 213 | aa_matches_batch : List[Tuple[np.ndarray, bool]] 214 | For each pair of peptide sequences: (i) boolean flags indicating whether 215 | each paired-up amino acid matches across both peptide sequences, (ii) 216 | boolean flag to indicate whether the two peptide sequences fully match. 217 | n_aa1: int 218 | Total number of amino acids in the first list of peptide sequences. 219 | n_aa2: int 220 | Total number of amino acids in the second list of peptide sequences. 221 | """ 222 | aa_matches_batch, n_aa1, n_aa2 = [], 0, 0 223 | for peptide1, peptide2 in zip(peptides1, peptides2): 224 | # Split peptides into individual AAs if necessary. 225 | if isinstance(peptide1, str): 226 | peptide1 = re.split(r"(?<=.)(?=[A-Z])", peptide1) 227 | if isinstance(peptide2, str): 228 | peptide2 = re.split(r"(?<=.)(?=[A-Z])", peptide2) 229 | n_aa1, n_aa2 = n_aa1 + len(peptide1), n_aa2 + len(peptide2) 230 | aa_matches_batch.append( 231 | aa_match( 232 | peptide1, 233 | peptide2, 234 | aa_dict, 235 | cum_mass_threshold, 236 | ind_mass_threshold, 237 | mode, 238 | ) 239 | ) 240 | return aa_matches_batch, n_aa1, n_aa2 241 | 242 | 243 | def aa_match_metrics( 244 | aa_matches_batch: List[Tuple[np.ndarray, bool]], 245 | n_aa_true: int, 246 | n_aa_pred: int, 247 | ) -> Tuple[float, float, float]: 248 | """ 249 | Calculate amino acid and peptide-level evaluation metrics. 250 | 251 | Parameters 252 | ---------- 253 | aa_matches_batch : List[Tuple[np.ndarray, bool]] 254 | For each pair of peptide sequences: (i) boolean flags indicating whether 255 | each paired-up amino acid matches across both peptide sequences, (ii) 256 | boolean flag to indicate whether the two peptide sequences fully match. 257 | n_aa_true: int 258 | Total number of amino acids in the true peptide sequences. 259 | n_aa_pred: int 260 | Total number of amino acids in the predicted peptide sequences. 261 | 262 | Returns 263 | ------- 264 | aa_precision: float 265 | The number of correct AA predictions divided by the number of predicted 266 | AAs. 267 | aa_recall: float 268 | The number of correct AA predictions divided by the number of true AAs. 269 | pep_precision: float 270 | The number of correct peptide predictions divided by the number of 271 | peptides. 272 | """ 273 | n_aa_correct = sum( 274 | [aa_matches[0].sum() for aa_matches in aa_matches_batch] 275 | ) 276 | aa_precision = n_aa_correct / (n_aa_pred + 1e-8) 277 | aa_recall = n_aa_correct / (n_aa_true + 1e-8) 278 | pep_precision = sum([aa_matches[1] for aa_matches in aa_matches_batch]) / ( 279 | len(aa_matches_batch) + 1e-8 280 | ) 281 | return aa_precision, aa_recall, pep_precision 282 | 283 | 284 | def aa_precision_recall( 285 | aa_scores_correct: List[float], 286 | aa_scores_all: List[float], 287 | n_aa_total: int, 288 | threshold: float, 289 | ) -> Tuple[float, float]: 290 | """ 291 | Calculate amino acid level precision and recall at a given score threshold. 292 | 293 | Parameters 294 | ---------- 295 | aa_scores_correct : List[float] 296 | Amino acids scores for the correct amino acids predictions. 297 | aa_scores_all : List[float] 298 | Amino acid scores for all amino acids predictions. 299 | n_aa_total : int 300 | The total number of amino acids in the predicted peptide sequences. 301 | threshold : float 302 | The amino acid score threshold. 303 | 304 | Returns 305 | ------- 306 | aa_precision: float 307 | The number of correct amino acid predictions divided by the number of 308 | predicted amino acids. 309 | aa_recall: float 310 | The number of correct amino acid predictions divided by the total number 311 | of amino acids. 312 | """ 313 | n_aa_correct = sum([score > threshold for score in aa_scores_correct]) 314 | n_aa_predicted = sum([score > threshold for score in aa_scores_all]) 315 | return n_aa_correct / n_aa_predicted, n_aa_correct / n_aa_total 316 | -------------------------------------------------------------------------------- /ContraNovo/denovo/clipmodel.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | # from typing import Tuple,Union 3 | from typing import Any, Dict, List, Optional, Set, Tuple, Union 4 | import numpy as np 5 | from pytorch_lightning.utilities.types import STEP_OUTPUT 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from ..masses import PeptideMass 10 | from .. import utils2 as utils 11 | import itertools 12 | import re 13 | 14 | from ..components.encoders import MassEncoder,PeakEncoder,PositionalEncoder 15 | 16 | class LayerNorm(nn.LayerNorm): 17 | """Subclass torch's LayerNorm to handle fp16.""" 18 | 19 | def forward(self, x: torch.Tensor): 20 | orig_type = x.dtype 21 | ret = super().forward(x.type(torch.float32)) 22 | return ret.type(orig_type) 23 | 24 | 25 | class QuickGELU(nn.Module): 26 | def forward(self, x: torch.Tensor): 27 | return x * torch.sigmoid(1.702 * x) 28 | 29 | 30 | class ResidualAttentionBlock(nn.Module): 31 | def __init__(self, d_model: int, n_head: int, dropout: float): 32 | super().__init__() 33 | 34 | self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout) 35 | self.ln_1 = LayerNorm(d_model) 36 | self.mlp = nn.Sequential(OrderedDict([ 37 | ("c_fc", nn.Linear(d_model, d_model * 4)), 38 | ("gelu", QuickGELU()), 39 | ("c_proj", nn.Linear(d_model * 4, d_model)) 40 | ])) 41 | self.ln_2 = LayerNorm(d_model) 42 | # self.attn_mask = attn_mask 43 | 44 | def attention(self, x: torch.Tensor, key_padding_mask : torch.Tensor = None): 45 | key_padding_mask = key_padding_mask .to(dtype=x.dtype, device=x.device) if key_padding_mask is not None else None 46 | return self.attn(x, x, x, need_weights=False, key_padding_mask =key_padding_mask)[0] 47 | 48 | def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None): 49 | x = x + self.attention(self.ln_1(x),attn_mask) 50 | x = x + self.mlp(self.ln_2(x)) 51 | return x 52 | 53 | class Transformer(nn.Module): 54 | def __init__(self, width: int, layers: int, heads: int, dropout: float): 55 | super().__init__() 56 | self.width = width 57 | self.layers = layers 58 | self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, dropout) for _ in range(layers)]) 59 | 60 | def forward(self, x: torch.Tensor, key_padding_mask : torch.Tensor = None): 61 | for module in self.resblocks: 62 | x = module(x,key_padding_mask) 63 | return x 64 | 65 | 66 | class SpectrumEncoder(nn.Module): 67 | def __init__(self, 68 | embed_dim:int = 512, 69 | #Spectrum: 70 | n_peaks: int = 150, 71 | transformer_width: int = 512, 72 | transformer_heads: int = 8, 73 | transformer_layers: int = 9, 74 | max_charge = 10, 75 | dropout = 0.18): 76 | super().__init__() 77 | 78 | self.transformer_width = transformer_width 79 | self.embed_dim = embed_dim 80 | 81 | self.n_peaks = n_peaks 82 | 83 | self.mass_encoder = MassEncoder(transformer_width) 84 | 85 | # self.latent_spectrum = torch.nn.Parameter(torch.randn(1, 1, embed_dim)) 86 | 87 | # self.percursors_param = torch.nn.Parameter(torch.randn(transformer_width,transformer_width)) 88 | 89 | layer = torch.nn.TransformerEncoderLayer( 90 | d_model=transformer_width, 91 | nhead=transformer_heads, 92 | dim_feedforward=1024, 93 | batch_first=True, 94 | dropout=dropout, 95 | ) 96 | 97 | self.spectraTransformer = torch.nn.TransformerEncoder( 98 | layer, 99 | num_layers = transformer_layers, 100 | ) 101 | 102 | 103 | # self.spectraTransformer = Transformer( 104 | # width=transformer_width, 105 | # layers=transformer_layers, 106 | # heads=transformer_heads, 107 | # dropout=dropout 108 | # ) 109 | 110 | self.peak_encoder = PeakEncoder( 111 | transformer_width, 112 | dim_intensity=None 113 | ) 114 | 115 | #Precursor Encoder 116 | self.mass_encoder = MassEncoder(self.transformer_width) 117 | self.charge_encoder = torch.nn.Embedding(max_charge, transformer_width) 118 | 119 | 120 | 121 | # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 122 | 123 | def forward(self,spectra,precursors): 124 | '''---------------------------------------------------------------''' 125 | # Transformer Encoder for Peaks and Precursor. 126 | 127 | # masses = self.mass_encoder(precursors[:, None, [0]]) 128 | # charges = self.charge_encoder(precursors[:, 1].int() - 1) 129 | # precursors = masses + charges[:, None, :] 130 | 131 | zeros = ~spectra.sum(dim=2).bool() 132 | 133 | # mask = [ 134 | # # torch.tensor([[False]] * spectra.shape[0]).type_as(zeros), 135 | # # torch.tensor([[False]] * spectra.shape[0]).type_as(zeros), 136 | # zeros, 137 | # ] 138 | # mask = torch.cat(mask, dim=1) 139 | mask = zeros 140 | 141 | peaks = self.peak_encoder(spectra) 142 | 143 | # precursors = torch.matmul(precursors,self.percursors_param) 144 | # peaks = torch.concat([precursors,peaks],dim = 1) 145 | 146 | # Add the spectrum representation to each input: 147 | # latent_spectra = self.latent_spectrum.expand(peaks.shape[0], -1, -1) 148 | 149 | # peaks = torch.cat([latent_spectra, peaks], dim=1) 150 | 151 | # peaks = peaks + precursors 152 | 153 | # Peaks after Transformer Encoder named as pkt 154 | # pkt = peaks.permute(1,0,2) # Do premute because the batch_is_first of SelfAttentionLayer is False 155 | pkt = self.spectraTransformer(peaks,src_key_padding_mask=mask) 156 | # pkt = pkt.permute(1,0,2) #Shape(B_s, Specturm_len, Feature_size) 157 | 158 | return pkt, mask 159 | 160 | 161 | 162 | class PeptideEncoder(nn.Module): 163 | def __init__(self, 164 | #Peptide: 165 | max_len = 100, 166 | residues: Union[Dict[str, float], str] = "canonical", 167 | ptransformer_width: int = 512, 168 | ptransformer_heads: int = 8, 169 | ptransformer_layers: int = 9, 170 | dropout = 0.4): 171 | 172 | super().__init__() 173 | 174 | self.reverse = True 175 | 176 | self.max_len = max_len 177 | self._peptide_mass = PeptideMass(residues=residues) 178 | self._amino_acids = list(self._peptide_mass.masses.keys()) 179 | self._idx2aa = {i + 1: aa for i, aa in enumerate(self._amino_acids)} 180 | self._aa2idx = {aa: i for i, aa in self._idx2aa.items()} 181 | 182 | self.aminoEmbedDim = ptransformer_width - 256 183 | self.aa_encoder = torch.nn.Embedding( 184 | len(self._amino_acids) + 1, 185 | self.aminoEmbedDim, 186 | padding_idx=0, 187 | ) 188 | 189 | #Mass/charge Encoder for precursors (Dim:256 = dim_model // 2) 190 | self.mass_encoder = MassEncoder(256) 191 | self.charge_encoder = torch.nn.Embedding(10, 256) 192 | 193 | # MassEncoder for prefix and suffix mass 194 | self.prefixMassEncoder = MassEncoder(128) 195 | self.suffixMassEncoder = MassEncoder(128) 196 | 197 | self.massEncoder = MassEncoder(ptransformer_width) 198 | 199 | self.pos_encoder = PositionalEncoder(ptransformer_width) 200 | 201 | layer = torch.nn.TransformerEncoderLayer( 202 | d_model=ptransformer_width, 203 | nhead=ptransformer_heads, 204 | dim_feedforward=1024, 205 | batch_first=True, 206 | dropout=dropout, 207 | ) 208 | 209 | self.peptideTransformer = torch.nn.TransformerEncoder( 210 | layer, 211 | num_layers = ptransformer_layers, 212 | ) 213 | 214 | # self.peptideTransformer = Transformer( 215 | # width=ptransformer_width, 216 | # layers=ptransformer_layers, 217 | # heads=ptransformer_heads, 218 | # dropout=dropout 219 | # ) 220 | 221 | #Massses Embedding 222 | '''self.linearEmbedding = torch.nn.Linear(1,2)''' 223 | 224 | 225 | def forward(self,sequences,precursors): 226 | 227 | # Transformer Encoder For Peptide Sequence. 228 | # # Mass Encoder For Peptide Sequence. 229 | '''if sequences is not None: 230 | sequences = utils.listify(sequences) 231 | tokens = [self.tokenize(s) for s in sequences] 232 | Masses = [self.deMass(s) for s in sequences] 233 | Masses = torch.nn.utils.rnn.pad_sequence(Masses, batch_first = True) 234 | tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first = True) 235 | else: 236 | tokens = torch.tensor([[]]).to(self.device) 237 | Masses = torch.tensor([[]]).to(self.device) 238 | 239 | tgt = self.aa_encoder(tokens) 240 | 241 | Masses = Masses.unsqueeze(2) 242 | Masses = self.linearEmbedding(Masses) 243 | tgt = torch.concat([tgt,Masses],dim=2) 244 | 245 | tgt = self.pos_encoder(tgt)''' 246 | if sequences is not None: 247 | sequences = utils.listify(sequences) 248 | Masses = [self.deMass(sequences[i]) for i in range(len(sequences))] 249 | Masses = torch.nn.utils.rnn.pad_sequence(Masses, batch_first = True) 250 | suffixMasses = [self.get_suffix_mass(sequences[i],precursors[i][0]) for i in range(len(sequences))] 251 | suffixMasses = torch.nn.utils.rnn.pad_sequence(suffixMasses, batch_first = True) 252 | tokens = [self.tokenize(s) for s in sequences] 253 | tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first = True) 254 | else: 255 | Masses = torch.tensor([[]]).to(self.device) 256 | suffixMasses = torch.tensor([[]]).to(self.device) 257 | tokens = torch.tensor([[]]).to(self.device) 258 | 259 | 260 | masses = self.mass_encoder(precursors[:, None, [0]]) 261 | charges = self.charge_encoder(precursors[:, 1].int() - 1) 262 | precursors = masses + charges[:, None, :] 263 | 264 | preAndSufPrecursors = torch.tensor([[0]]).to(self.device) 265 | preAndSufPrecursors = self.prefixMassEncoder(preAndSufPrecursors) 266 | preAndSufPrecursors = preAndSufPrecursors.repeat(precursors.shape[0],1) 267 | preAndSufPrecursors = preAndSufPrecursors.unsqueeze(1) 268 | precursors = torch.cat([precursors,preAndSufPrecursors,preAndSufPrecursors],dim=2) 269 | 270 | 271 | 272 | Masses = Masses.unsqueeze(2) 273 | Masses = self.prefixMassEncoder(Masses) 274 | 275 | suffixMasses = suffixMasses.unsqueeze(2) 276 | suffixMasses = self.suffixMassEncoder(suffixMasses) 277 | 278 | Masses = torch.concat([Masses,suffixMasses],dim=2) 279 | 280 | tgt = self.aa_encoder(tokens) 281 | tgt_key_padding_mask = tgt.sum(axis=2) == 0 282 | tgt = torch.concat([tgt,Masses],dim=2) 283 | 284 | tgt = torch.cat([precursors, tgt], dim=1) 285 | 286 | tgt_key_padding_mask = tgt.sum(axis=2) == 0 287 | # Add positional code on peptide sequence. 288 | # tgt = self.pos_encoder(tgt) #(n_spectra, len(Peptide), dim_model) 289 | # Peptide input to Transformer. 290 | 291 | # tgt = tgt.permute(1,0,2) 292 | tgt = self.peptideTransformer(tgt,src_key_padding_mask = tgt_key_padding_mask) 293 | # tgt = tgt.permute(1,0,2) #Shape(B_s, Peptide_len, Feature_size) 294 | 295 | return tgt 296 | 297 | 298 | def tokenize(self, sequence, partial=False): 299 | """Transform a peptide sequence into tokens 300 | 301 | Parameters 302 | ---------- 303 | sequence : str 304 | A peptide sequence. 305 | 306 | Returns 307 | ------- 308 | torch.Tensor 309 | The token for each amino acid in the peptide sequence. 310 | """ 311 | if not isinstance(sequence, str): 312 | return sequence # Assume it is already tokenized. 313 | 314 | sequence = sequence.replace("I", "L") 315 | sequence = re.split(r"(?<=.)(?=[A-Z])", sequence) 316 | 317 | if self.reverse: 318 | sequence = list(reversed(sequence)) 319 | 320 | # if not partial: 321 | # sequence += ["$"] 322 | 323 | tokens = [self._aa2idx[aa] for aa in sequence] 324 | tokens = torch.tensor(tokens, device=self.device) 325 | return tokens 326 | 327 | def deMass(self,sequence): 328 | 329 | if not isinstance(sequence, str): 330 | 331 | sequence = [self._idx2aa.get(i.item(), "") for i in sequence] 332 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 333 | masses = list(itertools.accumulate(masses)) 334 | masses = torch.tensor(masses, device = self.device) 335 | 336 | return masses 337 | 338 | sequence = sequence.replace("I", "L") 339 | sequence = re.split(r"(?<=.)(?=[A-Z])", sequence) 340 | 341 | if self.reverse: 342 | sequence = list(reversed(sequence)) 343 | 344 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 345 | 346 | masses = list(itertools.accumulate(masses)) 347 | 348 | masses = torch.tensor(masses, device = self.device) 349 | 350 | return masses 351 | 352 | def get_suffix_mass(self,sequence,premass): 353 | 354 | if not isinstance(sequence, str): 355 | 356 | sequence = [self._idx2aa.get(i.item(), "") for i in sequence] 357 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 358 | masses = list(itertools.accumulate(masses)) 359 | masses = torch.tensor(masses, device = self.device) 360 | masses = premass - masses 361 | return masses 362 | 363 | sequence = sequence.replace("I", "L") 364 | sequence = re.split(r"(?<=.)(?=[A-Z])", sequence) 365 | 366 | if self.reverse: 367 | sequence = list(reversed(sequence)) 368 | 369 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 370 | 371 | masses = list(itertools.accumulate(masses)) 372 | 373 | masses = torch.tensor(masses, device = self.device) 374 | masses = premass - masses 375 | 376 | return masses 377 | 378 | def get_mass(self, sequence): 379 | 380 | if not isinstance(sequence, str): 381 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 382 | masstemp = torch.tensor(masses, device = self.device) 383 | return masstemp 384 | 385 | sequence = sequence.replace("I", "L") 386 | sequence = re.split(r"(?<=.)(?=[A-Z])", sequence) 387 | 388 | if self.reverse: 389 | sequence = list(reversed(sequence)) 390 | 391 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 392 | masstemp = torch.tensor(masses, device = self.device) 393 | # masstemp = torch.cat([masstemp,torch.tensor([0.0]).to(masstemp.device)]) 394 | 395 | return masstemp 396 | 397 | def getAminoAcid(self): 398 | AA_masslist = [self._peptide_mass.masses[self._idx2aa[i]] for i in range(1,28)] 399 | AA_masslist = [0] + AA_masslist 400 | AA_masslist = torch.tensor(AA_masslist,device = self.device) 401 | return AA_masslist 402 | 403 | @property 404 | def device(self): 405 | """The current device for the model""" 406 | return next(self.parameters()).device 407 | 408 | class CLIP(nn.Module): 409 | def __init__(self, 410 | embed_dim:int = 512, 411 | #Spectrum: 412 | n_peaks: int = 150, 413 | transformer_width: int = 512, 414 | transformer_heads: int = 8, 415 | transformer_layers: int = 9, 416 | #Peptide: 417 | max_len = 100, 418 | residues: Union[Dict[str, float], str] = "canonical", 419 | ptransformer_width: int = 512, 420 | ptransformer_heads: int = 8, 421 | ptransformer_layers: int = 9, 422 | max_charge = 10, 423 | ): 424 | super().__init__() 425 | 426 | self.spectrumEncoder = SpectrumEncoder(embed_dim = embed_dim, 427 | #Spectrum: 428 | n_peaks = n_peaks, 429 | transformer_width = transformer_width, 430 | transformer_heads = transformer_heads, 431 | transformer_layers = transformer_layers, 432 | max_charge = max_charge, 433 | ) 434 | self.peptideEncoder = PeptideEncoder(max_len = max_len, 435 | residues = residues, 436 | ptransformer_width = ptransformer_width, 437 | ptransformer_heads = ptransformer_heads, 438 | ptransformer_layers = ptransformer_layers) 439 | 440 | self.global_peptide = torch.nn.Parameter(torch.randn(1,1,ptransformer_width)) 441 | # Which be used to calc the global feature vector for spectrum. 442 | self.global_spectrum = torch.nn.Parameter(torch.randn(1,1,transformer_width)) 443 | 444 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 445 | 446 | def encode_spectrum(self,spectra,precursors): 447 | 448 | pkt = self.spectrumEncoder(spectra,precursors) 449 | 450 | '''-----------------------------------------------------''' 451 | # Extract the global features for Spectrum and Peptide. 452 | 453 | '''Sprectrum Global Features:''' 454 | pkt = torch.transpose(pkt,1,2) 455 | ratiospkt = torch.matmul(self.global_spectrum, pkt) 456 | ratiospkt = torch.softmax(ratiospkt,dim = 2) 457 | pkt = torch.transpose(pkt,1,2) 458 | pkt = torch.matmul(ratiospkt,pkt) 459 | pkt = pkt.squeeze(1) 460 | 461 | return pkt 462 | 463 | def encode_peptide(self,sequences): 464 | 465 | tgt = self.peptideEncoder(sequences) 466 | 467 | '''Peptide Global Features:''' 468 | 469 | tgt = torch.transpose(tgt,1,2) 470 | ratiostgt = torch.matmul(self.global_peptide, tgt) 471 | ratiostgt = torch.softmax(ratiostgt,dim = 2) 472 | tgt = torch.transpose(tgt,1,2) 473 | tgt = torch.matmul(ratiostgt,tgt) 474 | tgt = tgt.squeeze(1) 475 | return tgt 476 | 477 | def forward(self, spectra:torch.Tensor, 478 | precursors: torch.Tensor, 479 | sequences): 480 | 481 | pkt_features = self.encode_spectrum(spectra=spectra,precursors=precursors) 482 | tgt_features = self.encode_peptide(sequences=sequences) 483 | 484 | pkt_features = pkt_features / pkt_features.norm(dim = -1, keepdim = True) 485 | tgt_features = tgt_features / tgt_features.norm(dim = -1, keepdim = True) 486 | 487 | logit_scale = self.logit_scale.exp() 488 | logits_per_spec = logit_scale * pkt_features @ tgt_features.t() 489 | logits_per_tgt = logits_per_spec.t() 490 | 491 | return logits_per_spec, logits_per_tgt 492 | 493 | @property 494 | def device(self): 495 | """The current device for the model""" 496 | return next(self.parameters()).device 497 | 498 | -------------------------------------------------------------------------------- /ContraNovo/denovo/model_runner.py: -------------------------------------------------------------------------------- 1 | """Training and testing functionality for the de novo peptide sequencing 2 | model.""" 3 | import glob 4 | import logging 5 | import operator 6 | import os 7 | import tempfile 8 | import uuid 9 | from typing import Any, Dict, Iterable, List, Optional, Union 10 | 11 | import numpy as np 12 | import pytorch_lightning as pl 13 | import torch 14 | from pytorch_lightning.strategies import DDPStrategy 15 | from pytorch_lightning.profiler import SimpleProfiler 16 | 17 | from .. import utils 18 | from .db_dataloader import DeNovoDataModule 19 | from .model import Spec2Pep 20 | 21 | logger = logging.getLogger("ContraNovo") 22 | 23 | def predict( 24 | peak_path: str, 25 | model_filename: str, 26 | config: Dict[str, Any], 27 | out_writer: None 28 | ) -> None: 29 | """ 30 | Predict peptide sequences with a trained ContraNovo model. 31 | 32 | Parameters 33 | ---------- 34 | peak_path : str 35 | The path with peak files for predicting peptide sequences. 36 | model_filename : str 37 | The file name of the model weights (.ckpt file). 38 | config : Dict[str, Any] 39 | The configuration options. 40 | """ 41 | _execute_existing(peak_path, model_filename, config, False, out_writer) 42 | 43 | 44 | def evaluate(peak_path: str, model_filename: str, config: Dict[str, 45 | Any]) -> None: 46 | """ 47 | Evaluate peptide sequence predictions from a trained ContraNovo model. 48 | 49 | Parameters 50 | ---------- 51 | peak_path : str 52 | The path with peak files for predicting peptide sequences. 53 | model_filename : str 54 | The file name of the model weights (.ckpt file). 55 | config : Dict[str, Any] 56 | The configuration options. 57 | """ 58 | _execute_existing(peak_path, model_filename, config, True) 59 | 60 | 61 | def _execute_existing( 62 | peak_path: str, 63 | model_filename: str, 64 | config: Dict[str, Any], 65 | annotated: bool, 66 | out_writer=None, 67 | ) -> None: 68 | """ 69 | Predict peptide sequences with a trained ContraNovo model with/without 70 | evaluation. 71 | 72 | Parameters 73 | ---------- 74 | peak_path : str 75 | The path with peak files for predicting peptide sequences. 76 | model_filename : str 77 | The file name of the model weights (.ckpt file). 78 | config : Dict[str, Any] 79 | The configuration options. 80 | annotated : bool 81 | Whether the input peak files are annotated (execute in evaluation mode) 82 | or not (execute in prediction mode only). 83 | """ 84 | # Load the trained model. 85 | if not os.path.isfile(model_filename): 86 | logger.error( 87 | "Could not find the trained model weights at file %s", 88 | model_filename, 89 | ) 90 | raise FileNotFoundError("Could not find the trained model weights") 91 | model = Spec2Pep().load_from_checkpoint( 92 | model_filename, 93 | dim_model=config["dim_model"], 94 | n_head=config["n_head"], 95 | dim_feedforward=config["dim_feedforward"], 96 | n_layers=config["n_layers"], 97 | dropout=config["dropout"], 98 | dim_intensity=config["dim_intensity"], 99 | custom_encoder=config["custom_encoder"], 100 | max_length=config["max_length"], 101 | residues=config["residues"], 102 | max_charge=config["max_charge"], 103 | precursor_mass_tol=config["precursor_mass_tol"], 104 | isotope_error_range=config["isotope_error_range"], 105 | n_beams=config["n_beams"], 106 | n_log=config["n_log"], 107 | out_writer=out_writer, 108 | ) 109 | # Read the MS/MS spectra for which to predict peptide sequences. 110 | if annotated: 111 | peak_ext = (".mgf", ".h5", ".hdf5") 112 | else: 113 | peak_ext = (".mgf", ".mzml", ".mzxml", ".h5", ".hdf5") 114 | if len(peak_filenames := _get_peak_filenames(peak_path, peak_ext)) == 0: 115 | logger.error("Could not find peak files from %s", peak_path) 116 | raise FileNotFoundError("Could not find peak files") 117 | peak_is_not_index = any( 118 | [os.path.splitext(fn)[1] in (".mgf", ".mzxml", ".mzml") for fn in peak_filenames]) 119 | 120 | tmp_dir = tempfile.TemporaryDirectory() 121 | if peak_is_not_index: 122 | index_path = [os.path.join(tmp_dir.name, f"eval_{uuid.uuid4().hex}")] 123 | else: 124 | index_path = peak_filenames 125 | peak_filenames = None 126 | print("is peak not index?, ", peak_is_not_index) 127 | 128 | #SpectrumIdx = AnnotatedSpectrumIndex if annotated else SpectrumIndex 129 | valid_charge = np.arange(1, config["max_charge"] + 1) 130 | dataloader_params = dict( 131 | batch_size=config["predict_batch_size"], 132 | n_peaks=config["n_peaks"], 133 | min_mz=config["min_mz"], 134 | max_mz=config["max_mz"], 135 | min_intensity=config["min_intensity"], 136 | remove_precursor_tol=config["remove_precursor_tol"], 137 | n_workers=config["n_workers"], 138 | train_filenames = None, 139 | val_filenames = None, 140 | test_filenames = peak_filenames, 141 | train_index_path = None, #always a list, either a list containing one index path file or a list containing multiple db files 142 | val_index_path = None, 143 | test_index_path = index_path, 144 | annotated = annotated, 145 | valid_charge = valid_charge , 146 | mode = "test" 147 | ) 148 | # Initialize the data loader. 149 | dataModule = DeNovoDataModule(**dataloader_params) 150 | dataModule.prepare_data() 151 | dataModule.setup(stage="test") 152 | test_dataloader = dataModule.test_dataloader() 153 | 154 | # Create the Trainer object. 155 | trainer = pl.Trainer( 156 | enable_model_summary=True, 157 | accelerator="auto", 158 | auto_select_gpus=True, 159 | devices=_get_devices(), 160 | logger=config["logger"], 161 | max_epochs=config["max_epochs"], 162 | num_sanity_val_steps=config["num_sanity_val_steps"], 163 | strategy=_get_strategy(), 164 | ) 165 | # Run the model with/without validation. 166 | run_trainer = trainer.validate if annotated else trainer.predict 167 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 168 | print("model size is : ", pytorch_total_params) 169 | run_trainer(model,test_dataloader) 170 | # Clean up temporary files. 171 | tmp_dir.cleanup() 172 | 173 | 174 | def train( 175 | peak_path: str, 176 | peak_path_val: str, 177 | peak_path_test: str, 178 | model_filename: str, 179 | config: Dict[str, Any], 180 | ) -> None: 181 | """ 182 | Train a ContraNovo model. 183 | 184 | The model can be trained from scratch or by continuing training an existing 185 | model. 186 | 187 | Parameters 188 | ---------- 189 | peak_path : str 190 | The path with peak files to be used as training data. 191 | peak_path_val : str 192 | The path with peak files to be used as validation data. 193 | peak_path_test : str 194 | The path with peak files to be used as testing data. 195 | model_filename : str 196 | The file name of the model weights (.ckpt file). 197 | config : Dict[str, Any] 198 | The configuration options. 199 | """ 200 | # Read the MS/MS spectra to use for training and validation. 201 | ext = (".mgf", ".h5", ".hdf5") 202 | print("entering modelrunner_train") 203 | 204 | 205 | if len(train_filenames := _get_peak_filenames(peak_path, ext)) == 0: 206 | print(train_filenames) 207 | logger.error("Could not find training peak files from %s", peak_path) 208 | raise FileNotFoundError("Could not find training peak files") 209 | train_is_not_index = any([ 210 | os.path.splitext(fn)[1] in (".mgf", ".mzxml", ".mzml") for fn in train_filenames 211 | ]) 212 | ''' 213 | train_is_index = any([ 214 | os.path.splitext(fn)[1] in (".h5", ".hdf5") for fn in train_filenames 215 | ]) 216 | if train_is_index and len(train_filenames) > 1: 217 | logger.error("Multiple training HDF5 spectrum indexes specified") 218 | raise ValueError("Multiple training HDF5 spectrum indexes specified") 219 | ''' 220 | if (peak_path_val is None 221 | or len(val_filenames := _get_peak_filenames(peak_path_val, ext)) 222 | == 0): 223 | logger.error("Could not find validation peak files from %s", 224 | peak_path_val) 225 | raise FileNotFoundError("Could not find validation peak files") 226 | val_is_not_index = any( 227 | [os.path.splitext(fn)[1] in (".mgf", ".mzxml", ".mzml") for fn in val_filenames]) 228 | ''' 229 | val_is_index = any( 230 | [os.path.splitext(fn)[1] in (".h5", ".hdf5") for fn in val_filenames]) 231 | if val_is_index and len(val_filenames) > 1: 232 | logger.error("Multiple validation HDF5 spectrum indexes specified") 233 | raise ValueError("Multiple validation HDF5 spectrum indexes specified") 234 | ''' 235 | if (peak_path_test is None 236 | or len(test_filenames := _get_peak_filenames(peak_path_test, ext)) 237 | == 0): 238 | logger.error("Could not find testing peak files from %s", 239 | peak_path_test) 240 | raise FileNotFoundError("Could not find testing peak files") 241 | test_is_not_index = any( 242 | [os.path.splitext(fn)[1] in (".mgf", ".mzxml", ".mzml") for fn in test_filenames]) 243 | ''' 244 | test_is_index = any( 245 | [os.path.splitext(fn)[1] in (".h5", ".hdf5") for fn in test_filenames]) 246 | if test_is_index and len(test_filenames) > 1: 247 | logger.error("Multiple testing HDF5 spectrum indexes specified") 248 | raise ValueError("Multiple testing HDF5 spectrum indexes specified") 249 | ''' 250 | class MyDirectory: 251 | def __init__(self, sdir=None): 252 | self.name = sdir 253 | 254 | tmp_dir = MyDirectory("/mnt/petrelfs/jinzhi/NATdump/") 255 | 256 | #tmp_dir = tempfile.TemporaryDirectory() 257 | ''' 258 | if train_is_index: 259 | train_idx_fn, train_filenames = train_filenames[0], None 260 | else: 261 | train_idx_fn = os.path.join(tmp_dir.name, f"Train_{uuid.uuid4().hex}.hdf5") 262 | ''' 263 | 264 | if train_is_not_index: 265 | train_index_path = [os.path.join(tmp_dir.name, f"Train_{uuid.uuid4().hex}")] 266 | else: 267 | train_index_path = train_filenames 268 | train_filenames = None 269 | 270 | 271 | 272 | if val_is_not_index: 273 | val_index_path = [os.path.join(tmp_dir.name, f"valid_{uuid.uuid4().hex}")] 274 | else: 275 | val_index_path = val_filenames 276 | val_filenames = None 277 | if test_is_not_index: 278 | test_index_path = [os.path.join(tmp_dir.name, f"test_{uuid.uuid4().hex}")] 279 | else: 280 | test_index_path = test_filenames 281 | test_filenames = None 282 | 283 | valid_charge = np.arange(1, config["max_charge"] + 1) 284 | ''' 285 | train_index = AnnotatedSpectrumIndex(train_idx_fn, 286 | train_filenames, 287 | valid_charge=valid_charge) 288 | if val_is_index: 289 | val_idx_fn, val_filenames = val_filenames[0], None 290 | else: 291 | val_idx_fn = os.path.join(tmp_dir.name, f"Valid_{uuid.uuid4().hex}.hdf5") 292 | val_index = AnnotatedSpectrumIndex(val_idx_fn, 293 | val_filenames, 294 | valid_charge=valid_charge) 295 | if test_is_index: 296 | test_idx_fn, test_filenames = test_filenames[0], None 297 | else: 298 | test_idx_fn = os.path.join(tmp_dir.name, f"Test_{uuid.uuid4().hex}.hdf5") 299 | test_index = AnnotatedSpectrumIndex(test_idx_fn, 300 | test_filenames, 301 | valid_charge=valid_charge) 302 | ''' 303 | # Initialize the data loaders. 304 | dataloader_params = dict( 305 | batch_size=config["train_batch_size"], 306 | n_peaks=config["n_peaks"], 307 | min_mz=config["min_mz"], 308 | max_mz=config["max_mz"], 309 | min_intensity=config["min_intensity"], 310 | remove_precursor_tol=config["remove_precursor_tol"], 311 | n_workers=config["n_workers"], 312 | train_filenames = train_filenames, 313 | val_filenames = val_filenames, 314 | test_filenames = test_filenames, 315 | train_index_path = train_index_path, #always a list, either a list containing one index path file or a list containing multiple db files 316 | val_index_path = val_index_path, 317 | test_index_path = test_index_path, 318 | annotated = True, 319 | valid_charge = valid_charge , 320 | mode = "fit" 321 | 322 | ) 323 | dataModule = DeNovoDataModule(**dataloader_params) 324 | dataModule.prepare_data() 325 | dataModule.setup() 326 | train_dataloader=dataModule.train_dataloader() 327 | #train_loader = DeNovoDataModule(train_index=train_index, 328 | # **dataloader_params) 329 | #train_loader.setup() 330 | #train_dataloader=train_loader.train_dataloader() 331 | 332 | 333 | #val_loader = DeNovoDataModule(valid_index=val_index, **dataloader_params) 334 | #val_loader.setup() 335 | 336 | #test_loader = DeNovoDataModule(valid_index=test_index, **dataloader_params) 337 | #test_loader.setup() 338 | 339 | # Set warmup_iters & max_iters 340 | # Author: Sheng Xu 341 | # Date: 20230202 342 | config["warmup_iters"] = int(len(train_dataloader)/(torch.cuda.device_count()*config["accumulate_grad_batches"])) * config["warm_up_epochs"] 343 | config["max_iters"] = int(len(train_dataloader)/(torch.cuda.device_count()*config["accumulate_grad_batches"])) * int(config["max_epochs"]) 344 | 345 | # Initialize the model. 346 | ctc_params = dict(model_path=None, #to change 347 | alpha=0, beta=0, 348 | cutoff_top_n=100, 349 | cutoff_prob= 1.0, 350 | beam_width=config["n_beams"], 351 | num_processes=4, 352 | log_probs_input = False) 353 | model_params = dict( 354 | dim_model=config["dim_model"], 355 | n_head=config["n_head"], 356 | dim_feedforward=config["dim_feedforward"], 357 | n_layers=config["n_layers"], 358 | dropout=config["dropout"], 359 | dim_intensity=config["dim_intensity"], 360 | custom_encoder=config["custom_encoder"], 361 | max_length=config["max_length"], 362 | residues=config["residues"], 363 | max_charge=config["max_charge"], 364 | precursor_mass_tol=config["precursor_mass_tol"], 365 | isotope_error_range=config["isotope_error_range"], 366 | n_beams=config["n_beams"], 367 | n_log=config["n_log"], 368 | tb_summarywriter=config["tb_summarywriter"], 369 | warmup_iters=config["warmup_iters"], 370 | max_iters=config["max_iters"], 371 | lr=config["learning_rate"], 372 | weight_decay=config["weight_decay"], 373 | ctc_dic = ctc_params 374 | ) 375 | if config["train_from_scratch"]: 376 | model = Spec2Pep(**model_params) 377 | else: 378 | logger.info("Training from checkpoint...") 379 | model_filename = config["load_file_name"] 380 | if not os.path.isfile(model_filename): 381 | logger.error( 382 | "Could not find the model weights at file %s to continue " 383 | "training", 384 | model_filename, 385 | ) 386 | raise FileNotFoundError( 387 | "Could not find the model weights to continue training") 388 | model = Spec2Pep().load_from_checkpoint(model_filename, **model_params) 389 | # Create the Trainer object and (optionally) a checkpoint callback to 390 | # periodically save the model. 391 | if config["save_model"]: 392 | callbacks = [ 393 | pl.callbacks.ModelCheckpoint( 394 | dirpath=config["model_save_folder_path"], 395 | save_top_k=-1, 396 | save_weights_only=False, 397 | every_n_train_steps=config["every_n_train_steps"], 398 | ) 399 | ] 400 | else: 401 | callbacks = [] 402 | 403 | (path, filename) = os.path.split(val_index_path[0]) 404 | import time 405 | 406 | if config["SWA"]: 407 | callbacks.append(pl.callbacks.StochasticWeightAveraging(swa_lrs=1e-2)) 408 | 409 | if config["enable_neptune"]: 410 | callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='epoch')) 411 | neptune_logger = pl.loggers.NeptuneLogger( 412 | project=config["neptune_project"], 413 | api_token=config["neptune_api_token"], 414 | log_model_checkpoints=False, 415 | custom_run_id=filename + str(time.time()), 416 | name=filename + str(time.time()), 417 | tags=config["tags"] 418 | ) 419 | 420 | neptune_logger.log_hyperparams({ 421 | "train_batch_size": config["train_batch_size"], 422 | "n_cards": torch.cuda.device_count(), 423 | "random_seed": config["random_seed"], 424 | "train_filename":peak_path, 425 | "val_filename":peak_path_val, 426 | "test_filename":peak_path_test, 427 | "gradient_clip_val":config["gradient_clip_val"], 428 | "accumulate_grad_batches": config["accumulate_grad_batches"], 429 | "sync_batchnorm":config["sync_batchnorm"], 430 | "SWA":config["SWA"], 431 | "gradient_clip_algorithm":config["gradient_clip_algorithm"] 432 | }) 433 | print("num avaiable devices" , torch.cuda.device_count()) 434 | trainer = pl.Trainer( 435 | 436 | # reload_dataloaders_every_n_epochs=1, 437 | enable_model_summary= True, 438 | accelerator="auto", 439 | auto_select_gpus=True, 440 | callbacks=callbacks, 441 | devices=_get_devices(), 442 | num_nodes=config["n_nodes"], 443 | logger=neptune_logger if config["enable_neptune"] else None, 444 | max_epochs=config["max_epochs"], 445 | num_sanity_val_steps=config["num_sanity_val_steps"], 446 | strategy= _get_strategy(), 447 | gradient_clip_val=config["gradient_clip_val"], 448 | gradient_clip_algorithm=config["gradient_clip_algorithm"], 449 | accumulate_grad_batches=config["accumulate_grad_batches"], 450 | sync_batchnorm=config["sync_batchnorm"], 451 | ) 452 | pytorch_total_params = sum(p.numel() for p in model.parameters()) 453 | print("model size is : ", pytorch_total_params) 454 | # Train the model. 455 | if config["train_from_resume"] == True and config["train_from_scratch"] == False: 456 | trainer.fit(model, datamodule=dataModule,ckpt_path=config['load_file_name']) 457 | else: 458 | trainer.fit(model, 459 | datamodule=dataModule) 460 | # Clean up temporary files. 461 | tmp_dir.cleanup() 462 | 463 | 464 | def _get_peak_filenames( 465 | path: str, supported_ext: Iterable[str] = (".mgf", )) -> List[str]: 466 | """ 467 | Get all matching peak file names from the path pattern. 468 | 469 | Performs cross-platform path expansion akin to the Unix shell (glob, expand 470 | user, expand vars). 471 | 472 | Parameters 473 | ---------- 474 | path : str 475 | The path pattern. 476 | supported_ext : Iterable[str] 477 | Extensions of supported peak file formats. Default: MGF. 478 | 479 | Returns 480 | ------- 481 | List[str] 482 | The peak file names matching the path pattern. 483 | """ 484 | path = os.path.expanduser(path) 485 | path = os.path.expandvars(path) 486 | 487 | return [ 488 | fn for fn in glob.glob(path, recursive=True) 489 | #if os.path.splitext(fn.lower())[1] in supported_ext 490 | ] 491 | 492 | 493 | def _get_strategy() -> Optional[DDPStrategy]: 494 | """ 495 | Get the strategy for the Trainer. 496 | 497 | The DDP strategy works best when multiple GPUs are used. It can work for 498 | CPU-only, but definitely fails using MPS (the Apple Silicon chip) due to 499 | Gloo. 500 | 501 | Returns 502 | ------- 503 | Optional[DDPStrategy] 504 | The strategy parameter for the Trainer. 505 | """ 506 | if torch.cuda.device_count() > 1: 507 | return DDPStrategy(find_unused_parameters=False, static_graph=True) 508 | 509 | return None 510 | 511 | 512 | def _get_devices() -> Union[int, str]: 513 | """ 514 | Get the number of GPUs/CPUs for the Trainer to use. 515 | 516 | Returns 517 | ------- 518 | Union[int, str] 519 | The number of GPUs/CPUs to use, or "auto" to let PyTorch Lightning 520 | determine the appropriate number of devices. 521 | """ 522 | 523 | if any( 524 | operator.attrgetter(device + ".is_available")(torch)() 525 | for device in ["cuda", "backends.mps"]): 526 | return -1 527 | elif not (n_workers := utils.n_workers()): 528 | return "auto" 529 | else: 530 | return n_workers 531 | -------------------------------------------------------------------------------- /ContraNovo/components/transformers.py: -------------------------------------------------------------------------------- 1 | """Base Transformer models for working with mass spectra and peptides""" 2 | import re 3 | 4 | import torch 5 | 6 | from .encoders import MassEncoder, PeakEncoder, PositionalEncoder 7 | from ..masses import PeptideMass 8 | from .. import utils2 as utils 9 | import numpy as np 10 | import itertools 11 | 12 | class SpectrumEncoder(torch.nn.Module): 13 | """A Transformer encoder for input mass spectra. 14 | 15 | Parameters 16 | ---------- 17 | dim_model : int, optional 18 | The latent dimensionality to represent peaks in the mass spectrum. 19 | n_head : int, optional 20 | The number of attention heads in each layer. ``dim_model`` must be 21 | divisible by ``n_head``. 22 | dim_feedforward : int, optional 23 | The dimensionality of the fully connected layers in the Transformer 24 | layers of the model. 25 | n_layers : int, optional 26 | The number of Transformer layers. 27 | dropout : float, optional 28 | The dropout probability for all layers. 29 | peak_encoder : bool, optional 30 | Use positional encodings m/z values of each peak. 31 | dim_intensity: int or None, optional 32 | The number of features to use for encoding peak intensity. 33 | The remaining (``dim_model - dim_intensity``) are reserved for 34 | encoding the m/z value. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | dim_model=128, 40 | n_head=8, 41 | dim_feedforward=1024, 42 | n_layers=1, 43 | dropout=0, 44 | peak_encoder=True, 45 | dim_intensity=None, 46 | ): 47 | """Initialize a SpectrumEncoder""" 48 | super().__init__() 49 | 50 | # self.latent_spectrum = torch.nn.Parameter(torch.randn(1, 1, dim_model)) 51 | # self.spectrum_matrix = torch.nn.Parameter(torch.randn(dim_model,dim_model)) 52 | 53 | #dim_intensity = 128 54 | self.zeroPeaks_intensity = torch.nn.Parameter(torch.randn(1,1,1)) 55 | self.allPeaks_intensity = torch.nn.Parameter(torch.randn(1,1,1)) 56 | 57 | 58 | if peak_encoder: 59 | self.peak_encoder = PeakEncoder( 60 | dim_model, 61 | dim_intensity=dim_intensity, 62 | ) 63 | else: 64 | self.peak_encoder = torch.nn.Linear(2, dim_model) 65 | 66 | # The Transformer layers: 67 | layer = torch.nn.TransformerEncoderLayer( 68 | d_model=dim_model, 69 | nhead=n_head, 70 | dim_feedforward=dim_feedforward, 71 | batch_first=True, 72 | dropout=dropout, 73 | ) 74 | 75 | self.transformer_encoder = torch.nn.TransformerEncoder( 76 | layer, 77 | num_layers=n_layers, 78 | ) 79 | 80 | # Precursor Encoder 81 | # self.mass_encoder = MassEncoder(dim_model=256) 82 | # self.charge_encoder = torch.nn.Embedding(10, dim_model) 83 | 84 | def forward(self, spectra, precursors): 85 | """The forward pass. 86 | 87 | Parameters 88 | ---------- 89 | spectra : torch.Tensor of shape (n_spectra, n_peaks, 2) 90 | The spectra to embed. Axis 0 represents a mass spectrum, axis 1 91 | contains the peaks in the mass spectrum, and axis 2 is essentially 92 | a 2-tuple specifying the m/z-intensity pair for each peak. These 93 | should be zero-padded, such that all of the spectra in the batch 94 | are the same length. 95 | 96 | Returns 97 | ------- 98 | latent : torch.Tensor of shape (n_spectra, n_peaks + 1, dim_model) 99 | The latent representations for the spectrum and each of its 100 | peaks. 101 | mem_mask : torch.Tensor 102 | The memory mask specifying which elements were padding in X. 103 | """ 104 | 105 | # add percursors into encoder 106 | # masses = self.mass_encoder(precursors[:, None, [0]]) 107 | # charges = self.charge_encoder(precursors[:, 1].int() - 1) 108 | # precursors = masses + charges[:, None, :] 109 | 110 | zeroMass = torch.zeros([precursors.shape[0],1,1]).to(self.device) 111 | precursorMass = precursors[:, None, [0]] 112 | zeroPeaksIntensities = self.zeroPeaks_intensity.expand(precursors.shape[0],-1,-1) 113 | allPeaksIntensities = self.allPeaks_intensity.expand(precursors.shape[0],-1,-1) 114 | zeros = torch.cat([zeroMass,zeroPeaksIntensities],dim = 2) 115 | alls = torch.cat([precursorMass,allPeaksIntensities],dim = 2) 116 | starts = torch.cat([zeros,alls],dim = 1) 117 | spectra = torch.cat([starts,spectra],dim = 1) 118 | 119 | zeros = ~spectra.sum(dim=2).bool() 120 | mask = zeros 121 | # mask = [ 122 | # # add percursors into encoder 123 | # # torch.tensor([[False]] * spectra.shape[0]).type_as(zeros), 124 | # # torch.tensor([[False]] * spectra.shape[0]).type_as(zeros), 125 | # zeros, 126 | # ] 127 | # mask = torch.cat(mask, dim=1) 128 | peaks = self.peak_encoder(spectra) 129 | 130 | 131 | # precursors = torch.matmul(precursors,self.spectrum_matrix) 132 | # peaks = torch.concat([precursors,peaks],dim = 1) 133 | 134 | # Add the spectrum representation to each input: 135 | 136 | # latent_spectra = self.latent_spectrum.expand(peaks.shape[0], -1, -1) 137 | # peaks = torch.cat([latent_spectra, peaks], dim=1) 138 | return self.transformer_encoder(peaks, src_key_padding_mask=mask), mask 139 | 140 | @property 141 | def device(self): 142 | """The current device for the model""" 143 | return next(self.parameters()).device 144 | 145 | 146 | class _PeptideTransformer(torch.nn.Module): 147 | """A transformer base class for peptide sequences. 148 | 149 | Parameters 150 | ---------- 151 | dim_model : int 152 | The latent dimensionality to represent the amino acids in a peptide 153 | sequence. 154 | pos_encoder : bool 155 | Use positional encodings for the amino acid sequence. 156 | residues: Dict or str {"massivekb", "canonical"}, optional 157 | The amino acid dictionary and their masses. By default this is only 158 | the 20 canonical amino acids, with cysteine carbamidomethylated. If 159 | "massivekb", this dictionary will include the modifications found in 160 | MassIVE-KB. Additionally, a dictionary can be used to specify a custom 161 | collection of amino acids and masses. 162 | max_charge : int 163 | The maximum charge to embed. 164 | """ 165 | 166 | def __init__( 167 | self, 168 | dim_model, 169 | pos_encoder, 170 | residues, 171 | max_charge, 172 | ): 173 | super().__init__() 174 | self.reverse = False 175 | self._peptide_mass = PeptideMass(residues=residues) 176 | self._amino_acids = list(self._peptide_mass.masses.keys()) + ["$"] 177 | self._idx2aa = {i + 1: aa for i, aa in enumerate(self._amino_acids)} 178 | self._aa2idx = {aa: i for i, aa in self._idx2aa.items()} 179 | 180 | if pos_encoder: 181 | self.pos_encoder = PositionalEncoder(dim_model) 182 | else: 183 | self.pos_encoder = torch.nn.Identity() 184 | 185 | self.charge_encoder = torch.nn.Embedding(max_charge, dim_model) 186 | self.aa_encoder = torch.nn.Embedding( 187 | len(self._amino_acids) + 1, 188 | dim_model, 189 | padding_idx=0, 190 | ) 191 | 192 | def tokenize(self, sequence, partial=False): 193 | """Transform a peptide sequence into tokens 194 | 195 | Parameters 196 | ---------- 197 | sequence : str 198 | A peptide sequence. 199 | 200 | Returns 201 | ------- 202 | torch.Tensor 203 | The token for each amino acid in the peptide sequence. 204 | """ 205 | if not isinstance(sequence, str): 206 | return sequence # Assume it is already tokenized. 207 | 208 | sequence = sequence.replace("I", "L") 209 | sequence = re.split(r"(?<=.)(?=[A-Z])", sequence) 210 | 211 | if self.reverse: 212 | sequence = list(reversed(sequence)) 213 | 214 | if not partial: 215 | sequence += ["$"] 216 | 217 | tokens = [self._aa2idx[aa] for aa in sequence] 218 | tokens = torch.tensor(tokens, device=self.device) 219 | return tokens 220 | 221 | def deMass(self,sequence): 222 | 223 | if not isinstance(sequence, str): 224 | sequence = [self._idx2aa.get(i.item(), "") for i in sequence] 225 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 226 | # if(len(sequence) > 1): 227 | # print(sequence,masses) 228 | masses = list(itertools.accumulate(masses)) 229 | masses = torch.tensor(masses, device = self.device) 230 | 231 | return masses 232 | 233 | sequence = sequence.replace("I", "L") 234 | sequence = re.split(r"(?<=.)(?=[A-Z])", sequence) 235 | 236 | if self.reverse: 237 | sequence = list(reversed(sequence)) 238 | 239 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 240 | masses = list(itertools.accumulate(masses)) 241 | masses.append(0.0) 242 | 243 | masses = torch.tensor(masses, device = self.device) 244 | 245 | return masses 246 | 247 | def get_suffix_mass(self,sequence,premass): 248 | 249 | if not isinstance(sequence, str): 250 | 251 | sequence = [self._idx2aa.get(i.item(), "") for i in sequence] 252 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 253 | masses = list(itertools.accumulate(masses)) 254 | masses = torch.tensor(masses, device = self.device) 255 | masses = premass - masses 256 | # print(sequence,masses) 257 | return masses 258 | 259 | sequence = sequence.replace("I", "L") 260 | sequence = re.split(r"(?<=.)(?=[A-Z])", sequence) 261 | 262 | if self.reverse: 263 | sequence = list(reversed(sequence)) 264 | 265 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 266 | 267 | masses = list(itertools.accumulate(masses)) 268 | masses.append(premass) 269 | 270 | masses = torch.tensor(masses, device = self.device) 271 | masses = premass - masses 272 | 273 | return masses 274 | 275 | def get_mass(self, sequence): 276 | 277 | if not isinstance(sequence, str): 278 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 279 | masstemp = torch.tensor(masses, device = self.device) 280 | return masstemp 281 | 282 | sequence = sequence.replace("I", "L") 283 | sequence = re.split(r"(?<=.)(?=[A-Z])", sequence) 284 | 285 | if self.reverse: 286 | sequence = list(reversed(sequence)) 287 | 288 | masses = [self._peptide_mass.masses[aa] for aa in sequence] 289 | masstemp = torch.tensor(masses, device = self.device) 290 | masstemp = torch.cat([masstemp,torch.tensor([0.0]).to(masstemp.device)]) 291 | 292 | return masstemp 293 | 294 | def getAminoAcid(self): 295 | AA_masslist = [self._peptide_mass.masses[self._idx2aa[i]] for i in range(1,28)] 296 | AA_masslist = [0] + AA_masslist 297 | AA_masslist = torch.tensor(AA_masslist,device = self.device) 298 | return AA_masslist 299 | 300 | def detokenize(self, tokens): 301 | """Transform tokens back into a peptide sequence. 302 | 303 | Parameters 304 | ---------- 305 | tokens : torch.Tensor of shape (n_amino_acids,) 306 | The token for each amino acid in the peptide sequence. 307 | 308 | Returns 309 | ------- 310 | list of str 311 | The amino acids in the peptide sequence. 312 | """ 313 | sequence = [self._idx2aa.get(i.item(), "") for i in tokens] 314 | if "$" in sequence: 315 | idx = sequence.index("$") 316 | sequence = sequence[: idx + 1] 317 | 318 | if self.reverse: 319 | sequence = list(reversed(sequence)) 320 | 321 | return sequence 322 | 323 | @property 324 | def vocab_size(self): 325 | """Return the number of amino acids""" 326 | return len(self._aa2idx) 327 | 328 | @property 329 | def device(self): 330 | """The current device for the model""" 331 | return next(self.parameters()).device 332 | 333 | 334 | class PeptideEncoder(_PeptideTransformer): 335 | """A transformer encoder for peptide sequences. 336 | 337 | Parameters 338 | ---------- 339 | dim_model : int 340 | The latent dimensionality to represent the amino acids in a peptide 341 | sequence. 342 | n_head : int, optional 343 | The number of attention heads in each layer. ``dim_model`` must be 344 | divisible by ``n_head``. 345 | dim_feedforward : int, optional 346 | The dimensionality of the fully connected layers in the Transformer 347 | layers of the model. 348 | n_layers : int, optional 349 | The number of Transformer layers. 350 | dropout : float, optional 351 | The dropout probability for all layers. 352 | pos_encoder : bool, optional 353 | Use positional encodings for the amino acid sequence. 354 | residues: Dict or str {"massivekb", "canonical"}, optional 355 | The amino acid dictionary and their masses. By default this is only 356 | the 20 canonical amino acids, with cysteine carbamidomethylated. If 357 | "massivekb", this dictionary will include the modifications found in 358 | MassIVE-KB. Additionally, a dictionary can be used to specify a custom 359 | collection of amino acids and masses. 360 | max_charge : int, optional 361 | The maximum charge state for peptide sequences. 362 | """ 363 | 364 | def __init__( 365 | self, 366 | dim_model=128, 367 | n_head=8, 368 | dim_feedforward=1024, 369 | n_layers=1, 370 | dropout=0, 371 | pos_encoder=True, 372 | residues="canonical", 373 | max_charge=5, 374 | ): 375 | """Initialize a PeptideEncoder""" 376 | super().__init__( 377 | dim_model=dim_model, 378 | pos_encoder=pos_encoder, 379 | residues=residues, 380 | max_charge=max_charge, 381 | ) 382 | 383 | # The Transformer layers: 384 | layer = torch.nn.TransformerEncoderLayer( 385 | d_model=dim_model, 386 | nhead=n_head, 387 | dim_feedforward=dim_feedforward, 388 | batch_first=True, 389 | dropout=dropout, 390 | ) 391 | 392 | self.transformer_encoder = torch.nn.TransformerEncoder( 393 | layer, 394 | num_layers=n_layers, 395 | ) 396 | 397 | def forward(self, sequences, charges): 398 | """Predict the next amino acid for a collection of sequences. 399 | 400 | Parameters 401 | ---------- 402 | sequences : list of str or list of torch.Tensor of length batch_size 403 | The partial peptide sequences for which to predict the next 404 | amino acid. Optionally, these may be the token indices instead 405 | of a string. 406 | charges : torch.Tensor of size (batch_size,) 407 | The charge state of the peptide 408 | 409 | Returns 410 | ------- 411 | latent : torch.Tensor of shape (n_sequences, len_sequence, dim_model) 412 | The latent representations for the spectrum and each of its 413 | peaks. 414 | mem_mask : torch.Tensor 415 | The memory mask specifying which elements were padding in X. 416 | """ 417 | sequences = utils.listify(sequences) 418 | tokens = [self.tokenize(s) for s in sequences] 419 | tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True) 420 | encoded = self.aa_encoder(tokens) 421 | 422 | # Encode charges 423 | charges = self.charge_encoder(charges - 1)[:, None] 424 | encoded = torch.cat([charges, encoded], dim=1) 425 | 426 | # Create mask 427 | mask = ~encoded.sum(dim=2).bool() 428 | 429 | # Add positional encodings 430 | encoded = self.pos_encoder(encoded) 431 | 432 | # Run through the model: 433 | latent = self.transformer_encoder(encoded, src_key_padding_mask=mask) 434 | return latent, mask 435 | 436 | 437 | class PeptideDecoder(_PeptideTransformer): 438 | """A transformer decoder for peptide sequences. 439 | 440 | Parameters 441 | ---------- 442 | dim_model : int, optional 443 | The latent dimensionality to represent peaks in the mass spectrum. 444 | n_head : int, optional 445 | The number of attention heads in each layer. ``dim_model`` must be 446 | divisible by ``n_head``. 447 | dim_feedforward : int, optional 448 | The dimensionality of the fully connected layers in the Transformer 449 | layers of the model. 450 | n_layers : int, optional 451 | The number of Transformer layers. 452 | dropout : float, optional 453 | The dropout probability for all layers. 454 | pos_encoder : bool, optional 455 | Use positional encodings for the amino acid sequence. 456 | reverse : bool, optional 457 | Sequence peptides from c-terminus to n-terminus. 458 | residues: Dict or str {"massivekb", "canonical"}, optional 459 | The amino acid dictionary and their masses. By default this is only 460 | the 20 canonical amino acids, with cysteine carbamidomethylated. If 461 | "massivekb", this dictionary will include the modifications found in 462 | MassIVE-KB. Additionally, a dictionary can be used to specify a custom 463 | collection of amino acids and masses. 464 | """ 465 | 466 | def __init__( 467 | self, 468 | dim_model=128, 469 | n_head=8, 470 | dim_feedforward=1024, 471 | n_layers=1, 472 | dropout=0, 473 | pos_encoder=True, 474 | reverse=True, 475 | residues="canonical", 476 | max_charge=5, 477 | ): 478 | """Initialize a PeptideDecoder""" 479 | super().__init__( 480 | dim_model=dim_model, 481 | pos_encoder=pos_encoder, 482 | residues=residues, 483 | max_charge=max_charge, 484 | ) 485 | self.reverse = reverse 486 | 487 | self.aaDim = dim_model - 256 488 | # Additional model components 489 | 490 | #Mass/charge Encoder for precursors (Dim:256 = dim_model // 2) 491 | self.mass_encoder = MassEncoder(256) 492 | self.charge_encoder = torch.nn.Embedding(max_charge, 256) 493 | 494 | 495 | # MassEncoder for prefix and suffix mass 496 | self.prefixMassEncoder = MassEncoder(128) 497 | self.suffixMassEncoder = MassEncoder(128) 498 | 499 | self.aa_encoder = torch.nn.Embedding( 500 | len(self._amino_acids) + 1, 501 | self.aaDim, 502 | padding_idx=0, 503 | ) 504 | 505 | layer = torch.nn.TransformerDecoderLayer( 506 | d_model=dim_model, 507 | nhead=n_head, 508 | dim_feedforward=dim_feedforward, 509 | batch_first=True, 510 | dropout=dropout, 511 | ) 512 | 513 | self.transformer_decoder = torch.nn.TransformerDecoder( 514 | layer, 515 | num_layers=n_layers, 516 | ) 517 | 518 | # self.startvector = torch.nn.Parameter(torch.randn(1, 1, dim_model)) 519 | 520 | self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 521 | self.final_aa_Encoder = torch.nn.Embedding( 522 | len(self._amino_acids) + 1, 523 | self.aaDim, 524 | padding_idx=0, 525 | ) 526 | 527 | finalLinears = [] 528 | xin = dim_model 529 | for xout in [512,1024,1024]: 530 | finalLinears.append(torch.nn.Linear(xin,xout)) 531 | finalLinears.append(torch.nn.PReLU()) 532 | xin = xout 533 | finalLinears.append(torch.nn.Linear(xin,512)) 534 | 535 | self.finalLinears = torch.nn.Sequential(*finalLinears) 536 | # self.final_linear = torch.nn.Linear(dim_model,dim_model) 537 | 538 | self.final_mass_encoder = MassEncoder(256) 539 | # self.finalMassLinearLayer = torch.nn.Sequential(torch.nn.Linear(1,256),torch.nn.PReLU(),torch.nn.Linear(256,256)) 540 | self.finalCharMass = torch.nn.Parameter(torch.randn(1)) 541 | 542 | # self.final = torch.nn.Linear(dim_model, len(self._amino_acids) + 1) 543 | 544 | def forward(self, sequences, precursors, memory, memory_key_padding_mask): 545 | """Predict the next amino acid for a collection of sequences. 546 | 547 | Parameters 548 | ---------- 549 | sequences : list of str or list of torch.Tensor 550 | The partial peptide sequences for which to predict the next 551 | amino acid. Optionally, these may be the token indices instead 552 | of a string. 553 | precursors : torch.Tensor of size (batch_size, 2) 554 | The measured precursor mass (axis 0) and charge (axis 1) of each 555 | tandem mass spectrum 556 | memory : torch.Tensor of shape (batch_size, n_peaks, dim_model) 557 | The representations from a ``TransformerEncoder``, such as a 558 | ``SpectrumEncoder``. 559 | memory_key_padding_mask : torch.Tensor of shape (batch_size, n_peaks) 560 | The mask that indicates which elements of ``memory`` are padding. 561 | 562 | Returns 563 | ------- 564 | scores : torch.Tensor of size (batch_size, len_sequence, n_amino_acids) 565 | The raw output for the final linear layer. These can be Softmax 566 | transformed to yield the probability of each amino acid for the 567 | prediction. 568 | tokens : torch.Tensor of size (batch_size, len_sequence) 569 | The input padded tokens. 570 | 571 | """ 572 | # Prepare sequences 573 | '''if sequences is not None: 574 | sequences = utils.listify(sequences) 575 | tokens = [self.tokenize(s) for s in sequences] 576 | tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True) 577 | else: 578 | tokens = torch.tensor([[]]).to(self.device)''' 579 | 580 | if sequences is not None: 581 | # print(sequences[0],precursors[0]) 582 | sequences = utils.listify(sequences) 583 | Masses = [self.deMass(s) for s in sequences] 584 | Masses = torch.nn.utils.rnn.pad_sequence(Masses, batch_first = True) 585 | suffixMasses = [self.get_suffix_mass(sequences[i],precursors[i][0]) for i in range(len(sequences))] 586 | suffixMasses = torch.nn.utils.rnn.pad_sequence(suffixMasses, batch_first = True) 587 | tokens = [self.tokenize(s) for s in sequences] 588 | tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first = True) 589 | else: 590 | Masses = torch.tensor([[]]).to(self.device) 591 | suffixMasses = torch.tensor([[]]).to(self.device) 592 | tokens = torch.tensor([[]]).to(self.device) 593 | 594 | # Prepare mass and charge 595 | masses = self.mass_encoder(precursors[:, None, [0]]) 596 | charges = self.charge_encoder(precursors[:, 1].int() - 1) 597 | precursors = masses + charges[:, None, :] 598 | 599 | preAndSufPrecursors = torch.tensor([[0]]).to(self.device) 600 | preAndSufPrecursors = self.prefixMassEncoder(preAndSufPrecursors) 601 | preAndSufPrecursors = preAndSufPrecursors.repeat(precursors.shape[0],1) 602 | preAndSufPrecursors = preAndSufPrecursors.unsqueeze(1) 603 | precursors = torch.cat([precursors,preAndSufPrecursors,preAndSufPrecursors],dim=2) 604 | 605 | 606 | 607 | # masses = self.mass_encoder(precursors[:, None, [0]]) 608 | # charges = self.charge_encoder(precursors[:, 1].int() - 1) 609 | # precursors = masses + charges[:, None, :] 610 | 611 | Masses = Masses.unsqueeze(2) 612 | Masses = self.prefixMassEncoder(Masses) 613 | 614 | suffixMasses = suffixMasses.unsqueeze(2) 615 | suffixMasses = self.suffixMassEncoder(suffixMasses) 616 | 617 | Masses = torch.concat([Masses,suffixMasses],dim=2) 618 | 619 | tgt = self.aa_encoder(tokens.to(torch.long)) 620 | tgt_key_padding_mask = tgt.sum(axis=2) == 0 621 | 622 | tgtTemp = torch.concat([tgt,Masses],dim=2) 623 | 624 | # startVector = self.startvector.expand(tgtTemp.shape[0], -1, -1) 625 | # tgtTemp = self.aa_encoder(tokens) 626 | 627 | # Feed through model: 628 | if sequences is None: 629 | tgt = precursors 630 | else: 631 | tgt = torch.cat([precursors, tgtTemp], dim=1) 632 | 633 | tgt_key_padding_mask = tgt.sum(axis=2) == 0 634 | tgt = self.pos_encoder(tgt) 635 | tgt_mask = generate_tgt_mask(tgt.shape[1]).type_as(precursors) 636 | preds = self.transformer_decoder( 637 | tgt=tgt, 638 | memory=memory, 639 | tgt_mask=tgt_mask, 640 | tgt_key_padding_mask=tgt_key_padding_mask, 641 | memory_key_padding_mask=memory_key_padding_mask.to(self.device), 642 | ) 643 | 644 | aa_masses = self.getAminoAcid() 645 | aa_idx = torch.range(0,28).to(torch.long).to(self.device) 646 | aa_masses = torch.concat([aa_masses,self.finalCharMass],dim = 0) 647 | aa_masses = aa_masses.unsqueeze(1) 648 | 649 | 650 | #Mass Encoder for cos similar of Amino(with $ finalChar) and PepDecoder. 651 | # aa_masses = torch.cat([aa_masses,self.finalCharMass],dim = 0) 652 | aa_masses = self.final_mass_encoder(aa_masses) 653 | 654 | 655 | aa_idx = self.final_aa_Encoder(aa_idx) 656 | final_martix = torch.concat([aa_masses,aa_idx],dim = -1) 657 | final_martix = self.finalLinears(final_martix) 658 | preds = self.logit_scale * preds @ final_martix.t() 659 | 660 | # return preds,tokens 661 | return torch.softmax(preds,dim=2), tokens 662 | 663 | # return torch.softmax(self.final(preds),dim=2), tokens 664 | 665 | 666 | def generate_tgt_mask(sz): 667 | """Generate a square mask for the sequence. The masked positions 668 | are filled with float('-inf'). Unmasked positions are filled with 669 | float(0.0). 670 | 671 | This function is a slight modification of the version in the PyTorch 672 | repository. 673 | 674 | Parameters 675 | ---------- 676 | sz : int 677 | The length of the target sequence. 678 | """ 679 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 680 | mask = ( 681 | mask.float() 682 | .masked_fill(mask == 0, float("-inf")) 683 | .masked_fill(mask == 1, float(0.0)) 684 | ) 685 | return mask 686 | --------------------------------------------------------------------------------