├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── config ├── deconvolute_config.json ├── finetune_config.json └── preprocess_finetune_config.json ├── environment.yml ├── img └── introduction_methylbert.png ├── pyproject.toml ├── requirements.txt ├── src └── methylbert │ ├── __init__.py │ ├── cli.py │ ├── config.py │ ├── data │ ├── __init__.py │ ├── bam.py │ ├── dataset.py │ ├── finetune_data_generate.py │ ├── genome.py │ └── vocab.py │ ├── deconvolute.py │ ├── function.py │ ├── network.py │ ├── trainer.py │ └── utils.py ├── test ├── data │ ├── config.json │ ├── dmrs.csv │ └── processed │ │ ├── dmrs.csv │ │ ├── test_seq.csv │ │ └── train_seq.csv ├── test_deconvolute.py ├── test_finetune.py └── test_finetune_preprocess.py └── tutorials ├── 01_Data_Preparation.md ├── 02_Preprocessing_training_data.ipynb ├── 03_Preprocessing_bulk_data.ipynb ├── 04_Fine-tuning_MethylBERT_model.ipynb └── 05_tumour_deconvolution.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *__pycache__/ 3 | methylbert/data/__pycache__/ 4 | #*.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | test/data/* 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mambaorg/micromamba:cuda11.8.0-ubuntu20.04 2 | 3 | RUN micromamba install -n base -y -c conda-forge pip python==3.11 pip freetype-py 4 | 5 | RUN micromamba clean --all --yes 6 | 7 | #if you need to run pip install in the same environment, uncomment the following lines 8 | 9 | ARG MAMBA_DOCKERFILE_ACTIVATE=1 10 | 11 | RUN mkdir src/ 12 | RUN mkdir src/methylbert/ 13 | COPY src/methylbert/ src/methylbert/ 14 | COPY pyproject.toml . 15 | RUN pip install . 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yunhee Jeong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MethylBERT: A Transformer-based model for read-level DNA methylation pattern identification and tumour deconvolution 2 | [![DOI](https://zenodo.org/badge/559284606.svg)](https://doi.org/10.5281/zenodo.14025051) 3 | 4 | ![methylbert_scheme](https://github.com/CompEpigen/methylbert/blob/main/img/introduction_methylbert.png) 5 | _The figure was generated using [biorender](https://www.biorender.com/)_ 6 | 7 | BERT model to classify read-level DNA methylation data into tumour/normal and perform tumour deconvolution. 8 | _MethylBERT_ is implemented using [pytorch](https://pytorch.org/) and [transformers](https://huggingface.co/docs/transformers/index) 🤗. 9 | 10 | ## Paper 11 | _MethylBERT_ paper is now online on [__Nature Communications__](https://www.nature.com/articles/s41467-025-55920-z#article-info)!! 12 | 13 | __MethylBERT enables read-level DNA methylation pattern identification and tumour deconvolution using a Transformer-based model__ 14 | 15 | Yunhee Jeong, Clarissa Gerhäuser, Guido Sauter, Thorsten Schlomm, Karl Rohr and Pavlo Lutsik 16 | 17 | ## Installation 18 | _MethylBERT_ runs most stably with __Python=3.11__ 19 | 20 | ### Pip Installation 21 | _MethylBERT_ is available as a [python package](https://pypi.org/project/methylbert/). 22 | ``` 23 | conda create -n methylbert -c conda-forge python=3.11 cudatoolkit==11.8 pip freetype-py 24 | conda activate methylbert 25 | pip install methylbert 26 | ``` 27 | 28 | ### Manual Installation 29 | You can set up your conda environment with the `environment.yml` file by 30 | running `conda env create --file environment.yml` or instead: 31 | ``` 32 | conda create -n methylbert -c conda-forge python=3.11 cudatoolkit==11.8 pip freetype-py 33 | conda activate methylbert 34 | git clone https://github.com/hanyangii/methylbert.git 35 | cd methylbert 36 | pip3 install . 37 | ``` 38 | 39 | ## Quick start 40 | ### Python library 41 | If you want to use _MethylBERT_ as a python library, please follow the [tutorials](https://methylbert.readthedocs.io/en/latest/). 42 | 43 | ### Command line 44 | MethylBERT supports a command line tool. Before using the command line tool, please check [the input file requirements](https://github.com/hanyangii/methylbert/blob/main/tutorials/01_Data_Preparation.md) 45 | ``` 46 | > methylbert 47 | MethylBERT v2.0.1 48 | One option must be given from ['preprocess_finetune', 'finetune', 'deconvolute'] 49 | ``` 50 | `-h` or `--help` provides available arguments for each function. (e.g., `methylbert preprocess_finetune --help`) 51 | 52 | #### 1. Data Preprocessing to fine-tune MethylBERT 53 | **e.g.)** `methylbert preprocess_finetune -f bulk.bam -d dmrs.csv -r genome.fa -p 0.8 -c 50 -o data/` 54 | ``` 55 | -s SC_DATASET, --sc_dataset SC_DATASET 56 | a file all single-cell bam files are listed up. The first and second columns must indicate file names and cell types if cell types are given. Otherwise, each line must have one file path. 57 | -f INPUT_FILE, --input_file INPUT_FILE 58 | .bam file to be processed 59 | -d F_DMR, --f_dmr F_DMR 60 | .bed or .csv file DMRs information is contained 61 | -o OUTPUT_PATH, --output_path OUTPUT_PATH 62 | a directory where all generated results will be saved 63 | -r F_REF, --f_ref F_REF 64 | .fasta file containing reference genome 65 | -nm N_MERS, --n_mers N_MERS 66 | K for K-mer sequences (default: 3) 67 | -m METHYLCALLER, --methylcaller METHYLCALLER 68 | Used methylation caller. It must be either bismark or dorado (default: bismark) 69 | -p SPLIT_RATIO, --split_ratio SPLIT_RATIO 70 | the ratio between train and test dataset (default: 0.8) 71 | -nd N_DMRS, --n_dmrs N_DMRS 72 | Number of DMRs to take from the dmr file. If the value is not given, all DMRs will be used 73 | -c N_CORES, --n_cores N_CORES 74 | number of cores for the multiprocessing (default: 1) 75 | --seed SEED random seed number (default: 950410) 76 | --ignore_sex_chromo IGNORE_SEX_CHROMO 77 | Whether DMRs at sex chromosomes (chrX and chrY) will be ignored (default: True) 78 | ``` 79 | #### 2. MethylBERT fine-tuning 80 | **e.g.)** `methylbert finetune -c data/train_seq.csv -t data/test_seq.csv -o model/ -l 12 -s 150 -b 256 --gradient_accumulation_steps 4 -e 600 -w 8 --log_freq 1 --eval_freq 1 --warm_up 1 --lr 1e-4 --decrease_steps 200` 81 | ``` 82 | -c TRAIN_DATASET, --train_dataset TRAIN_DATASET 83 | train dataset for train bert 84 | -t TEST_DATASET, --test_dataset TEST_DATASET 85 | test set for evaluate train set 86 | -o OUTPUT_PATH, --output_path OUTPUT_PATH 87 | ex)output/bert.model 88 | -p PRETRAIN, --pretrain PRETRAIN 89 | path to the saved pretrained model to restore 90 | -l N_ENCODER, --n_encoder N_ENCODER 91 | number of encoder blocks. One of [12, 8, 6] need to be given. A pre-trained MethylBERT model is downloaded accordingly. Ignored when -p (--pretrain) is given. 92 | -nm N_MERS, --n_mers N_MERS 93 | n-mers (default: 3) 94 | -s SEQ_LEN, --seq_len SEQ_LEN 95 | maximum sequence len (default: 150) 96 | -b BATCH_SIZE, --batch_size BATCH_SIZE 97 | number of batch_size (default: 50) 98 | --valid_batch VALID_BATCH 99 | number of batch_size in valid set. If it's not given, valid_set batch size is set same as the train_set batch size 100 | --corpus_lines CORPUS_LINES 101 | total number of lines in corpus 102 | --loss LOSS Loss function for fine-tuning. It can be either 'bce' or 'focal_bce' (default: bce) 103 | --max_grad_norm MAX_GRAD_NORM 104 | Max gradient norm (default: 1.0) 105 | --gradient_accumulation_steps GRADIENT_ACCUMULATION_STEPS 106 | Number of updates steps to accumulate before performing a backward/update pass. (default: 1) 107 | -e STEPS, --steps STEPS 108 | number of training steps (default: 600) 109 | --save_freq SAVE_FREQ 110 | Steps to save the interim model 111 | -w NUM_WORKERS, --num_workers NUM_WORKERS 112 | dataloader worker size (default: 20) 113 | --with_cuda WITH_CUDA 114 | training with CUDA: true, or false (default: True) 115 | --log_freq LOG_FREQ Frequency (steps) to print the loss values (default: 100) 116 | --eval_freq EVAL_FREQ 117 | Evaluate the model every n iter (default: 10) 118 | --lr LR learning rate of adamW (default: 4e-4) 119 | --adam_weight_decay ADAM_WEIGHT_DECAY 120 | weight_decay of adamW (default: 0.01) 121 | --adam_beta1 ADAM_BETA1 122 | adamW first beta value (default: 0.9) 123 | --adam_beta2 ADAM_BETA2 124 | adamW second beta value (default: 0.98) 125 | --warm_up WARM_UP steps for warm-up (default: 100) 126 | --decrease_steps DECREASE_STEPS 127 | step to decrease the learning rate (default: 200) 128 | --seed SEED seed number (default: 950410) 129 | ``` 130 | #### 3. MethylBERT tumour deconvolution 131 | **e.g.)** `methylbert deconvolute -i data/data.txt -m model/ -o res/ -b 128 --adjustment` 132 | ``` 133 | -i INPUT_DATA, --input_data INPUT_DATA 134 | Bulk data to deconvolute 135 | -m MODEL_DIR, --model_dir MODEL_DIR 136 | Trained methylbert model 137 | -o OUTPUT_PATH, --output_path OUTPUT_PATH 138 | Directory to save deconvolution results. (default: ./) 139 | -b BATCH_SIZE, --batch_size BATCH_SIZE 140 | Batch size. Please decrease the number if you do not have enough memory to run the software (default: 64) 141 | --save_logit Save logits from the model (default: False) 142 | --adjustment Adjust the estimated tumour purity (default: False) 143 | ``` 144 | ## Citation 145 | ``` 146 | @article{jeong2025methylbert, 147 | title={MethylBERT enables read-level DNA methylation pattern identification and tumour deconvolution using a Transformer-based model}, 148 | author={Jeong, Yunhee and Gerh{\"a}user, Clarissa and Sauter, Guido and Schlomm, Thorsten and Rohr, Karl and Lutsik, Pavlo}, 149 | journal={Nature Communications}, 150 | volume={16}, 151 | number={1}, 152 | pages={788}, 153 | year={2025}, 154 | publisher={Nature Publishing Group UK London} 155 | } 156 | ``` 157 | -------------------------------------------------------------------------------- /config/deconvolute_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "input_data": "data/data.csv", 3 | "model_dir": "model/", 4 | "output_path": "output/", 5 | "batch_size": 64, 6 | "save_logit": false, 7 | "adjustment": false 8 | } 9 | -------------------------------------------------------------------------------- /config/finetune_config.json: -------------------------------------------------------------------------------- 1 | {"train_dataset": "data/train_seq.csv", 2 | "test_dataset": "data/test_seq.csv", 3 | "output_path": "res/", 4 | "pretrain": null, 5 | "n_encoder": 2, 6 | "n_mers": 3, 7 | "seq_len": 150, 8 | "batch_size": 10, 9 | "valid_batch": -1, 10 | "corpus_lines": null, 11 | "loss": "focal_bce", 12 | "max_grad_norm": 1.0, 13 | "gradient_accumulation_steps": 1, 14 | "steps": 2, 15 | "save_freq": null, 16 | "num_workers": 8, 17 | "with_cuda": true, 18 | "log_freq": 1, 19 | "eval_freq": 1, 20 | "lr": 0.0001, 21 | "adam_weight_decay": 0.01, 22 | "adam_beta1": 0.9, 23 | "adam_beta2": 0.98, 24 | "warm_up": 1, 25 | "decrease_steps": 200, 26 | "seed": 950410} 27 | -------------------------------------------------------------------------------- /config/preprocess_finetune_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "sc_dataset": null, 3 | "input_file": "data/bulk.bam", 4 | "f_dmr": "data/dmrs.csv", 5 | "output_dir": "res/", 6 | "f_ref": "data/genome.fa", 7 | "n_mers": 3, 8 | "methylcaller": "bismark", 9 | "split_ratio": 0.8, 10 | "n_dmrs": -1, 11 | "n_cores": 50, 12 | "seed": 950410, 13 | "ignore_sex_chromo": true 14 | } 15 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # run: conda env create --file environment.yml 2 | name: methylbert 3 | channels: 4 | - defaults 5 | - conda-forge 6 | dependencies: 7 | - python=3.11 8 | - cudatoolkit=11.8 9 | - pip 10 | - pip: 11 | - -e . -------------------------------------------------------------------------------- /img/introduction_methylbert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompEpigen/methylbert/f82f83b0bc20e2af1441795536ec9f338fbbb3e7/img/introduction_methylbert.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "methylbert" 7 | version = "2.0.2" 8 | description = "A Transformer-based model for read-level DNA methylation pattern identification and tumour deconvolution" 9 | authors = [ 10 | { name = "Yunhee Jeong", email = "y.jeong@dkfz-heidelberg.de" } 11 | ] 12 | readme = "README.md" 13 | license = {file = "LICENSE"} 14 | 15 | dependencies = [ 16 | "biopython==1.84", 17 | "matplotlib==3.9.2", 18 | "numpy==2.1.1", 19 | "pandas==2.2.2", 20 | "pysam==0.22.1", 21 | "scikit-learn==1.5.1", 22 | "scipy==1.14.1", 23 | "torch==2.4.1", 24 | "tqdm==4.66.5", 25 | "transformers==4.44.2", 26 | "tokenizers==0.19.1", 27 | "urllib3==2.2.2", 28 | "zipp==3.13.0" 29 | ] 30 | requires-python = ">=3.11" 31 | 32 | [tool.setuptools.packages.find] 33 | where = ["src"] 34 | include = ["methylbert*"] 35 | exclude = ["methylbert/cli.py"] 36 | 37 | [tool.setuptools.package-data] 38 | "test" = ["*"] 39 | recipe = ["*.py"] 40 | 41 | [project.entry-points.console_scripts] 42 | methylbert = "methylbert.cli:main" 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Bio 2 | biopython<=1.81 3 | matplotlib<3.3,>=3.2 4 | numpy<1.21 5 | pandas<1.4.0 6 | pysam 7 | scikit_learn<1.1.0 8 | scipy<1.7.0 9 | torch>=2.4.0 10 | tqdm 11 | transformers>=4.0.0 12 | urllib3<1.27,>=1.25.4 13 | zipp==3.13.0 14 | -------------------------------------------------------------------------------- /src/methylbert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__="2.0.2" 2 | -------------------------------------------------------------------------------- /src/methylbert/cli.py: -------------------------------------------------------------------------------- 1 | import argparse, sys, os, json 2 | 3 | import pickle as pk 4 | import pandas as pd 5 | import numpy as np 6 | 7 | from methylbert.data.finetune_data_generate import finetune_data_generate 8 | from methylbert.trainer import MethylBertFinetuneTrainer 9 | from methylbert.data.vocab import MethylVocab 10 | from methylbert.data.dataset import MethylBertFinetuneDataset 11 | from methylbert.utils import set_seed 12 | from methylbert.deconvolute import deconvolute 13 | from methylbert import __version__ 14 | 15 | import torch 16 | from torch.utils.data import DataLoader 17 | 18 | OPTIONS = [ 19 | "preprocess_finetune", 20 | "finetune", 21 | "deconvolute" 22 | ] 23 | 24 | def deconvolute_arg_parser(subparsers): 25 | parser = subparsers.add_parser('deconvolute', help='Run MethylBERT tumour deconvolution') 26 | 27 | parser.add_argument("-i", "--input_data", required=True, type=str, help="Bulk data to deconvolute") 28 | parser.add_argument("-m", "--model_dir", required=True, type=str, help="Trained methylbert model") 29 | parser.add_argument("-o", "--output_path", type=str, default="./", help="Directory to save deconvolution results. (default: ./)") 30 | 31 | # Running parametesr 32 | parser.add_argument("-b", "--batch_size", type=int, default=64, help="Batch size. Please decrease the number if you do not have enough memory to run the software (default: 64)") 33 | parser.add_argument("--save_logit", default=False, action="store_true", help="Save logits from the model (default: False)") 34 | parser.add_argument("--adjustment", default=False, action="store_true", help="Adjust the estimated tumour purity (default: False)") 35 | 36 | 37 | def finetune_arg_parser(subparsers): 38 | parser = subparsers.add_parser('finetune', help='Run MethylBERT fine-tuning') 39 | 40 | # Data and directory paths 41 | parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for train bert") 42 | parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluate train set") 43 | parser.add_argument("-o", "--output_path", required=True, type=str, help="ex)output/bert.model") 44 | 45 | # For a pre-trained model 46 | parser.add_argument("-p", "--pretrain", type=str, default=None, help="path to the saved pretrained model to restore") 47 | parser.add_argument("-l", "--n_encoder", type=int, default=None, help="number of encoder blocks. One of [12, 8, 6, 4, 2] need to be given. A pre-trained MethylBERT model is downloaded accordingly. Ignored when -p (--pretrain) is given.") 48 | parser.add_argument("--without_pretrain", default=False, action="store_true", help="Use MethylBERT without a pre-trained model.") 49 | 50 | # Hyperparams for input data processing 51 | parser.add_argument("-nm", "--n_mers", type=int, default=3, help="n-mers (default: 3)") 52 | parser.add_argument("-s", "--seq_len", type=int, default=150, help="maximum sequence len (default: 150)") 53 | parser.add_argument("-b", "--batch_size", type=int, default=50, help="number of batch_size (default: 50)") 54 | parser.add_argument("--valid_batch", type=int, default=-1, help="number of batch_size in valid set. If it's not given, valid_set batch size is set same as the train_set batch size") 55 | parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus") 56 | 57 | # Hyperparams for training 58 | parser.add_argument("--loss", type=str, default="bce", help="Loss function for fine-tuning. It can be either \'bce\' or \'focal_bce\' (default: bce)") 59 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm (default: 1.0)") 60 | parser.add_argument( 61 | "--gradient_accumulation_steps", 62 | type=int, 63 | default=1, 64 | help="Number of updates steps to accumulate before performing a backward/update pass. (default: 1)", 65 | ) 66 | parser.add_argument("-e", "--steps", type=int, default=600, help="number of training steps (default: 600)") 67 | parser.add_argument("--save_freq", type=int, default=None, help="Steps to save the interim model") 68 | parser.add_argument("-w", "--num_workers", type=int, default=20, help="dataloader worker size (default: 20)") 69 | 70 | parser.add_argument("--with_cuda", default=False, action="store_true", help="training with CUDA (GPU)") 71 | parser.add_argument("--log_freq", type=int, default=100, help="Frequency (steps) to print the loss values (default: 100)") 72 | parser.add_argument("--eval_freq", type=int, default=10, help="Evaluate the model every n iter (default: 10)") 73 | parser.add_argument("--lr", type=float, default=4e-4, help="learning rate of adamW (default: 4e-4)") 74 | parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adamW (default: 0.01)") 75 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="adamW first beta value (default: 0.9)") 76 | parser.add_argument("--adam_beta2", type=float, default=0.98, help="adamW second beta value (default: 0.98)") 77 | parser.add_argument("--warm_up", type=int, default=100, help="steps for warm-up (default: 100)") 78 | parser.add_argument("--decrease_steps", type=int, default=200, help="step to decrease the learning rate (default: 200)") 79 | 80 | # Others 81 | parser.add_argument("--seed", type=int, default=950410, help="seed number (default: 950410)") 82 | 83 | def preprocess_finetune_arg_parser(subparsers): 84 | 85 | parser = subparsers.add_parser('preprocess_finetune', help='Preprocess .bam files for finetuning') 86 | 87 | parser.add_argument("-s", "--sc_dataset", required=False, default=None, type=str, help="a file all single-cell bam files are listed up. The first and second columns must indicate file names and cell types if cell types are given. Otherwise, each line must have one file path.") 88 | parser.add_argument("-f", "--input_file", required=False, default=None, type=str, help=".bam file to be processed") 89 | parser.add_argument("-d", "--f_dmr", required=True, type=str, help=".bed or .csv file DMRs information is contained") 90 | parser.add_argument("-o", "--output_path", required=True, type=str, help="a directory where all generated results will be saved") 91 | parser.add_argument("-r", "--f_ref", required=True, type=str, help=".fasta file containing reference genome") 92 | 93 | parser.add_argument("-nm", "--n_mers", type=int, default=3, help="K for K-mer sequences (default: 3)") 94 | parser.add_argument("-m", "--methylcaller", type=str, default="bismark", help="Used methylation caller. It must be either bismark or dorado (default: bismark)") 95 | parser.add_argument("-p", "--split_ratio", type=float, default=0.8, help="the ratio between train and test dataset (default: 0.8)") 96 | parser.add_argument("-nd", "--n_dmrs", type=int, default=-1, help="Number of DMRs to take from the dmr file. If the value is not given, all DMRs will be used") 97 | parser.add_argument("-c", "--n_cores", type=int, default=1, help="number of cores for the multiprocessing (default: 1)") 98 | parser.add_argument("--seed", type=int, default=950410, help="random seed number (default: 950410)") 99 | parser.add_argument("--ignore_sex_chromo", default=False, action="store_true", help="Whether DMRs at sex chromosomes (chrX and chrY) will be ignored") 100 | 101 | def run_finetune(args): 102 | 103 | with open(args.output_path+"/train_param.txt", "w") as f_param: 104 | dict_args = vars(args) 105 | for key in dict_args: 106 | f_param.write(key+"\t"+str(dict_args[key])+"\n") 107 | 108 | # Set seed 109 | set_seed(args.seed) 110 | 111 | #print("On memory: ", args.on_memory) 112 | 113 | # Create a tokenizer 114 | print("Create a tokenizer for %d-mers"%(args.n_mers)) 115 | 116 | tokenizer=MethylVocab(k=args.n_mers) 117 | print("Vocab Size: ", len(tokenizer)) 118 | 119 | torch.set_num_threads(40) 120 | print("CPU info:", torch.get_num_threads(), torch.get_num_interop_threads()) 121 | 122 | 123 | # Load data sets 124 | print("Loading Train Dataset:", args.train_dataset) 125 | train_dataset = MethylBertFinetuneDataset(args.train_dataset, tokenizer, 126 | seq_len=args.seq_len) 127 | 128 | print("%d seqs with %d labels "%(len(train_dataset), train_dataset.num_dmrs())) 129 | print("Loading Test Dataset:", args.test_dataset) 130 | 131 | if args.test_dataset is not None: 132 | test_dataset = MethylBertFinetuneDataset(args.test_dataset, tokenizer, 133 | seq_len=args.seq_len) 134 | 135 | # Create a data loader 136 | print("Creating Dataloader") 137 | local_step_batch_size = int(args.batch_size/args.gradient_accumulation_steps) 138 | print("Local step batch size : ", local_step_batch_size) 139 | 140 | train_data_loader = DataLoader(train_dataset, batch_size=local_step_batch_size, num_workers= args.num_workers, pin_memory=False, shuffle=True) 141 | 142 | if args.valid_batch < 0: 143 | args.valid_batch = args.batch_size 144 | 145 | test_data_loader = DataLoader(test_dataset, batch_size=args.valid_batch, num_workers=args.num_workers, pin_memory=True, shuffle=False) if args.test_dataset is not None else None 146 | 147 | # BERT train 148 | print("Creating BERT Trainer") 149 | trainer = MethylBertFinetuneTrainer(len(tokenizer), save_path=args.output_path+"bert.model/", 150 | train_dataloader=train_data_loader, 151 | test_dataloader=test_data_loader, 152 | lr=args.lr, beta=(args.adam_beta1, args.adam_beta2), 153 | weight_decay=args.adam_weight_decay, 154 | with_cuda=args.with_cuda, 155 | log_freq=args.log_freq, 156 | eval_freq=args.eval_freq, 157 | gradient_accumulation_steps=args.gradient_accumulation_steps, 158 | max_grad_norm = args.max_grad_norm, 159 | warmup_step=args.warm_up, 160 | decrease_steps=args.decrease_steps, 161 | save_freq=args.save_freq, 162 | loss=args.loss) 163 | 164 | if not args.without_pretrain: 165 | # Set up pre-trained model 166 | if ( args.pretrain is None ) and ( args.n_encoder is None ): 167 | raise ValueError("Either -p (--pretrain) or -l (--n_encoder) need to be given to find a pre-trained model.") 168 | elif args.pretrain is None: 169 | print(f"Pre-trained MethylBERT model for {args.n_encoder} encoder blocks is selected.") 170 | args.pretrain = f"hanyangii/methylbert_hg19_{args.n_encoder}l" 171 | 172 | # Load pre-trained model 173 | trainer.load(args.pretrain) 174 | else: 175 | # Download the config file for the given n_encoder 176 | if args.n_encoder is None: 177 | raise ValueError("Please give the number of BERT encoder blocks to use with -l option") 178 | trainer.create_model(config_file=f"hanyangii/methylbert_hg19_{args.n_encoder}l") 179 | 180 | # Fine-tune 181 | print("Training Start") 182 | trainer.train(args.steps) 183 | 184 | def run_deconvolute(args): 185 | # Reload training parameters 186 | params = dict() 187 | try: 188 | os.path.exists(args.model_dir+"train_param.txt") 189 | except: 190 | FileNotFoundError(f"{args.model_dir}train_param.txt does not exist. Please check if MethylBERT is fine-tuned") 191 | exit() 192 | 193 | with open(args.model_dir+'train_param.txt', "r") as fp: 194 | for li in fp.readlines(): 195 | li = li.strip().split('\t') 196 | params[li[0]] = li[1] 197 | print("Restored parameters: %s"%params) 198 | 199 | # Create a result directory 200 | if not os.path.exists(args.output_path): 201 | os.mkdir(args.output_path) 202 | print("New directory %s is created"%args.output_path) 203 | 204 | # Restore the model 205 | tokenizer=MethylVocab(k=int(params["n_mers"])) 206 | dataset = MethylBertFinetuneDataset(args.input_data, tokenizer, seq_len=int(params["seq_len"])) 207 | data_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=40) 208 | print("Bulk data (%s) is loaded"%args.input_data) 209 | 210 | restore_dir = os.path.join(args.model_dir, "bert.model/") 211 | trainer = MethylBertFinetuneTrainer(len(tokenizer), save_path='./test', 212 | train_dataloader=data_loader, 213 | test_dataloader=data_loader 214 | ) 215 | 216 | df_train = pd.read_csv(params["train_dataset"], sep="\t") 217 | n_dmrs = max(len(df_train["dmr_label"].unique()), 218 | df_train["dmr_label"].max()+1 # same as 'def num_dmrs()' in data/dataset.py 219 | ) 220 | 221 | trainer.load(restore_dir, 222 | load_fine_tune=True, 223 | n_dmrs=int(n_dmrs) # transformer from_pretrained does not accept np.int 224 | ) 225 | print("Trained model (%s) is restored"%restore_dir) 226 | # Calculate margins 227 | 228 | deconvolute(trainer = trainer, 229 | data_loader = data_loader, 230 | df_train = df_train, 231 | tokenizer = tokenizer, 232 | output_path = args.output_path, 233 | n_grid = 10000, 234 | adjustment = args.adjustment) 235 | 236 | 237 | def run_preprocess(args): 238 | finetune_data_generate(f_dmr=args.f_dmr, 239 | output_dir=args.output_path, 240 | f_ref=args.f_ref, 241 | sc_dataset=args.sc_dataset, 242 | input_file=args.input_file, 243 | n_mers=args.n_mers, 244 | split_ratio=args.split_ratio, 245 | n_dmrs=args.n_dmrs, 246 | n_cores=args.n_cores, 247 | seed=args.seed, 248 | ignore_sex_chromo=args.ignore_sex_chromo, 249 | methyl_caller=args.methylcaller 250 | ) 251 | 252 | def write_args2json(args, f_out): 253 | with open(f_out, "w") as fp: 254 | json.dump(vars(args), fp) 255 | 256 | def get_args(func): 257 | ''' 258 | get a set of arguments for each MethylBERT function 259 | ''' 260 | 261 | # init 262 | parser_init = argparse.ArgumentParser("methylbert") 263 | subparsers = parser_init.add_subparsers(help="Options for MethylBERT") 264 | 265 | # get args 266 | if func == "preprocess_finetune": 267 | preprocess_finetune_arg_parser(subparsers) 268 | elif func == "finetune": 269 | finetune_arg_parser(subparsers) 270 | elif func == "deconvolute": 271 | deconvolute_arg_parser(subparsers) 272 | else: 273 | raise ValueError(f"{func} must be one of {OPTIONS}") 274 | 275 | # Configuration file is given 276 | if (len(sys.argv) >= 3) and (".json" in sys.argv[2]): 277 | f_config = sys.argv[2] 278 | with open(f_config, "r") as fp: 279 | config_dict = json.load(fp) 280 | 281 | args = [func] 282 | for k, v in config_dict.items(): 283 | if ( type(v) == bool ) and (v): 284 | args.append(f"--{k}") 285 | elif v is not None: 286 | args.append(f"--{k}") 287 | args.append(f"{v}") 288 | 289 | args = parser_init.parse_args(args) 290 | else: 291 | # parse args 292 | args = parser_init.parse_args() 293 | 294 | # output configuration in a .json file 295 | if not os.path.exists(args.output_path): 296 | os.mkdir(args.output_path) 297 | write_args2json(args, os.path.join(args.output_path, f"{func}_config.json")) 298 | 299 | return args 300 | 301 | def main(args=None): 302 | print(f"MethylBERT v{__version__}") 303 | 304 | if len(sys.argv) == 1: 305 | print(f"One option must be given from {OPTIONS}") 306 | exit() 307 | 308 | selected_option = sys.argv[1] 309 | args = get_args(selected_option) 310 | 311 | # Run the function 312 | if selected_option == "preprocess_finetune": 313 | run_preprocess(args) 314 | elif selected_option == "finetune": 315 | run_finetune(args) 316 | elif selected_option == "deconvolute": 317 | run_deconvolute(args) -------------------------------------------------------------------------------- /src/methylbert/config.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from transformers import BertConfig 4 | 5 | METHYLBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 6 | "hanyangii/methylbert_hg19_12l": "https://huggingface.co/hanyangii/methylbert_hg19_12l/raw/main/config.json", 7 | "hanyangii/methylbert_hg19_8l": "https://huggingface.co/hanyangii/methylbert_hg19_8l/raw/main/config.json", 8 | "hanyangii/methylbert_hg19_6l": "https://huggingface.co/hanyangii/methylbert_hg19_6l/raw/main/config.json", 9 | "hanyangii/methylbert_hg19_4l": "https://huggingface.co/hanyangii/methylbert_hg19_4l/raw/main/config.json", 10 | "hanyangii/methylbert_hg19_2l": "https://huggingface.co/hanyangii/methylbert_hg19_2l/raw/main/config.json" 11 | } 12 | 13 | class MethylBERTConfig(BertConfig): 14 | pretrained_config_archive_map = METHYLBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 15 | loss="bce" 16 | num_labels=-1 17 | 18 | class Config(object): 19 | def __init__(self, config_dict: dict): 20 | for k, v in config_dict.items(): 21 | setattr(self, k, v) 22 | 23 | def get_config(**kwargs): 24 | ''' 25 | Create a Config object for configuration from input 26 | ''' 27 | config = OrderedDict( 28 | [ 29 | ('lr', 1e-4), 30 | ('beta', (0.9, 0.999)), 31 | ('weight_decay', 0.01), 32 | ('warmup_step', 10000), 33 | ('eps', 1e-6), 34 | ('with_cuda', True), 35 | ('log_freq', 10), 36 | ('eval_freq', 1), 37 | ('n_hidden', None), 38 | ("decrease_steps", 200), 39 | ('eval', False), 40 | ('amp', False), 41 | ("gradient_accumulation_steps", 1), 42 | ("max_grad_norm", 1.0), 43 | ("eval", False), 44 | ("save_freq", None), 45 | ("loss", "bce") 46 | ] 47 | ) 48 | 49 | if kwargs is not None: 50 | for key in config.keys(): 51 | if key in kwargs.keys(): 52 | config[key] = kwargs.pop(key) 53 | 54 | return Config(config) -------------------------------------------------------------------------------- /src/methylbert/data/__init__.py: -------------------------------------------------------------------------------- 1 | __version__="2.0.2" -------------------------------------------------------------------------------- /src/methylbert/data/bam.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | 4 | def parse_cigar(cigar: str): 5 | num = 0 6 | cigar_char = list() 7 | cigar_num = list() 8 | cigar = list(cigar) 9 | 10 | for c in cigar: 11 | if c.isdigit() : 12 | num = num*10 + int(c) 13 | else: 14 | cigar_char.append(c) 15 | cigar_num.append(num) 16 | num = 0 17 | return cigar_char, cigar_num 18 | 19 | def handling_cigar(methyl_seq: str, cigarstring: str): 20 | # Handle cigar strings 21 | cigar_list, num_list = parse_cigar(cigarstring) 22 | start_idx = 0 23 | new_seq, new_methyl_seq = "", "" 24 | 25 | for c, n in zip(cigar_list, num_list): 26 | if c not in ["D", "S", "I"]: 27 | new_methyl_seq += methyl_seq[start_idx:start_idx+n] 28 | elif c in ["D", "N"]: 29 | new_methyl_seq += "".join(["D" for nn in range(n)]) 30 | continue 31 | start_idx += n 32 | 33 | return new_methyl_seq 34 | 35 | def process_bismark_read(ref_seq, read): 36 | cigarstring = read.cigarstring 37 | 38 | if read.is_duplicate: 39 | return None 40 | 41 | try: 42 | # XM tag stores cytosine methyl patterns in bismark 43 | xm_tag = read.get_tag("XM") 44 | except KeyError: 45 | # no methylation call for read (check it with read.reference_id) 46 | return None 47 | 48 | # No CpG methylation on the read 49 | if xm_tag.count("z") + xm_tag.count("Z") == 0: 50 | return None 51 | 52 | # Handle the insertion and deletion 53 | xm_tag = handling_cigar(xm_tag, cigarstring) 54 | 55 | # Extract all CpGs 56 | methylatable_sites = [idx for idx, r in enumerate(ref_seq) if ref_seq[idx:idx+2] == "CG"] 57 | 58 | # if there's no CpGs on the read 59 | if len(methylatable_sites) == 0: 60 | return None 61 | 62 | # Paired-end or single-end 63 | is_single_end = not bool(read.flag % 2) 64 | 65 | for idx in methylatable_sites: 66 | methyl_state = None 67 | methyl_idx = -1 68 | 69 | # Taking the complement cytosine's methylation for the reversed read 70 | if idx >= len(xm_tag): 71 | methyl_state = "." 72 | methyl_idx=idx 73 | elif (not read.is_reverse and is_single_end) or (read.is_reverse != read.is_read1 and not is_single_end): 74 | methyl_state = xm_tag[idx] 75 | methyl_idx = idx 76 | elif idx+1 < len(xm_tag): 77 | methyl_state = xm_tag[idx+1] 78 | methyl_idx = idx+1 79 | else: 80 | methyl_state = "." 81 | methyl_idx = idx+1 82 | 83 | if methyl_state is not None: 84 | if methyl_state in (".", "D"): # Missing or occured by deletion 85 | methyl_state = "C" 86 | 87 | elif (methyl_state in ["x", "h", "X", "H"]): # non-CpG methyl 88 | if (xm_tag[idx] in ["D"]) or (xm_tag[idx+1] in ["D"]): 89 | methyl_state="C" 90 | else: 91 | raise ValueError("Error in the conversion: %d %s %s %s %s\nrefe_seq %s\nread_seq %s\nxmtag_seq %s\n%s\n%s %s %d"%( 92 | idx, 93 | xm_tag[idx], 94 | methyl_state, 95 | "Reverse" if read.is_reverse else "Forward", 96 | "Single" if is_single_end else "Paired", 97 | ref_seq, 98 | read.query_alignment_sequence, 99 | xm_tag, 100 | cigarstring, 101 | read.query_name, 102 | read.reference_name, 103 | read.pos)) 104 | 105 | ref_seq = ref_seq[:idx] + methyl_state + ref_seq[idx+1:] 106 | 107 | # Remove inserted and soft clip bases 108 | #ref_seq = ref_seq.replace("I", "") 109 | #ref_seq = ref_seq.replace("S", "") 110 | 111 | return ref_seq 112 | 113 | 114 | def compare_modifications(mod_base1, mod_base2): 115 | ''' 116 | Evaluate whether two modifications were measured at the same base 117 | ''' 118 | if len(mod_base1) != len(mod_base2): 119 | return False 120 | for b1, b2 in zip(mod_base1, mod_base2): 121 | if b1[0] != b2[0]: 122 | return False 123 | return True 124 | 125 | def process_dorado_read(ref_seq, read): 126 | ''' 127 | Process reference sequences with methylation patterns from Dorado 128 | ''' 129 | 130 | # read sequence bases including soft clipped bases + get forward if the read is aligned with reverse strand 131 | read_seq = read.query_sequence 132 | 133 | if (read_seq is None ) or ("H" in read.cigarstring): 134 | # Secondary alignment and hard-clipped reads are ignored 135 | return None 136 | 137 | if ref_seq.count("CG") == 0: 138 | return None 139 | 140 | # Get modified bases indicationg methylation 141 | methyl_seq = [2 if ref_seq[i:i+2] != "CG" else 0 for i in range(len(read.get_forward_sequence())-1)] 142 | modified_bases = read.modified_bases.copy() # the instance of read cannot be modified 143 | ch_key, cm_key = ('C', int(read.is_reverse), 'h'), ('C', int(read.is_reverse), 'm') 144 | 145 | if (ch_key not in modified_bases.keys()) and (cm_key not in modified_bases.keys()): 146 | # no cytosine modification 147 | return None 148 | elif cm_key not in modified_bases.keys(): 149 | # 5-Methylcytosine missing, add Cm keys with likelihood 0 150 | modified_bases.update({cm_key: list()}) 151 | for base_mod in modified_bases[ch_key]: 152 | modified_bases[cm_key].append((base_mod[0], 0)) 153 | elif ch_key not in modified_bases.keys(): 154 | # 5-Hydroxymethylcytosine missing, add Ch keys with likelihood 0 155 | modified_bases.update({ch_key: list()}) 156 | for base_mod in modified_bases[cm_key]: 157 | modified_bases[ch_key].append((base_mod[0], 0)) 158 | elif not compare_modifications(modified_bases['C', int(read.is_reverse), 'h'], 159 | modified_bases['C', int(read.is_reverse), 'm']): 160 | raise ValueError(f"Modifications are not aligned: {modified_bases['C', int(read.is_reverse), 'h']}, {modified_bases['C', int(read.is_reverse), 'm']}") 161 | 162 | for ch, cm in zip(modified_bases[ch_key], 163 | modified_bases[cm_key]): 164 | sum_prob = ch[1]+cm[1] 165 | methyl_pattern = 1 if sum_prob >= 178 else 0 # consider methylated when the likelihood is > .5 166 | cg_idx = ch[0] - 1 if read.is_reverse else ch[0] 167 | methyl_seq[cg_idx] = methyl_pattern 168 | 169 | methyl_seq = "".join(list(map(str, methyl_seq))) 170 | # Handle cigar strings 171 | methyl_seq = handling_cigar(methyl_seq, read.cigarstring) 172 | 173 | #if len(ref_seq) != len(methyl_seq): 174 | # raise ValueError(f"DNA seq and methylation seq have different lengths - {len(ref_seq)}, {len(methyl_seq)}") 175 | 176 | # Match reference seq and methyl pattern 177 | for idx in range(len(ref_seq)): 178 | if idx >= len(methyl_seq): 179 | # methyl seq shorter than ref seq 180 | break 181 | if methyl_seq[idx] in ["2", "D"]: 182 | continue 183 | 184 | if ref_seq[idx:idx+2] != "CG": 185 | # Occured because of variant 186 | continue 187 | 188 | ref_seq = ref_seq[:idx] + ("z" if methyl_seq[idx] == "0" else "Z") + ref_seq[idx+1:] 189 | 190 | return ref_seq 191 | -------------------------------------------------------------------------------- /src/methylbert/data/dataset.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import multiprocessing as mp 3 | import random 4 | from copy import deepcopy 5 | from functools import partial 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | from methylbert.data.vocab import MethylVocab 13 | 14 | 15 | def _line2tokens_pretrain(l, tokenizer, max_len=120): 16 | ''' 17 | convert a text line into a list of tokens converted by tokenizer 18 | 19 | ''' 20 | 21 | l = l.strip().split(" ") 22 | 23 | tokened = [tokenizer.to_seq(b) for b in l] 24 | if len(tokened) > max_len: 25 | return tokened[:max_len] 26 | else: 27 | return tokened + [[tokenizer.pad_index] for k in range(max_len-len(tokened))] 28 | 29 | def _parse_line(l, headers): 30 | # Check the header 31 | if not all([h in headers for h in ["dna_seq", "methyl_seq", "ctype", "dmr_ctype", "dmr_label"]]): 32 | raise ValueError("The header must contain dna_seq, methyl_seq, ctype, dmr_ctype, dmr_label") 33 | 34 | # Separate n-mers tokens and labels from each line 35 | l = l.split("\t") # don't add strip; some columns may be None 36 | if len(headers) == len(l): 37 | l = {k: v for k, v in zip(headers, l)} 38 | else: 39 | raise ValueError(f"Only {len(headers)} elements are in the input file header, whereas the line has {len(l)} elements.") 40 | 41 | # Cell-type label is binary (whether the cell type corresponds to the DMR cell type) 42 | l["ctype_label"] = int(l["ctype"] == l["dmr_ctype"]) 43 | l["dmr_label"] = int(l["dmr_label"]) 44 | 45 | return l 46 | 47 | 48 | def _line2tokens_finetune(l, tokenizer, max_len=150, headers=None): 49 | # parsed line! 50 | 51 | l["dna_seq"] = l["dna_seq"].split(" ") 52 | l["dna_seq"] = [[f] for f in tokenizer.to_seq(l["dna_seq"])] 53 | l["methyl_seq"] = [int(m) for m in l["methyl_seq"]] 54 | 55 | if len(l["dna_seq"]) > max_len: 56 | l["dna_seq"] = l["dna_seq"][:max_len] 57 | l["methyl_seq"] = l["methyl_seq"][:max_len] 58 | else: 59 | cur_seq_len=len(l["dna_seq"]) 60 | l["dna_seq"] = l["dna_seq"]+[[tokenizer.pad_index] for k in range(max_len-cur_seq_len)] 61 | l["methyl_seq"] = l["methyl_seq"] + [2 for k in range(max_len-cur_seq_len)] 62 | 63 | return l 64 | 65 | class MethylBertDataset(Dataset): 66 | def __init__(self): 67 | pass 68 | 69 | def __len__(self): 70 | return self.lines.shape[0] if type(self.lines) == np.array else len(self.lines) 71 | 72 | 73 | class MethylBertPretrainDataset(MethylBertDataset): 74 | def __init__(self, f_path: str, vocab: MethylVocab, seq_len: int, random_len=False, n_cores=50): 75 | 76 | self.vocab = vocab 77 | self.seq_len = seq_len 78 | self.f_path = f_path 79 | self.random_len = random_len 80 | 81 | # Define a range of tokens to mask based on k-mers 82 | self.mask_list = self._get_mask() 83 | 84 | # Read all text files and convert the raw sequence into tokens 85 | with open(self.f_path, "r") as f_input: 86 | print("Open data : %s"%f_input) 87 | raw_seqs = f_input.read().splitlines() 88 | 89 | print("Total number of sequences : ", len(raw_seqs)) 90 | 91 | # Multiprocessing for the sequence tokenisation 92 | with mp.Pool(n_cores) as pool: 93 | line_labels = pool.map(partial(_line2tokens_pretrain, 94 | tokenizer=self.vocab, 95 | max_len=self.seq_len), raw_seqs) 96 | del raw_seqs 97 | print("Lines are processed") 98 | self.lines = torch.squeeze(torch.tensor(np.array(line_labels, dtype=np.int16))) 99 | del line_labels 100 | gc.collect() 101 | 102 | def __getitem__(self, index): 103 | 104 | dna_seq = self.lines[index].clone() 105 | 106 | # Random len 107 | if self.random_len and np.random.random() < 0.5: 108 | dna_seq = dna_seq[:random.randint(5, self.seq_len)] 109 | 110 | # Padding 111 | if dna_seq.shape[0] < self.seq_len: 112 | pad_num = self.seq_len-dna_seq.shape[0] 113 | dna_seq = torch.cat((dna_seq, 114 | torch.tensor([self.vocab.pad_index for i in range(pad_num)], dtype=torch.int16))) 115 | 116 | # Mask 117 | masked_dna_seq, dna_seq, bert_mask = self._masking(dna_seq) 118 | #print(dna_seq, masked_dna_seq,"\n=============================================\n") 119 | return {"bert_input": masked_dna_seq, 120 | "bert_label": dna_seq, 121 | "bert_mask" : bert_mask} 122 | 123 | def subset_data(self, n_seq: int): 124 | self.lines = random.sample(self.lines, n_seq) 125 | 126 | def _get_mask(self): 127 | ''' 128 | Relative positions from the centre of masked region 129 | e.g) [-1, 0, 1] for 3-mers 130 | ''' 131 | half_length = int(self.vocab.kmers/2) 132 | mask_list = [-1*half_length + i for i in range(half_length)] + [i for i in range(1, half_length+1)] 133 | if self.vocab.kmers % 2 == 0: 134 | mask_list = mask_list[:-1] 135 | 136 | return mask_list 137 | 138 | def _masking(self, inputs: torch.Tensor, threshold=0.15): 139 | """ 140 | Moidfied version of masking token function 141 | Originally developed by Huggingface (datacollator) and DNABERT 142 | 143 | https://github.com/huggingface/transformers/blob/9a24b97b7f304fa1ceaaeba031241293921b69d3/src/transformers/data/data_collator.py#L747 144 | 145 | https://github.com/jerryji1993/DNABERT/blob/bed72fc0694a7b04f7e980dc9ce986e2bb785090/examples/run_pretrain.py#L251 146 | 147 | Added additional tasks to handle each sequence 148 | Lines using tokenizer were modified due to different tokenizer object structure 149 | 150 | """ 151 | 152 | labels = inputs.clone() 153 | 154 | # Sample tokens with given probability threshold 155 | probability_matrix = torch.full(labels.shape, threshold) # tensor filled with 0.15 156 | 157 | # Handle special tokens and padding 158 | special_tokens_mask = [ 159 | val < 5 for val in labels.tolist() 160 | ] 161 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) 162 | #padding_mask = labels.eq(self.vocab.pad_index) 163 | #probability_matrix.masked_fill_(padding_mask, value=0.0) 164 | 165 | masked_indices = torch.bernoulli(probability_matrix).bool() # get masked tokens based on bernoulli only within non-special tokens 166 | 167 | # change masked indices 168 | masked_index = deepcopy(masked_indices) 169 | 170 | # This function handles each sequence 171 | end = torch.where(probability_matrix!=0)[0].tolist()[-1] # end of the sequence 172 | mask_centers = set(torch.where(masked_index==1)[0].tolist()) # mask locations 173 | 174 | new_centers = deepcopy(mask_centers) 175 | for center in mask_centers: 176 | for mask_number in self.mask_list:# add neighbour loci 177 | current_index = center + mask_number 178 | if current_index <= end and current_index >= 0: 179 | new_centers.add(current_index) 180 | 181 | new_centers = list(new_centers) 182 | 183 | masked_indices[new_centers] = True 184 | 185 | # Avoid loss calculation on unmasked tokens 186 | labels[~masked_indices] = -100 187 | 188 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 189 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 190 | inputs[indices_replaced] = self.vocab.mask_index 191 | 192 | # 10% of the time, we replace masked input tokens with random word 193 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 194 | random_words = torch.randint(len(self.vocab), labels.shape, dtype=torch.int16) 195 | inputs[indices_random] = random_words[indices_random] 196 | 197 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 198 | 199 | # Special tokens (SOS, EOS) 200 | if end < inputs.shape[0]: 201 | inputs[end] = self.vocab.eos_index 202 | else: 203 | inputs[-1] = self.vocab.eos_index 204 | 205 | labels = torch.cat((torch.tensor([-100]), labels)) 206 | inputs = torch.cat((torch.tensor([self.vocab.sos_index]), inputs)) 207 | masked_index = torch.cat((torch.tensor([False]), masked_index)) 208 | 209 | 210 | return inputs, labels, masked_index 211 | 212 | class MethylBertFinetuneDataset(MethylBertDataset): 213 | def __init__(self, f_path: str, vocab: MethylVocab, seq_len: int, n_cores: int=10, n_seqs = None): 214 | ''' 215 | MethylBERT dataset 216 | 217 | f_path: str 218 | File path to the processed input file 219 | vocab: MethylVocab 220 | MethylVocab object to convert DNA and methylation pattern sequences 221 | seq_len: int 222 | Length for the processed sequences 223 | n_cores: int 224 | Number of cores for multiprocessing 225 | n_seqs: int 226 | Number of sequences to subset the input (default: None, do not make a subset) 227 | 228 | ''' 229 | self.vocab = vocab 230 | self.seq_len = seq_len 231 | self.f_path = f_path 232 | 233 | # Read all text files and convert the raw sequence into tokens 234 | with open(self.f_path, "r") as f_input: 235 | raw_seqs = f_input.read().splitlines() 236 | 237 | # Check if there's a header 238 | self.headers = raw_seqs[0].split("\t") 239 | raw_seqs = raw_seqs[1:] 240 | 241 | if n_seqs is not None: 242 | raw_seqs = raw_seqs[:n_seqs] 243 | print("Total number of sequences : ", len(raw_seqs)) 244 | 245 | # Multiprocessing for the sequence tokenisation 246 | with mp.Pool(n_cores) as pool: 247 | self.lines = pool.map(partial(_parse_line, 248 | headers=self.headers), raw_seqs) 249 | del raw_seqs 250 | gc.collect() 251 | self.set_dmr_labels = set([l["dmr_label"] for l in self.lines]) 252 | 253 | self.ctype_label_count = self._get_cls_num() 254 | print("# of reads in each label: ", self.ctype_label_count) 255 | 256 | def _get_cls_num(self): 257 | # unique labels 258 | ctype_labels=[l["ctype_label"] for l in self.lines] 259 | labels = list(set(ctype_labels)) 260 | label_count = np.zeros(len(labels)) 261 | for l in labels: 262 | label_count[l] = sum(np.array(ctype_labels) == l) 263 | return label_count 264 | 265 | def num_dmrs(self): 266 | return max(len(self.set_dmr_labels), max(self.set_dmr_labels)+1) # +1 is for the label 0 267 | 268 | def subset_data(self, n_seq): 269 | self.lines = self.lines[:n_seq] 270 | 271 | def __getitem__(self, index): 272 | line = deepcopy(self.lines[index]) 273 | 274 | item = _line2tokens_finetune( 275 | l=line, 276 | tokenizer=self.vocab, max_len=self.seq_len, headers=self.headers) 277 | 278 | item["dna_seq"] = torch.squeeze(torch.tensor(np.array(item["dna_seq"], dtype=np.int32))) 279 | item["methyl_seq"] = torch.squeeze(torch.tensor(np.array(item["methyl_seq"], dtype=np.int8))) 280 | 281 | # Special tokens (SOS, EOS) 282 | end = torch.where(item["dna_seq"]!=self.vocab.pad_index)[0].tolist()[-1] + 1 # end of the read 283 | if end < item["dna_seq"].shape[0]: 284 | item["dna_seq"][end] = self.vocab.eos_index 285 | item["methyl_seq"][end] = 2 286 | else: 287 | item["dna_seq"][-1] = self.vocab.eos_index 288 | item["methyl_seq"][-1] = 2 289 | item["dna_seq"] = torch.cat((torch.tensor([self.vocab.sos_index]), item["dna_seq"])) 290 | item["methyl_seq"] = torch.cat((torch.tensor([2]), item["methyl_seq"])) 291 | 292 | return item 293 | 294 | -------------------------------------------------------------------------------- /src/methylbert/data/finetune_data_generate.py: -------------------------------------------------------------------------------- 1 | 2 | import multiprocessing as mp 3 | import os 4 | import random 5 | import re 6 | import sys 7 | import uuid 8 | import warnings 9 | from functools import partial 10 | from typing import List, Optional 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import pysam 15 | from Bio import SeqIO 16 | from sklearn.model_selection import train_test_split 17 | from tqdm.auto import tqdm 18 | 19 | from .bam import process_bismark_read, process_dorado_read 20 | 21 | 22 | # https://gist.github.com/EdwinChan/3c13d3a746bb3ec5082f 23 | def globalize(func): 24 | def result(*args, **kwargs): 25 | return func(*args, **kwargs) 26 | result.__name__ = result.__qualname__ = uuid.uuid4().hex 27 | setattr(sys.modules[result.__module__], result.__name__, result) 28 | return result 29 | 30 | def kmers(seq: str, k: int=3): 31 | ''' 32 | Covert sequences including both reference DNA and methylation patterns into 3-mer DNA and methylation seqs 33 | ''' 34 | converted_seq = list() 35 | methyl = list() 36 | 37 | if k % 2 == 0: 38 | raise ValueError(f"k must be an odd number because of CpG methylation being at the middle of the token. The given k is {k}") 39 | 40 | mid = int(k/2) 41 | for seq_idx in range(len(seq)-k): 42 | token = seq[seq_idx:seq_idx+k] 43 | if token[mid] =='z': 44 | m = 0 45 | elif token[mid] =='Z': 46 | m = 1 47 | else: 48 | m = 2 49 | 50 | # six alphabets indicating cytosine methylation in bismark processed files 51 | token = re.sub("[h|H|z|Z|x|X]", "C", token) 52 | converted_seq.append(token) 53 | methyl.append(str(m)) 54 | 55 | return converted_seq, methyl 56 | 57 | 58 | def read_extract(bam_file_path: str, dict_ref: dict, k: int, dmrs: pd.DataFrame, ncores: int=1, methyl_caller: str = "bismark"): 59 | ''' 60 | Extract reads including methylation patterns overlapping with DMRs 61 | and convert those into 3-mer sequences 62 | 63 | bam_file_path: (str) 64 | single-cell bam file path 65 | dict_ref: (dict) 66 | Reference genome given in a dictionary whose key is chromosome and value is DNA sequences 67 | k: (int) 68 | k value for k-mers 69 | dmrs: (pd.Dataframe) 70 | dataframe where DMR information is stored (chromo, start, end are required) 71 | ncores: (int) 72 | Number of cores to be used for parallel processing, default: 1 73 | ''' 74 | 75 | def _reads_overlapping(aln, chromo, start, end, methyl_caller="bismark"): 76 | ''' 77 | seq_list = list() 78 | dna_seq = list() 79 | xm_tags = list() 80 | ''' 81 | 82 | if methyl_caller not in ["bismark", "dorado"]: 83 | raise ValueError(f"Methylation caller must be one of [bismark, dorado]") 84 | 85 | # get all reads overlapping with chromo:start-end 86 | fetched_reads = aln.fetch(chromo, start, end, until_eof=True) 87 | processed_reads = list() 88 | 89 | for read in fetched_reads: 90 | # Only select fully overlapping reads 91 | if (read.pos < start) or ((read.pos+read.query_alignment_length) > end): 92 | continue 93 | 94 | # Remove case-specific mode occured by the quality 95 | ref_seq = dict_ref[chromo][read.pos:(read.pos+read.query_alignment_length)].upper() 96 | 97 | if methyl_caller == "bismark": 98 | ref_seq = process_bismark_read(ref_seq, read) 99 | elif methyl_caller == "dorado": 100 | ref_seq = process_dorado_read(ref_seq, read) 101 | 102 | if ref_seq is None: 103 | continue 104 | 105 | # When there is no CpG methylation patterns after processing 106 | if "z" not in ref_seq.lower(): 107 | continue 108 | 109 | # K-mers 110 | s, m = kmers(ref_seq, k=3) 111 | 112 | if len(s) != len(m): 113 | raise ValueError("DNA and methylation sequences have different length (%d and %d)"%(len(s), len(m))) 114 | 115 | # Add processed results as a tag 116 | read.setTag("RF", value=" ".join(s), replace=True) # reference sequence 117 | read.setTag("ME", value="".join(m), replace=True) # methylation pattern sequence 118 | 119 | # Process back to a dictionary 120 | read_tags = {t:v for t, v in read.get_tags()} 121 | read = read.to_dict() 122 | read.update(read_tags) 123 | processed_reads.append(read) 124 | 125 | processed_reads = pd.DataFrame(processed_reads) 126 | if "tags" in processed_reads.keys(): 127 | processed_reads = processed_reads.drop(columns=["tags"]) 128 | 129 | return processed_reads 130 | 131 | @globalize 132 | def _get_methylseq(dmr, bam_file_path: str, k: int, methyl_caller: str): 133 | ''' 134 | Return a dictionary of DNA seq, cell type and methylation seq processed in a 3-mer seq 135 | ''' 136 | aln = pysam.AlignmentFile(bam_file_path, "rb") 137 | 138 | processed_reads = _reads_overlapping(aln, 139 | dmr["chr"], int(dmr["start"]), int(dmr["end"]), 140 | methyl_caller) 141 | if processed_reads.shape[0] > 0: 142 | processed_reads = processed_reads.assign(dmr_ctype = dmr["ctype"], 143 | dmr_label = dmr["dmr_id"]) 144 | return processed_reads 145 | 146 | if ncores > 1: 147 | with mp.Pool(ncores) as pool: 148 | # Convert read sequences to k-mer sequences 149 | seqs = pool.map(partial(_get_methylseq, 150 | bam_file_path = bam_file_path, k=k, methyl_caller=methyl_caller), 151 | dmrs.to_dict("records")) 152 | else: 153 | seqs = [_get_methylseq(dmr, bam_file_path = bam_file_path, k=k, 154 | methyl_caller = methyl_caller) 155 | for dmr in dmrs.to_dict("records")] 156 | 157 | # Filter None values that means no overlapping read with the given DMR 158 | seqs = list(filter(lambda i: i is not None, seqs)) 159 | if len(seqs) > 0: 160 | return pd.concat(seqs, ignore_index=True) 161 | else: 162 | warnings.warn(f"Zero reads were extracted from {bam_file_path}") 163 | return pd.DataFrame([]) 164 | 165 | def finetune_data_generate( 166 | f_dmr: str, 167 | output_dir: str, 168 | f_ref: str, 169 | sc_dataset: str = None, 170 | input_file: str = None, 171 | n_mers: int = 3, 172 | split_ratio: float = None, 173 | train_valid_test_ratio: List[float] = None, 174 | use_file_name: bool = False, 175 | n_dmrs: int = -1, 176 | n_cores: int = 1, 177 | seed: int = 950410, 178 | ignore_sex_chromo: bool = True, 179 | methyl_caller: str = "bismark", 180 | verbose: int = 2, 181 | read_extract_sequences_func: Optional[callable] = None 182 | ): 183 | 184 | # Setup random seed 185 | random.seed(seed) 186 | np.random.seed(seed) 187 | 188 | # Check split ratio 189 | if (split_ratio is not None) and (train_valid_test_ratio is not None): 190 | raise ValueError("Only either of 'split_ratio (float)' or train_valid_test_ratio 'List[float]' must be given.") 191 | elif split_ratio is not None: 192 | split_ratios = [split_ratio, 1-split_ratio, 0.0] 193 | elif train_valid_test_ratio is not None: 194 | if ( np.sum(train_valid_test_ratio) != 1.0 ) or \ 195 | ( len(train_valid_test_ratio) != 3 ): 196 | raise ValueError("'train_valid_test_ratio' must be a list with 3 float values whose sum is 1.0") 197 | split_ratios = train_valid_test_ratio 198 | else: 199 | split_ratios = [1.0, 0.0, 0.0] # output must be one single file 200 | 201 | # Setup output files 202 | if not os.path.exists(output_dir): 203 | os.mkdir(output_dir) 204 | fp_dmr = os.path.join(output_dir, "dmrs.csv") # File to save selected DMRs 205 | 206 | # Reference genome 207 | record_iter = SeqIO.parse(f_ref, "fasta") 208 | 209 | # Save the reference genome into a dictionary with chr as a key value 210 | dict_ref = dict() 211 | for r in record_iter: 212 | dict_ref[str(r.id)] = str(r.seq.upper()) 213 | del record_iter 214 | 215 | # Load DMRs into a dataframe 216 | dmrs = pd.read_csv(f_dmr, sep="\t", index_col=None) 217 | if ("chr" not in dmrs.keys()) or \ 218 | ("start" not in dmrs.keys()) or \ 219 | ("end" not in dmrs.keys()): 220 | ValueError("The .csv file for DMRs must contain chr, start and end in the header.") 221 | 222 | # Remove chrX, chrY, chrM and so on in DMRs 223 | # Genome style 224 | if "chr" in str(dmrs["chr"][0]): 225 | regex_expr = "chr\d+" if ignore_sex_chromo else "chr[\d+|X|Y]" 226 | old_keys = list(dict_ref.keys()) 227 | for k in old_keys: 228 | if "chr" not in k: dict_ref[f"chr{k}"] = dict_ref.pop(k) 229 | else: # NCBI style genome 230 | dmrs["chr"] = dmrs["chr"].astype(str) 231 | regex_expr = "\d+" if ignore_sex_chromo else "[\d+|X|Y]" 232 | old_keys = list(dict_ref.keys()) 233 | for k in old_keys: 234 | if "chr" in k: dict_ref[k.split("chr")[1]] = dict_ref.pop(k) 235 | 236 | dmrs = dmrs[dmrs["chr"].str.contains(regex_expr, regex=True)] 237 | 238 | if dmrs.shape[0] == 0: 239 | ValueError("Could not find any DMRs. Please make sure chromosomes have \'chr\' at the beginning.") 240 | 241 | # Sort by statistics if available 242 | if "areaStat" in dmrs.keys(): 243 | if verbose > 0: 244 | print("DMRs sorted by areaStat") 245 | dmrs["abs_areaStat"] = dmrs["areaStat"].abs() 246 | dmrs = dmrs.sort_values(by="abs_areaStat", ascending=False) 247 | elif "diff.Methy" in dmrs.keys(): 248 | if verbose > 0: 249 | print("DMRs sorted by diff.Methy") 250 | dmrs["abs_diff.Methy"] = dmrs["diff.Methy"].abs() 251 | dmrs = dmrs.sort_values(by="abs_diff.Methy", ascending=False) 252 | else: 253 | if verbose > 0: 254 | print("Could not find any statistics to sort DMRs") 255 | 256 | # Select top n dmrs based on 257 | if n_dmrs > 0: 258 | if verbose > 0: 259 | print(f"{n_dmrs} are selected based on the statistics") 260 | list_dmrs = list() 261 | for c in dmrs["ctype"].unique(): # For the case when multiple cell types are given 262 | ctype_dmrs = dmrs[dmrs["ctype"]==c] 263 | if ctype_dmrs.shape[0] > n_dmrs: 264 | list_dmrs.append(ctype_dmrs[:n_dmrs]) 265 | else: 266 | list_dmrs.append(ctype_dmrs) 267 | dmrs = pd.concat(list_dmrs) 268 | del list_dmrs 269 | 270 | # Newly assign dmr label from 0 271 | if "dmr_id" not in dmrs.keys(): 272 | dmrs["dmr_id"] = range(len(dmrs)) 273 | 274 | # Save DMRs in a new file 275 | dmrs.to_csv(fp_dmr, sep="\t", index=False) 276 | if verbose > 2: 277 | print(dmrs.head()) 278 | 279 | if verbose > 0: 280 | print(f"Number of DMRs to extract sequence reads: {len(dmrs)}") 281 | 282 | # check whether the input is a file or a file list 283 | if ( not sc_dataset ) and ( not input_file ): 284 | ValueError("Please provide either a list of input files or a file path. Both are given.") 285 | elif ( not sc_dataset ): 286 | # one input file in the list 287 | sc_files = [input_file] 288 | if use_file_name: 289 | if verbose > 0: 290 | print('When only one input file is give, file name cannot be used for the train-test split. We set use_file_name=False. Read names will be used for the split') 291 | use_file_name = False 292 | elif ( not input_file ): 293 | # Collect train data (single-cell samples) 294 | train_sc_samples = [] 295 | with open(sc_dataset, "r") as fp_sc_dataset: 296 | sc_files = fp_sc_dataset.readlines() 297 | 298 | if ( len(sc_files) < 10 ) and ( use_file_name ): 299 | warnings.warn("We do not encourage to users to set use_file_name=True with the number of input bam files < 10. It can cause an unexpected error.") 300 | else: 301 | raise ValueError("Either a list of input files or a file path must be given.") 302 | 303 | # Collect reads from the .bam files 304 | df_reads = list() 305 | tqdm_bar = tqdm(total=len(sc_files), 306 | desc="Collecting reads from .bam files") 307 | 308 | # file/read name - cell type pair for stratification in train test split 309 | files_lbl_map = {} 310 | for f_sc in sc_files: 311 | f_sc = f_sc.strip().split("\t") 312 | f_sc_bam = f_sc[0] 313 | 314 | if read_extract_sequences_func is None: 315 | extracted_reads = read_extract( 316 | f_sc_bam, dict_ref, k=3, dmrs=dmrs, 317 | ncores=n_cores, methyl_caller=methyl_caller 318 | ) 319 | else: 320 | # custom function 321 | extracted_reads = read_extract_sequences_func( 322 | f_sc_bam, dict_ref, k=3, dmrs=dmrs, 323 | ncores=n_cores, methyl_caller=methyl_caller 324 | ) 325 | 326 | if extracted_reads is None: 327 | continue 328 | 329 | # cell type for the single-cell data 330 | ''' 331 | if "RG" in extracted_reads.columns: 332 | extracted_reads = extracted_reads.rename(columns={"RG":"read_ctype"}) 333 | #extracted_reads["ctype"] = [c.split("_")[1][:3]+"-"+c.split("_")[1][3] for c in extracted_reads["read_ctype"]] # mouse single-cell 334 | extracted_reads["ctype"] = [c.split("_")[1] for c in extracted_reads["read_ctype"]] # tumour pseudo bulk 335 | else: 336 | ''' 337 | if len(f_sc) > 1: 338 | if verbose > 1: 339 | print(f"{f_sc_bam} processing ({f_sc[1]})...") 340 | extracted_reads["ctype"] = f_sc[1] 341 | else: 342 | extracted_reads["ctype"] = "NA" 343 | extracted_reads = extracted_reads.rename(columns={"RF":"dna_seq", 344 | "ME":"methyl_seq"}) 345 | 346 | if(extracted_reads.shape[0] > 0): 347 | if use_file_name: 348 | filename = os.path.basename(f_sc_bam) 349 | if len(f_sc) > 1: 350 | files_lbl_map[filename] = extracted_reads['ctype'][0] 351 | extracted_reads["filename"] = filename 352 | else: 353 | for name, ctype in zip(extracted_reads['name'], 354 | extracted_reads['ctype']): 355 | files_lbl_map[name] = ctype 356 | 357 | df_reads.append(extracted_reads) 358 | 359 | tqdm_bar.update() 360 | 361 | # Integrate all reads and shuffle 362 | if len(df_reads) > 0: 363 | df_reads = pd.concat(df_reads, ignore_index=True) \ 364 | .sample(frac=1) \ 365 | .reset_index(drop=True) # sample is for shuffling 366 | 367 | if verbose > 1: 368 | print("Fine-tuning data generated:", df_reads.head()) 369 | else: 370 | ValueError("Could not find any reads overlapping with the given DMRs. Please try different regions.") 371 | 372 | if verbose > 1: 373 | print("Total sequences per cell type") 374 | print(df_reads["ctype"].value_counts()) 375 | 376 | # Split the data into train and train/valid/test by patient/bam file 377 | if split_ratios[0] != 1.0: 378 | fp_train_seq = os.path.join(output_dir, "train_seq.csv") 379 | fp_test_seq = os.path.join(output_dir, "test_seq.csv") 380 | 381 | split_key = "filename" if use_file_name else "name" 382 | 383 | val_test_size = 1 - split_ratios[0] 384 | train_files, test_files = train_test_split( 385 | list(files_lbl_map.keys()), 386 | test_size=val_test_size, random_state=seed, 387 | stratify=list(files_lbl_map.values()) 388 | ) 389 | 390 | if split_ratios[-1] > 0.0: 391 | fp_val_seq = os.path.join(output_dir, "val_seq.csv") 392 | test_size = split_ratios[2] / (split_ratios[1] + split_ratios[2]) 393 | val_files, test_files = train_test_split( 394 | test_files, 395 | test_size=test_size, random_state=seed, 396 | stratify=[files_lbl_map[e] for e in test_files] 397 | ) 398 | 399 | df_reads.loc[df_reads[split_key].isin(val_files), :] \ 400 | .to_csv(fp_val_seq, sep="\t", header=True, index=None) 401 | 402 | # Write train & test files (adding a final column because of sep="\t") 403 | df_reads["non_null_col"] = "" 404 | df_reads.loc[df_reads[split_key].isin(train_files), :] \ 405 | .to_csv(fp_train_seq, sep="\t", header=True, index=None) 406 | df_reads.loc[df_reads[split_key].isin(test_files), :] \ 407 | .to_csv(fp_test_seq, sep="\t", header=True, index=None) 408 | 409 | if verbose > 0: 410 | print("Size - train %d seqs , valid %d seqs "% \ 411 | (df_reads.loc[df_reads[split_key].isin(train_files), :].shape[0], 412 | df_reads.loc[df_reads[split_key].isin(test_files), :].shape[0])) 413 | 414 | else: 415 | fp_data_seq = os.path.join(output_dir, "data.csv") 416 | df_reads.to_csv(fp_data_seq, sep="\t", header=True, index=None) 417 | 418 | return df_reads 419 | -------------------------------------------------------------------------------- /src/methylbert/data/genome.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import numpy as np 4 | 5 | 6 | def _kmers(original_string, kmer=3): 7 | sentence = "" 8 | #original_string = original_string.replace("\n", "") 9 | i = 0 10 | while i < len(original_string)-kmer: 11 | sentence += original_string[i:i+kmer] + " " 12 | i += 1 13 | 14 | return sentence[:-1].strip("\"") 15 | 16 | def pretrain_data_preprocess(f_ref, k=3, seq_len=510, f_output=None): 17 | ''' 18 | Generate N bp length k-mers sequence from reference genome as pretrain data 19 | 20 | f_ref : str 21 | path to the reference fasta file 22 | k : int 23 | Number for k-mers (default=3) 24 | seq_len: int 25 | Base-pair length of generated sequences (default=510) 26 | f_output : str 27 | path to the output file, an appropriate name 28 | will be automatically assigned if not given 29 | 30 | ''' 31 | 32 | fp_ref = open(f_ref, "r") 33 | if f_output == None: 34 | f_output = f_ref + "_%dmers.txt"%(k) 35 | 36 | fp_out = open(f_output, "w") 37 | line = fp_ref.readline().strip().upper() 38 | cur_line="" # keeping an incomplete N bp line 39 | collect_data=False 40 | valid_chromosomes = ["CHR"+str(i) for i in range(22)] 41 | valid_chromosomes += ["CHRX", "CHRY"] 42 | 43 | while line: 44 | n_missing = line.count("N") 45 | 46 | if n_missing > 0: 47 | # Missing DNA base in the line -> reset the line 48 | #line = fp_ref.readline().strip().upper() 49 | line=fp_ref.readline().strip().upper() 50 | cur_line="" 51 | continue 52 | elif ( line.count(">") > 0 ): 53 | # New chromosome or there are some missing bases at the middle 54 | # We restart a 510 bp piece 55 | chromosome = line.split(">")[1] 56 | collect_data = chromosome in valid_chromosomes 57 | if collect_data: 58 | print("Collect sequences in %s"%(chromosome)) 59 | line = fp_ref.readline().strip().upper() 60 | cur_line = "" 61 | continue 62 | 63 | if collect_data: 64 | cur_line += line 65 | line_length = len(cur_line) 66 | if line_length >= seq_len: 67 | new_line = cur_line[:seq_len] 68 | cur_line = cur_line[seq_len:] 69 | sentence = _kmers(new_line, kmer=k) 70 | fp_out.write(sentence + "\n") 71 | 72 | # get a new line 73 | line = fp_ref.readline().strip().upper() 74 | 75 | 76 | def parse_args(): 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument( 79 | "--kmer", 80 | default=1, 81 | type=int, 82 | help="K-mer", 83 | ) 84 | parser.add_argument( 85 | "--length", 86 | default=10000, 87 | type=int, 88 | help="Length of the sampled sequence", 89 | ) 90 | parser.add_argument( 91 | "--file_path", 92 | default=None, 93 | type=str, 94 | help="The path of the file to be processed", 95 | ) 96 | parser.add_argument( 97 | "--output_path", 98 | default=None, 99 | type=str, 100 | help="The path of the processed data", 101 | ) 102 | args = parser.parse_args() 103 | return args 104 | 105 | if __name__ == "__main__": 106 | args = parse_args() 107 | pretrain_data_preprocess(args.file_path, k=args.kmer, 108 | seq_len=args.length, f_output=args.output_path) 109 | 110 | -------------------------------------------------------------------------------- /src/methylbert/data/vocab.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tqdm 3 | from collections import Counter 4 | import itertools 5 | 6 | ''' 7 | class TorchVocab(object): 8 | """Defines a vocabulary object that will be used to numericalize a field. 9 | Attributes: 10 | freqs: A collections.Counter object holding the frequencies of tokens 11 | in the data used to build the Vocab. 12 | stoi: A collections.defaultdict instance mapping token strings to 13 | numerical identifiers. 14 | itos: A list of token strings indexed by their numerical identifiers. 15 | """ 16 | 17 | def __init__(self, counter, max_size=None, min_freq=1, specials=['', ''], 18 | vectors=None, unk_init=None, vectors_cache=None): 19 | """Create a Vocab object from a collections.Counter. 20 | Arguments: 21 | counter: collections.Counter object holding the frequencies of 22 | each value found in the data. 23 | max_size: The maximum size of the vocabulary, or None for no 24 | maximum. Default: None. 25 | min_freq: The minimum frequency needed to include a token in the 26 | vocabulary. Values less than 1 will be set to 1. Default: 1. 27 | specials: The list of special tokens (e.g., padding or eos) that 28 | will be prepended to the vocabulary in addition to an 29 | token. Default: [''] 30 | vectors: One of either the available pretrained vectors 31 | or custom pretrained vectors (see Vocab.load_vectors); 32 | or a list of aforementioned vectors 33 | unk_init (callback): by default, initialize out-of-vocabulary word vectors 34 | to zero vectors; can be any function that takes in a Tensor and 35 | returns a Tensor of the same size. Default: torch.Tensor.zero_ 36 | vectors_cache: directory for cached vectors. Default: '.vector_cache' 37 | """ 38 | self.freqs = counter 39 | counter = counter.copy() 40 | min_freq = max(min_freq, 1) 41 | 42 | self.itos = list(specials) 43 | # frequencies of special tokens are not counted when building vocabulary 44 | # in frequency order 45 | for tok in specials: 46 | del counter[tok] 47 | 48 | max_size = None if max_size is None else max_size + len(self.itos) 49 | 50 | # sort by frequency, then alphabetically 51 | words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 52 | words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 53 | 54 | for word, freq in words_and_frequencies: 55 | if freq < min_freq or len(self.itos) == max_size: 56 | break 57 | self.itos.append(word) 58 | 59 | # stoi is simply a reverse dict for itos 60 | self.stoi = {tok: i for i, tok in enumerate(self.itos)} 61 | 62 | self.vectors = None 63 | if vectors is not None: 64 | self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) 65 | else: 66 | assert unk_init is None and vectors_cache is None 67 | 68 | def __eq__(self, other): 69 | if self.freqs != other.freqs: 70 | return False 71 | if self.stoi != other.stoi: 72 | return False 73 | if self.itos != other.itos: 74 | return False 75 | if self.vectors != other.vectors: 76 | return False 77 | return True 78 | 79 | def __len__(self): 80 | return len(self.itos) 81 | 82 | def vocab_rerank(self): 83 | self.stoi = {word: i for i, word in enumerate(self.itos)} 84 | 85 | def extend(self, v, sort=False): 86 | words = sorted(v.itos) if sort else v.itos 87 | for w in words: 88 | if w not in self.stoi: 89 | self.itos.append(w) 90 | self.stoi[w] = len(self.itos) - 1 91 | 92 | 93 | class Vocab(TorchVocab): 94 | def __init__(self, counter, max_size=None, min_freq=1): 95 | self.pad_index = 0 96 | self.unk_index = 1 97 | self.eos_index = 2 98 | self.sos_index = 3 99 | self.mask_index = 4 100 | super().__init__(counter, specials=["", "", "", "", ""], 101 | max_size=max_size, min_freq=min_freq) 102 | 103 | def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list: 104 | pass 105 | 106 | def from_seq(self, seq, join=False, with_pad=False): 107 | pass 108 | 109 | @staticmethod 110 | def load_vocab(vocab_path: str) -> 'Vocab': 111 | with open(vocab_path, "rb") as f: 112 | return pickle.load(f) 113 | 114 | def save_vocab(self, vocab_path): 115 | with open(vocab_path, "wb") as f: 116 | pickle.dump(self, f) 117 | ''' 118 | 119 | # Building Vocab with text files 120 | class MethylVocab(object): 121 | def __init__(self, k: int=3): 122 | ''' 123 | Create a look-up table to convert 3-mer tokens to numerical identifiers 124 | 125 | k: int 126 | k to create k-mer sequences 127 | ''' 128 | print("Building Vocab") 129 | self.kmers=k 130 | 131 | # Create a look up table with 3-mer tokens 132 | bases = ["A","G","T","C"] 133 | 134 | vocabs = list(itertools.product(bases, repeat=self.kmers)) 135 | vocabs = sorted(["".join(e) for e in vocabs]) #alphabetical orders 136 | 137 | # Set up special tokens 138 | special_tokens=["", "", "", "", ""] 139 | self.pad_index = 0 140 | self.unk_index = 1 141 | self.eos_index = 2 142 | self.sos_index = 3 143 | self.mask_index = 4 144 | 145 | self.itos = list(special_tokens) + vocabs 146 | self.stoi = {t: i for i, t in enumerate(self.itos)} 147 | 148 | def __len__(self): 149 | return len(self.itos) 150 | 151 | def to_seq(self, sequence) -> list: 152 | ''' 153 | Convert a 3-mer sequence 154 | 155 | sequence: str or list(str) 156 | A 3-mer sequence to convert. It can be given as either a string or a list of 3-mer strings 157 | 158 | ''' 159 | if isinstance(sequence, str): 160 | sentence = sequence.split() 161 | 162 | seq = [self.stoi.get(kmer, self.unk_index) for kmer in sequence] 163 | 164 | return seq 165 | 166 | def from_seq(self, seq, join=False, with_pad=False): 167 | words = [self.itos[idx] 168 | if idx < len(self.itos) 169 | else "<%d>" % idx 170 | for idx in seq 171 | if not with_pad or idx != self.pad_index] 172 | 173 | return " ".join(words) if join else words 174 | ''' 175 | @staticmethod 176 | def load_vocab(vocab_path: str) -> 'WordVocab': 177 | with open(vocab_path, "rb") as f: 178 | return pickle.load(f) 179 | 180 | 181 | class GenomeVocab(Vocab): 182 | def __init__(self, k=3, max_size=None, min_freq=1, cpg=True, chg=True, chh=False): 183 | print("Building Vocab") 184 | self.kmers=k 185 | counter = Counter() 186 | 187 | # Create a look up table with 3-mers 188 | import itertools 189 | 190 | bases = ["A","G","T","C", "N"] 191 | 192 | 193 | vocabs = list(itertools.product(bases, repeat=self.kmers)) 194 | vocabs = ["".join(elem) for elem in vocabs] 195 | for word in vocabs: 196 | counter[word] += 1 197 | 198 | super().__init__(counter, max_size=max_size, min_freq=min_freq) 199 | 200 | def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False): 201 | if isinstance(sentence, str): 202 | sentence = sentence.split() 203 | 204 | seq = [self.stoi.get(word, self.unk_index) for word in sentence] 205 | 206 | if with_eos: 207 | seq += [self.eos_index] # this would be index 1 208 | if with_sos: 209 | seq = [self.sos_index] + seq 210 | 211 | origin_seq_len = len(seq) 212 | 213 | if seq_len is None: 214 | pass 215 | elif len(seq) <= seq_len: 216 | seq += [self.pad_index for _ in range(seq_len - len(seq))] 217 | else: 218 | seq = seq[:seq_len] 219 | 220 | return (seq, origin_seq_len) if with_len else seq 221 | 222 | def from_seq(self, seq, join=False, with_pad=False): 223 | words = [self.itos[idx] 224 | if idx < len(self.itos) 225 | else "<%d>" % idx 226 | for idx in seq 227 | if not with_pad or idx != self.pad_index] 228 | 229 | return " ".join(words) if join else words 230 | 231 | @staticmethod 232 | def load_vocab(vocab_path: str) -> 'WordVocab': 233 | with open(vocab_path, "rb") as f: 234 | return pickle.load(f) 235 | ''' -------------------------------------------------------------------------------- /src/methylbert/deconvolute.py: -------------------------------------------------------------------------------- 1 | import argparse, os, logging 2 | import pandas as pd 3 | import numpy as np 4 | 5 | from methylbert.trainer import MethylBertFinetuneTrainer 6 | from methylbert.data.vocab import MethylVocab 7 | from methylbert.data.dataset import MethylBertFinetuneDataset 8 | from methylbert.utils import _moment 9 | 10 | from torch.utils.data import DataLoader 11 | from torch import Tensor 12 | from torch.nn.functional import softmax 13 | 14 | from tqdm import tqdm 15 | import pickle as pk 16 | 17 | from scipy.optimize import minimize 18 | 19 | def likelihood_fun(theta, margins, prob): 20 | ''' 21 | theta should be 2 x 1 22 | prob should be reads x 2 23 | ''' 24 | #margins, prob = args 25 | prob = np.divide(prob, margins) 26 | prob = np.log(np.matmul(theta.T, prob.T)) 27 | return np.sum(prob) 28 | 29 | def nll_multi_celltype(thetas, *args): 30 | 31 | df_res, margins = args 32 | nll = 0 33 | 34 | for idx, ctype in enumerate(margins.keys()): 35 | ctype_margins = np.array([1-margins[ctype], margins[ctype]]) 36 | prob = df_res.loc[df_res["dmr_ctype"]==ctype, "P_ctype"].to_numpy() 37 | prob = np.concatenate([np.array(1-prob).reshape(-1,1), 38 | np.array(prob).reshape(-1, 1)], axis=1) 39 | prob = np.divide(prob, ctype_margins) 40 | prob = np.log(np.matmul(np.array([1-thetas[idx], thetas[idx]]).reshape([1,2]), 41 | prob.T)) 42 | nll -= np.sum(prob) 43 | return nll 44 | 45 | def grid_search(logits, margins, n_grid, verbose=True): 46 | ''' 47 | return estimated_proportions (list), fisher info, likelihood 48 | ''' 49 | grid = np.zeros([1, n_grid]) 50 | 51 | if verbose: 52 | logging.info("Grid search (n=%d) for deconvolution", n_grid) 53 | pbar = tqdm(total=n_grid) 54 | for m_theta in range(0, n_grid): 55 | theta = m_theta*(1/n_grid) 56 | theta = np.array([1-theta, theta]).reshape([2,1]) 57 | if logits.shape[1] != theta.shape[0]: 58 | raise ValueError(f"Dimensions are wrong: theta {theta.shape}, prob {logits.shape}") 59 | grid[0, m_theta] = likelihood_fun(theta, margins, logits) 60 | if verbose: 61 | pbar.update(1) 62 | if verbose: 63 | pbar.close() 64 | 65 | # Fisher info calculation 66 | fi = np.var([grid[0, f+1] - grid[0, f] for f in range(n_grid-1)]) 67 | argmax_idx = np.argmax(grid, axis=1) 68 | estimates = float(argmax_idx)*(1/n_grid) 69 | return [estimates, 1-estimates], fi, grid[0, argmax_idx] 70 | 71 | 72 | def grid_search_regions(logits, margins, n_grid, regions): 73 | 74 | def skewness_test(x, *args): 75 | region_purity, estimates = args 76 | a = np.multiply(np.array(region_purity), x) 77 | m2 = _moment(a, 2, axis=0, mean=estimates) 78 | m3 = _moment(a, 3, axis=0, mean=estimates) 79 | with np.errstate(all='ignore'): 80 | zero = (m2 <= (np.finfo(m2.dtype).resolution * estimates)**2) 81 | vals = np.where(zero, np.nan, m3 / m2**1.5) 82 | return (vals[()])**2 83 | 84 | regions = pd.DataFrame({ 85 | "logit1" : logits[:, 0], 86 | "logit2" : logits[:, 1], 87 | "region" : regions if isinstance(regions, list) else regions.tolist() 88 | }) 89 | 90 | dmr_labels = regions["region"].unique() 91 | region_purity = np.zeros(len(dmr_labels)) 92 | fi, likelihood = np.zeros(len(dmr_labels)), np.zeros(len(dmr_labels)) 93 | for idx, dmr_label in enumerate(dmr_labels): 94 | dmr_logits = regions[regions["region"] == dmr_label] 95 | purities, fi[idx], likelihood[idx] = grid_search(np.array(dmr_logits.loc[:, ["logit1", "logit2"]]), margins, n_grid, verbose=False) 96 | region_purity[idx] = purities[0] 97 | 98 | # EM algorithm for the adjustment 99 | weights = np.ones(len(dmr_labels)) 100 | prev_mean = np.inf 101 | estimates = np.mean(np.multiply(region_purity,weights)) 102 | 103 | for iterration in tqdm(range(10)): 104 | prev_mean = estimates 105 | x = minimize(skewness_test, weights, args=(region_purity, estimates)) 106 | weights = x["x"] 107 | estimates = np.mean(np.multiply(region_purity,weights)) 108 | 109 | if abs(estimates - prev_mean) < 0.0001: 110 | break 111 | 112 | estimates = np.clip(estimates, 0, 1) 113 | 114 | return [estimates, 1-estimates], list(fi), dmr_labels, list(likelihood), #list(region_purity), list(weights) 115 | 116 | def optimise_nll_deconvolute(reads : pd.DataFrame, 117 | margins : pd.Series): 118 | ''' 119 | Deconvolution for multiple cell types 120 | ''' 121 | 122 | estimates = minimize(nll_multi_celltype, 123 | margins.to_numpy(), 124 | args=(reads, margins), 125 | method='SLSQP', 126 | bounds=[(1e-10, 1-1e-10) for i in range(margins.shape[0])], 127 | constraints={'type': 'eq', 'fun': lambda x: np.sum(x)-1}) 128 | return pd.DataFrame.from_dict({"cell_type":margins.keys().tolist(), 129 | "pred":estimates.x}) 130 | 131 | def purity_estimation(reads : pd.DataFrame, 132 | margins : pd.Series, 133 | n_grid : int, 134 | adjustment : bool): 135 | 136 | margins = margins[["N","T"]].tolist() 137 | 138 | # Tumour-normal deconvolution 139 | if adjustment: 140 | # Adjustment applied 141 | estimation, fi, dmr_labels, likelihood = \ 142 | grid_search_regions(reads.loc[:,["P_N", "P_ctype"]].to_numpy(), 143 | margins, 144 | n_grid, 145 | reads["dmr_label"]) 146 | else: 147 | estimation, fi, likelihood = grid_search(reads.loc[:,["P_N", "P_ctype"]].to_numpy(), 148 | margins, 149 | n_grid) 150 | 151 | # Dataframe for the results 152 | deconv_res = pd.DataFrame.from_dict({"cell_type":["T", "N"], "pred":estimation}) 153 | if type(fi) is not list: 154 | fi = [fi] 155 | fi_res = pd.DataFrame.from_dict({"fi":fi, "likelihood": likelihood}) 156 | else: 157 | fi_res = pd.DataFrame.from_dict({"dmr_label":dmr_labels, 158 | "fi":fi, 159 | "likelihood": likelihood}).sort_values("dmr_label") 160 | 161 | return deconv_res, fi_res 162 | 163 | def deconvolute(trainer : MethylBertFinetuneTrainer, 164 | data_loader : DataLoader, 165 | df_train : pd.DataFrame, 166 | tokenizer : MethylVocab, 167 | output_path : str = "./", 168 | n_grid : int = 10000, 169 | adjustment : bool = False): 170 | ''' 171 | Tumour deconvolution for the given bulk 172 | 173 | trainer: MethylBertFinetuneTrainer 174 | Fine-tuned methylbert model contained in a MethylBertFinetuneTrainer object 175 | data_loader: torch.utils.data.DataLoader 176 | DataLoader containing sequencing reads from the bulk 177 | df_train: pandas.DataFrame 178 | DataFrame containing the training data. This is for calculating margins (marginal probability in the Bayes' theorem) 179 | output_path: str (defalut: "./") 180 | Directory to save the results 181 | n_grid: int (default: 10000) 182 | Number of grids for the grid-search algorithm. The higher the number is, the more precise the tumour purity estimation will be 183 | adjustment: bool (default: False) 184 | Whether you want to conduct the estimation adjustment or not 185 | ''' 186 | 187 | if not os.path.exists(output_path): 188 | os.mkdir(output_path) 189 | 190 | # Read classification 191 | total_res, logits = trainer.read_classification(data_loader=data_loader, 192 | tokenizer=tokenizer, 193 | logit=True) 194 | total_res = total_res.drop(columns=["ctype_label"]) 195 | 196 | # Save the classification results 197 | total_res["n_cpg"]=total_res["methyl_seq"].apply(lambda x: x.count("0") + x.count("1")) 198 | total_res["P_ctype"] = logits[:,1] 199 | total_res.to_csv(output_path+"/res.csv", sep="\t", header=True, index=False) 200 | total_res["P_N"] = logits[:,0] 201 | 202 | # Select reads which contain methylation patterns 203 | total_res = total_res[total_res["n_cpg"]>0] 204 | assert total_res.shape[0] != 0, "There are no reads selected for deconvolution. It may mean all of the reads do not have CpG methylation." 205 | 206 | # Calculate prior from training data 207 | margins = df_train.value_counts("ctype", normalize=True) 208 | 209 | print("Margins : ", margins) 210 | print(total_res.head()) 211 | 212 | if len(margins.keys()) == 2: 213 | # purity estimation 214 | deconv_res, fi_res = purity_estimation(reads = total_res, 215 | margins = margins, 216 | n_grid = n_grid, 217 | adjustment = adjustment) 218 | deconv_res.to_csv(output_path+"/deconvolution.csv", sep="\t", header=True, index=False) 219 | fi_res.to_csv(output_path+"/FI.csv", sep="\t", header=True, index=False) 220 | elif len(margins.keys()) > 2: 221 | # multiple cell-type deconvolution 222 | deconv_res = optimise_nll_deconvolute(reads = total_res, margins = margins) 223 | deconv_res.to_csv(output_path+"/deconvolution.csv", sep="\t", header=True, index=False) 224 | else: 225 | raise RuntimeError(f"There are less than two cell types in the training data set. {margins.keys()} Neither purity estimation nor deconvolution can be performed.") 226 | 227 | 228 | -------------------------------------------------------------------------------- /src/methylbert/function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn.modules.loss import _Loss 4 | 5 | def sigmoid_focal_loss( 6 | inputs: torch.Tensor, 7 | targets: torch.Tensor, 8 | alpha: float = 0.1, 9 | gamma: float = 2, 10 | reduction: str = "none", 11 | ) -> torch.Tensor: 12 | """ 13 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 14 | This code is from : https://pytorch.org/vision/main/_modules/torchvision/ops/focal_loss.html 15 | 16 | Args: 17 | inputs (Tensor): A float tensor of arbitrary shape. 18 | The predictions for each example. 19 | targets (Tensor): A float tensor with the same shape as inputs. Stores the binary 20 | classification label for each element in inputs 21 | (0 for the negative class and 1 for the positive class). 22 | alpha (float): Weighting factor in range (0,1) to balance 23 | positive vs negative examples or -1 for ignore. Default: ``0.25``. 24 | gamma (float): Exponent of the modulating factor (1 - p_t) to 25 | balance easy vs hard examples. Default: ``2``. 26 | reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` 27 | ``'none'``: No reduction will be applied to the output. 28 | ``'mean'``: The output will be averaged. 29 | ``'sum'``: The output will be summed. Default: ``'none'``. 30 | Returns: 31 | Loss tensor with the reduction option applied. 32 | """ 33 | # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py 34 | 35 | 36 | p = torch.sigmoid(inputs) 37 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 38 | p_t = p * targets + (1 - p) * (1 - targets) 39 | loss = ce_loss * ((1 - p_t) ** gamma) 40 | 41 | if alpha >= 0: 42 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 43 | loss = alpha_t * loss 44 | 45 | # Check reduction option and return loss accordingly 46 | if reduction == "none": 47 | pass 48 | elif reduction == "mean": 49 | loss = loss.mean() 50 | elif reduction == "sum": 51 | loss = loss.sum() 52 | else: 53 | raise ValueError( 54 | f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'" 55 | ) 56 | return loss 57 | 58 | 59 | class FocalLoss(_Loss): 60 | 61 | def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None: 62 | super().__init__(size_average, reduce, reduction) 63 | 64 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 65 | return sigmoid_focal_loss(input, target, reduction=self.reduction) -------------------------------------------------------------------------------- /src/methylbert/network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | import numpy as np 6 | import math, os 7 | from copy import deepcopy 8 | 9 | from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM 10 | from methylbert.function import FocalLoss 11 | from methylbert.config import MethylBERTConfig 12 | 13 | METHYLBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { 14 | "hanyangii/methylbert_hg19_12l": "https://huggingface.co/hanyangii/methylbert_hg19_12l/resolve/main/pytorch_model.bin", 15 | "hanyangii/methylbert_hg19_8l": "https://huggingface.co/hanyangii/methylbert_hg19_8l/resolve/main/pytorch_model.bin", 16 | "hanyangii/methylbert_hg19_6l": "https://huggingface.co/hanyangii/methylbert_hg19_6l/resolve/main/pytorch_model.bin", 17 | "hanyangii/methylbert_hg19_4l": "https://huggingface.co/hanyangii/methylbert_hg19_4l/resolve/main/pytorch_model.bin", 18 | "hanyangii/methylbert_hg19_2l": "https://huggingface.co/hanyangii/methylbert_hg19_2l/resolve/main/pytorch_model.bin", 19 | } 20 | 21 | class MethylBertEmbeddedDMR(BertPreTrainedModel): 22 | pretrained_model_archive_map = METHYLBERT_PRETRAINED_MODEL_ARCHIVE_MAP 23 | config_class = MethylBERTConfig 24 | base_model_prefix = "methylbert" 25 | 26 | def __init__(self, config, seq_len=150): 27 | # from pretrained - calls the init 28 | super().__init__(config) 29 | self.num_labels = config.num_labels 30 | 31 | if config.loss not in ["bce", "focal_bce"]: 32 | raise ValueError(f"loss must be bce or focal_bce. {config.loss} is given.") 33 | 34 | self.loss = config.loss 35 | self.classification_loss_fct = self._setup_loss(self.loss) 36 | self.bert = BertModel(config) 37 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 38 | self.read_classifier = nn.Sequential( 39 | nn.Linear((config.hidden_size+1)*(seq_len+1), seq_len+1), 40 | nn.Dropout(0.05),#config.hidden_dropout_prob), 41 | nn.ReLU(), 42 | nn.LayerNorm(seq_len+1, eps=config.layer_norm_eps), 43 | nn.Linear(seq_len+1, 2) 44 | ) 45 | 46 | self.seq_len = seq_len 47 | 48 | self.dmr_encoder = nn.Sequential( 49 | nn.Embedding(num_embeddings=self.num_labels, embedding_dim = seq_len+1), 50 | ) 51 | 52 | self.init_weights() 53 | 54 | def _setup_loss(self, loss): 55 | if loss == "bce": 56 | print("Cross entropy loss assigned") 57 | return nn.CrossEntropyLoss() # this function requires unnormalised logits 58 | elif loss == "focal_bce": 59 | print("Focal loss assigned") 60 | return FocalLoss() 61 | 62 | 63 | def check_model_status(self): 64 | print("Bert model training mode : %s"%(self.bert.training)) 65 | print("Dropout training mode : %s"%(self.dropout.training)) 66 | print("Read classifier training mode : %s"%(self.read_classifier.training)) 67 | 68 | def from_pretrained_read_classifier(self, pretrained_model_name_or_path, device="cpu"): 69 | self.read_classifier.load_state_dict(torch.load(pretrained_model_name_or_path, map_location=device)) 70 | 71 | def from_pretrained_dmr_encoder(self, pretrained_model_name_or_path, device="cpu"): 72 | self.dmr_encoder.load_state_dict(torch.load(pretrained_model_name_or_path, map_location=device)) 73 | 74 | def forward( 75 | self, 76 | step, 77 | input_ids=None, 78 | attention_mask=None, 79 | token_type_ids=None, 80 | position_ids=None, 81 | head_mask=None, 82 | inputs_embeds=None, 83 | labels=None, 84 | ctype_label=None 85 | ): 86 | 87 | outputs = self.bert( 88 | input_ids, 89 | attention_mask=attention_mask, 90 | token_type_ids=token_type_ids, 91 | position_ids=position_ids, 92 | head_mask=head_mask, 93 | inputs_embeds=inputs_embeds, 94 | ) 95 | 96 | sequence_output = outputs[0] 97 | sequence_output = self.dropout(sequence_output) 98 | 99 | #DMR info 100 | encoded_dmr = self.dmr_encoder(labels.view(-1)) 101 | 102 | sequence_output = torch.cat((sequence_output, encoded_dmr.unsqueeze(-1)), axis=-1) 103 | 104 | ctype_logits = self.read_classifier(sequence_output.view(-1,(self.seq_len+1)*769)) 105 | 106 | loss = self.classification_loss_fct(ctype_logits.view(-1, 2), 107 | F.one_hot(ctype_label, num_classes=2).to(torch.float32).view(-1, 2)) 108 | ctype_logits = ctype_logits.softmax(dim=1) 109 | 110 | outputs = {"loss": loss, 111 | "dmr_logits":sequence_output, 112 | "classification_logits": ctype_logits} 113 | 114 | return outputs # (loss), logits, (hidden_states), (attentions) 115 | 116 | 117 | -------------------------------------------------------------------------------- /src/methylbert/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import warnings 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torch.cuda.amp as amp 9 | import torch.nn as nn 10 | from sklearn.metrics import accuracy_score, auc, roc_curve 11 | from torch.cuda.amp import GradScaler 12 | from torch.optim import Adam, AdamW 13 | from torch.optim.lr_scheduler import LambdaLR 14 | from torch.utils.data import DataLoader 15 | from tqdm.auto import tqdm 16 | from transformers import (BertConfig, BertForMaskedLM, 17 | BertForSequenceClassification) 18 | 19 | from methylbert.config import MethylBERTConfig, get_config 20 | from methylbert.data.vocab import MethylVocab 21 | from methylbert.network import MethylBertEmbeddedDMR 22 | from methylbert.utils import get_dna_seq 23 | 24 | torch.set_warn_always(False) # one warning per process 25 | 26 | def learning_rate_scheduler(optimizer, num_warmup_steps: int, num_training_steps: int, decrease_steps: int): 27 | """ 28 | Modified version of get_linear_schedule_with_warmup from transformers 29 | Learning rate scheduler including warm-up, retaining and decrease 30 | 31 | optimizer: torch.optim.Optimizer 32 | Optimizer 33 | num_warmup_steps: int 34 | Initial steps for linear warm-up 35 | num_training_steps: int 36 | Total training steps 37 | decrease_steps: 38 | Steps when the learning rate decrease starts 39 | """ 40 | 41 | def lr_lambda(current_step): 42 | if current_step <= num_warmup_steps: # warm-up 43 | return float(current_step) / float(max(1, num_warmup_steps)) 44 | elif current_step >= decrease_steps: # decrease 45 | return max( 46 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - decrease_steps)) 47 | ) 48 | return 1 # Otherwise, keep the current learning rate 49 | 50 | return LambdaLR(optimizer, lr_lambda, last_epoch = -1) 51 | 52 | 53 | 54 | class MethylBertTrainer(object): 55 | def __init__(self, 56 | vocab_size: int, 57 | save_path: str = "", 58 | train_dataloader: DataLoader = None, 59 | test_dataloader: DataLoader = None, 60 | **kwargs): 61 | 62 | # Setup config 63 | self._config = get_config(**kwargs) 64 | 65 | # Setup dataloader 66 | self.train_data = train_dataloader 67 | self.test_data = test_dataloader 68 | 69 | # Setup cuda device for BERT training, argument -c, --cuda should be true 70 | self._config.amp = torch.cuda.is_available() and self._config.with_cuda 71 | if self._config.with_cuda and torch.cuda.device_count() < 1: 72 | print("No detected GPU device. Load the model on CPU") 73 | self._config.with_cuda = False 74 | print("The model is loaded on %s"%("GPU" if self._config.with_cuda else "CPU")) 75 | self.device = torch.device("cuda:0" if self._config.with_cuda else "cpu") 76 | 77 | # To save the best model 78 | self.min_loss = np.inf 79 | self.save_path = save_path 80 | if save_path and not os.path.exists(save_path): 81 | os.mkdir(save_path) 82 | self.f_train = os.path.join(self.save_path, "train.csv") 83 | self.f_eval = os.path.join(self.save_path, "eval.csv") 84 | 85 | 86 | def save(self, file_path="output/bert_trained.model"): 87 | ''' 88 | Saving the current BERT model on file_path 89 | 90 | file_path: str 91 | model output path which gonna be file_path+"ep%d" % epoch 92 | ''' 93 | self.bert.to("cpu") 94 | self.bert.save_pretrained(file_path) 95 | self.bert.to(self.device) 96 | print("Step:%d Model Saved on:" % self.step, file_path) 97 | 98 | def _setup_model(self): 99 | ''' 100 | Load the model to the designated device (CPU or GPU) and create an optimiser 101 | 102 | ''' 103 | self.model = self.bert.to(self.device) 104 | print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) 105 | 106 | # Distributed GPU training if CUDA can detect more than 1 GPU 107 | if self._config.with_cuda and torch.cuda.device_count() > 1: 108 | print("Using %d GPUs for BERT" % torch.cuda.device_count()) 109 | self.model = nn.DataParallel(self.model) 110 | 111 | if not self._config.eval: 112 | # Setting the AdamW optimizer with hyper-param 113 | self.optim = AdamW(self.model.parameters(), 114 | lr=self._config.lr, betas=self._config.beta, eps=self._config.eps, weight_decay=self._config.weight_decay) 115 | 116 | def train(self, steps: int = 0, verbose: int = 1): 117 | ''' 118 | Train MethylBERT over given steps 119 | 120 | steps: int 121 | number of steps to train the model 122 | ''' 123 | return self._iteration(steps, self.train_data, verbose) 124 | 125 | def test(self, test_dataloader: DataLoader): 126 | ''' 127 | Test/Evaluation of MethylBERT model with given data 128 | 129 | test_dataloader: DataLoader 130 | Data loader for test data 131 | ''' 132 | pass 133 | 134 | def load(self, file_path: str): 135 | ''' 136 | Restore the BERT model store in the given path 137 | ''' 138 | print(f"Restore the pretrained model from {file_path}") 139 | self.bert = BertForMaskedLM.from_pretrained(file_path, 140 | num_labels=self.train_data.dataset.num_dmrs(), 141 | output_attentions=True, 142 | output_hidden_states=True, 143 | hidden_dropout_prob=0.01, 144 | vocab_size = len(self.train_data.dataset.vocab)) 145 | 146 | # Initialize the model 147 | self._setup_model() 148 | 149 | def create_model(self, config_file=None): 150 | """ 151 | Create a new BERT MLM model from the configuration 152 | :param config_file: path to the configuration file 153 | """ 154 | pass 155 | 156 | def _acc(self, pred, label): 157 | """ 158 | Calculate accuracy between the predicted and the ground-truth values 159 | 160 | :param pred: predicted values 161 | :param label: ground-truth values 162 | """ 163 | 164 | if type(pred).__module__ != np.__name__: 165 | pred = pred.numpy() 166 | if type(label).__module__ != np.__name__: 167 | label = label.numpy() 168 | 169 | if len(pred.shape) > 1: 170 | pred = pred.flatten() 171 | if len(label.shape) > 1: 172 | label = label.flatten() 173 | 174 | return accuracy_score(y_true=label, y_pred=pred) 175 | 176 | 177 | class MethylBertPretrainTrainer(MethylBertTrainer): 178 | 179 | def __init__(self, *args, **kwargs): 180 | super().__init__(*args, **kwargs) 181 | pass 182 | 183 | def create_model(self, *args, **kwargs): 184 | config = BertConfig(vocab_size = len(self.train_data.dataset.vocab), *args, **kwargs) 185 | self.bert = BertForMaskedLM(config) 186 | self._setup_model() 187 | 188 | def _eval_iteration(self, data_loader): 189 | """ 190 | loop over the data_loader for evaluation 191 | 192 | :param data_loader: torch.utils.data.DataLoader for test 193 | :return: DataFrame, 194 | """ 195 | 196 | predict_res = {"prediction": [], "input": [], "label": [], "mask": []} 197 | 198 | mean_loss = 0 199 | self.model.eval() 200 | 201 | for i, batch in enumerate(data_loader): 202 | 203 | data = {key: value.to(self.device) for key, value in batch.items()} 204 | 205 | with torch.no_grad(): 206 | with torch.autocast(device_type="cuda" if self._config.with_cuda else "cpu", 207 | enabled=self._config.amp): 208 | mask_lm_output = self.model.forward(input_ids = data["input"], 209 | masked_lm_labels = data["label"]) 210 | 211 | mean_loss += mask_lm_output[0].mean().item()/len(data_loader) 212 | predict_res["prediction"].append(np.argmax(mask_lm_output[1].cpu().detach(), axis=-1)) 213 | predict_res["input"].append(data["input"].cpu().detach()) 214 | predict_res["label"].append(data["label"].cpu().detach()) 215 | predict_res["mask"].append(data["mask"].cpu().detach()) 216 | 217 | if self._config.eval: 218 | print("Batch %d/%d is done...."%(i, len(data_loader))) 219 | 220 | del mask_lm_output 221 | del data 222 | 223 | # Integrate all results 224 | predict_res["prediction"] = np.concatenate(predict_res["prediction"], axis=0) 225 | predict_res["input"] = np.concatenate(predict_res["input"], axis=0) 226 | predict_res["label"] = np.concatenate(predict_res["label"], axis=0) 227 | predict_res["mask"] = np.concatenate(predict_res["mask"], axis=0) 228 | 229 | self.model.train() 230 | return predict_res, np.mean(mean_loss) 231 | 232 | 233 | def _iteration(self, steps, data_loader, verbose): 234 | """ 235 | loop over the data_loader for training or testing 236 | if on train status, backward operation is activated 237 | and also auto save the model every epoch 238 | 239 | :param steps: total steps to train 240 | :param data_loader: torch.utils.data.DataLoader for training 241 | :param warm_up: number of steps for warming up the learning rate 242 | :return: None 243 | """ 244 | predict_res = {"prediction": [], "input": [], "label": []} 245 | self.step = 0 246 | 247 | if os.path.exists(self.f_train): 248 | os.remove(self.f_train) 249 | 250 | with open(self.f_train, "a") as f_perform: 251 | f_perform.write("step\tloss\tacc\tlr\n") 252 | 253 | if os.path.exists(self.f_eval): 254 | os.remove(self.f_eval) 255 | 256 | with open(self.f_eval, "a") as f_perform: 257 | f_perform.write("step\ttest_acc\ttest_loss\n") 258 | 259 | 260 | # Set up a learning rate scheduler 261 | self.scheduler = learning_rate_scheduler(self.optim, 262 | num_warmup_steps=self._config.warmup_step, 263 | num_training_steps=steps, 264 | decrease_steps=self._config.decrease_steps) 265 | 266 | # Set up configuration for train iteration 267 | global_step_loss = 0 268 | local_step = 0 269 | 270 | epochs = steps // (len(data_loader) // self._config.gradient_accumulation_steps) + 1 271 | self.model.zero_grad() 272 | self.model.train() 273 | train_prediction_res = {"prediction":[], "label":[]} 274 | 275 | scaler = GradScaler() if self._config.amp else None 276 | 277 | duration = 0 278 | for epoch in range(epochs): 279 | for i, batch in enumerate(data_loader): 280 | # 0. batch_data will be sent into the device(GPU or cpu) 281 | data = {key: value.to(self.device) for key, value in batch.items()} 282 | 283 | start = time.time() 284 | 285 | with torch.autocast(device_type="cuda" if self._config.with_cuda else "cpu", 286 | enabled=self._config.amp): 287 | mask_lm_output = self.model.forward(input_ids = data["bert_input"], 288 | masked_lm_labels = data["bert_label"]) 289 | 290 | loss = mask_lm_output[0] 291 | 292 | # Concatenate predicted sequences for the evaluation 293 | train_prediction_res["prediction"].append(np.argmax(mask_lm_output[1].cpu().detach(), axis=-1)) 294 | train_prediction_res["label"].append(data["bert_label"].cpu().detach()) 295 | 296 | # Calculate loss and back-propagation 297 | if "cuda" in self.device.type: 298 | loss = loss.mean() 299 | loss = loss/self._config.gradient_accumulation_steps 300 | scaler.scale(loss).backward() if self._config.amp else loss.backward() 301 | 302 | global_step_loss += loss.item() 303 | duration += time.time() - start 304 | 305 | # Gradient accumulation 306 | if (local_step+1) % self._config.gradient_accumulation_steps == 0: 307 | 308 | if self._config.amp: 309 | scaler.unscale_(self.optim) 310 | nn.utils.clip_grad_norm_(self.model.parameters(), self._config.max_grad_norm) 311 | scaler.step(self.optim) 312 | scaler.update() 313 | else: 314 | nn.utils.clip_grad_norm_(self.model.parameters(), self._config.max_grad_norm) 315 | self.optim.step() 316 | 317 | self.scheduler.step() 318 | self.model.zero_grad() 319 | 320 | # Evaluation with both train and testdata 321 | if self.test_data is not None and self.step % self._config.eval_freq == 0 and self.step > 0: 322 | 323 | test_pred, test_loss = self._eval_iteration(self.test_data) 324 | idces = np.where(test_pred["label"]>=0) 325 | test_pred_acc = self._acc(test_pred["prediction"][idces[0], idces[1]], 326 | test_pred["label"][idces[0], idces[1]]) 327 | 328 | with open(self.f_eval, "a") as f_perform: 329 | f_perform.write("\t".join([str(self.step), str(test_pred_acc), str(test_loss)]) +"\n") 330 | 331 | del test_pred 332 | 333 | if self.step % self._config.log_freq == 0: 334 | print("\nTrain Step %d iter - loss : %f / lr : %f"%(self.step, global_step_loss, self.optim.param_groups[0]["lr"])) 335 | print(f"Running time for iter = {duration}") 336 | 337 | if self.min_loss > global_step_loss: 338 | print("Step %d loss (%f) is lower than the current min loss (%f). Save the model at %s"%(self.step, global_step_loss, self.min_loss, self.save_path)) 339 | self.save(self.save_path) 340 | self.min_loss = global_step_loss 341 | 342 | # Save the step info (step, loss, lr, acc) 343 | with open(self.f_train, "a") as f_perform: 344 | 345 | train_prediction_res["prediction"] = np.concatenate(train_prediction_res["prediction"], axis=0) 346 | train_prediction_res["label"] = np.concatenate(train_prediction_res["label"], axis=0) 347 | 348 | idces = np.where(train_prediction_res["label"]>=0) 349 | train_pred_acc = self._acc(train_prediction_res["prediction"][idces[0], idces[1]], 350 | train_prediction_res["label"][idces[0], idces[1]]) 351 | 352 | f_perform.write("\t".join([str(self.step), str(global_step_loss), str(train_pred_acc), str(self.optim.param_groups[0]["lr"])])+"\n") 353 | 354 | self.step += 1 355 | 356 | duration=0 357 | global_step_loss = 0 358 | del train_prediction_res 359 | train_prediction_res = {"prediction":[], "label":[]} 360 | 361 | if steps == self.step: 362 | break 363 | local_step+=1 364 | 365 | if steps == self.step: 366 | break 367 | 368 | 369 | class MethylBertFinetuneTrainer(MethylBertTrainer): 370 | def __init__(self, *args, **kwargs): 371 | super().__init__(*args, **kwargs) 372 | 373 | def summary(self): 374 | ''' 375 | Print the summary of the MethylBERT model 376 | ''' 377 | 378 | print(self.model) 379 | 380 | def create_model(self, config_file: str = None): 381 | ''' 382 | Create a new MethylBERT model from the configuration 383 | ''' 384 | 385 | config = MethylBERTConfig.from_pretrained(config_file, 386 | num_labels=self.train_data.dataset.num_dmrs(), 387 | output_attentions=True, 388 | output_hidden_states=True, 389 | hidden_dropout_prob=0.01, 390 | vocab_size = len(self.train_data.dataset.vocab), 391 | loss=self._config.loss) 392 | 393 | self.bert = MethylBertEmbeddedDMR(config=config, 394 | seq_len=self.train_data.dataset.seq_len) 395 | 396 | # Initialize the BERT Language Model, with BERT model 397 | self._setup_model() 398 | 399 | def _eval_iteration(self, data_loader: DataLoader, return_logits: bool = False): 400 | """ 401 | loop over the data_loader for eval/test 402 | 403 | :param data_loader: torch.utils.data.DataLoader for test 404 | :return: DataFrame, 405 | """ 406 | 407 | predict_res = {"dmr_label":[], "pred_ctype_label":[], "ctype_label":[]} 408 | logits = list() 409 | 410 | mean_loss = 0 411 | self.model.eval() 412 | with torch.no_grad(): 413 | for i, batch in enumerate(data_loader): 414 | # 0. batch_data will be sent into the device(GPU or cpu) 415 | data = {key: value.to(self.device) for key, value in batch.items() if type(value) != list} 416 | 417 | with torch.autocast(device_type="cuda" if self._config.with_cuda else "cpu", enabled=self._config.amp): 418 | mask_lm_output = self.model.forward(step=self.step, 419 | input_ids = data["dna_seq"], 420 | token_type_ids=data["methyl_seq"], 421 | labels = data["dmr_label"], 422 | ctype_label=data["ctype_label"]) 423 | 424 | loss = mask_lm_output["loss"].mean().item() if "cuda" in self.device.type else mask_lm_output["loss"].item() 425 | mean_loss += loss/len(data_loader) 426 | 427 | if self._config.with_cuda and torch.cuda.device_count() > 1: 428 | torch.cuda.synchronize() 429 | 430 | predict_res["dmr_label"].append(data["dmr_label"].detach().cpu()) 431 | predict_res["pred_ctype_label"].append(torch.argmax(mask_lm_output["classification_logits"], dim=-1).detach().cpu()) 432 | predict_res["ctype_label"].append(data["ctype_label"].detach().cpu()) 433 | 434 | if return_logits: 435 | logits.append(mask_lm_output["classification_logits"].detach().cpu().numpy()) 436 | 437 | del mask_lm_output 438 | del data 439 | 440 | predict_res["dmr_label"] = np.concatenate(predict_res["dmr_label"], axis=0) 441 | predict_res["ctype_label"] = np.concatenate(predict_res["ctype_label"], axis=0) 442 | predict_res["pred_ctype_label"] = np.concatenate(predict_res["pred_ctype_label"], axis=0) 443 | 444 | self.model.train() 445 | 446 | if not return_logits: 447 | return predict_res, mean_loss 448 | else: 449 | return predict_res, mean_loss, np.concatenate(return_logits, axis=0) 450 | 451 | 452 | 453 | def _iteration(self, steps, data_loader, verbose = 1): 454 | """ 455 | loop over the data_loader for training or testing 456 | if on train status, backward operation is activated 457 | and also auto save the model every peoch 458 | 459 | :param steps: total steps to train 460 | :param data_loader: torch.utils.data.DataLoader for training 461 | :param warm_up: number of steps for warming up the learning rate 462 | :return: None 463 | """ 464 | 465 | self.step = 0 466 | 467 | if os.path.exists(self.f_train): 468 | os.remove(self.f_train) 469 | 470 | with open(self.f_train, "w") as f_perform: 471 | f_perform.write("step\tloss\tctype_acc\tlr\n") 472 | 473 | if os.path.exists(self.f_eval): 474 | os.remove(self.f_eval) 475 | 476 | with open(self.f_eval, "w") as f_perform: 477 | f_perform.write("step\tloss\tctype_acc\n") 478 | 479 | 480 | # Set up a learning rate scheduler 481 | self.scheduler = learning_rate_scheduler(self.optim, 482 | num_warmup_steps=self._config.warmup_step, 483 | num_training_steps=steps, 484 | decrease_steps=self._config.decrease_steps) 485 | global_step_loss = 0 486 | local_step = 0 487 | 488 | epochs = steps // (len(data_loader) // self._config.gradient_accumulation_steps) + 1 489 | 490 | self.model.zero_grad() 491 | if verbose > 0: 492 | print(self.model.training) 493 | self.model.train() 494 | train_prediction_res = {"dmr_label":[], "pred_ctype_label":[], "ctype_label":[]} 495 | 496 | scaler = GradScaler() if self._config.amp else None 497 | 498 | duration = 0 499 | epoch_progress_bar = tqdm(total=epochs, desc="Training...") 500 | for epoch in range(epochs): 501 | steps_progress_bar = tqdm(total=min(steps, len(data_loader)), 502 | desc=f"Epoch {epoch+1}/{epochs}") 503 | for i, batch in enumerate(data_loader): 504 | # 0. batch_data will be sent into the device(GPU or cpu) 505 | data = {key: value.to(self.device) for key, value in batch.items() if type(value) != list} 506 | 507 | start = time.time() 508 | with torch.autocast(device_type="cuda" if self._config.with_cuda else "cpu", 509 | enabled=self._config.amp): 510 | mask_lm_output = self.model.forward(step=self.step, 511 | input_ids=data["dna_seq"], 512 | token_type_ids=data["methyl_seq"], 513 | labels=data["dmr_label"], 514 | ctype_label=data["ctype_label"]) 515 | loss = mask_lm_output["loss"] 516 | 517 | # Concatenate predicted sequences for the evaluation 518 | train_prediction_res["dmr_label"].append(data["dmr_label"].detach().cpu()) 519 | 520 | 521 | # Cell-type classification 522 | train_prediction_res["pred_ctype_label"].append(np.argmax(mask_lm_output["classification_logits"].cpu().detach(), axis=-1)) 523 | train_prediction_res["ctype_label"].append(data["ctype_label"].detach().cpu()) 524 | 525 | 526 | # Calculate loss and back-propagation 527 | loss = mask_lm_output["loss"].mean() if "cuda" in self.device.type else mask_lm_output["loss"] 528 | loss = loss/self._config.gradient_accumulation_steps 529 | scaler.scale(loss).backward(retain_graph=True) if self._config.amp else loss.backward(retain_graph=True) 530 | 531 | loss_val = loss.item() 532 | global_step_loss += loss_val 533 | 534 | duration += time.time() - start 535 | # Gradient accumulation 536 | if (local_step+1) % self._config.gradient_accumulation_steps == 0: 537 | gradient_accum_start = time.time() 538 | if self._config.amp: 539 | scaler.unscale_(self.optim) 540 | nn.utils.clip_grad_norm_(self.model.parameters(), self._config.max_grad_norm) 541 | scaler.step(self.optim) 542 | scaler.update() 543 | else: 544 | nn.utils.clip_grad_norm_(self.model.parameters(), self._config.max_grad_norm) 545 | self.optim.step() 546 | 547 | self.scheduler.step() 548 | self.model.zero_grad() 549 | 550 | if (local_step+1) % self._config.eval_freq == 0 or local_step == 0: 551 | # Evaluation 552 | eval_pred, eval_loss = self._eval_iteration(self.test_data) 553 | eval_acc = self._acc(eval_pred["pred_ctype_label"], eval_pred["ctype_label"]) 554 | 555 | with open(self.f_eval, "a") as f_perform: 556 | f_perform.write("\t".join([str(self.step), str(eval_loss), str(eval_acc)]) +"\n") 557 | 558 | del eval_pred 559 | 560 | if self.step % self._config.log_freq == 0: 561 | if verbose > 0: 562 | print("\nTrain Step %d iter - loss : %f / lr : %f"%(self.step, global_step_loss, self.optim.param_groups[0]["lr"])) 563 | print(f"Running time for iter = {duration}") 564 | 565 | if self.min_loss > eval_loss: 566 | if verbose > 0: 567 | print("Step %d loss (%f) is lower than the current min loss (%f). Save the model at %s"%(self.step, eval_loss, self.min_loss, self.save_path)) 568 | self.save(self.save_path) 569 | self.min_loss = eval_loss 570 | 571 | # For saving an interim model to track the training 572 | if ( type(self._config.save_freq) == int ) and (self.step % self._config.save_freq == 0): 573 | step_save_dir=self.save_path.replace("bert.model", "bert.model_step%d"%(self.step)) 574 | if verbose > 0: 575 | print("Step %d: Save an interim model at %s"%(self.step, step_save_dir)) 576 | if not os.path.exists(step_save_dir): 577 | os.mkdir(step_save_dir) 578 | self.save(step_save_dir) 579 | 580 | # Save the step info (step, loss, lr, acc) 581 | with open(self.f_train, "a") as f_perform: 582 | 583 | train_prediction_res["dmr_label"] = np.concatenate(train_prediction_res["dmr_label"], axis=0) 584 | train_prediction_res["pred_ctype_label"] = np.concatenate(train_prediction_res["pred_ctype_label"], axis=0) 585 | train_prediction_res["ctype_label"] = np.concatenate(train_prediction_res["ctype_label"], axis=0) 586 | train_ctype_acc = self._acc(train_prediction_res["pred_ctype_label"], train_prediction_res["ctype_label"]) 587 | 588 | f_perform.write("\t".join([str(self.step), str(global_step_loss), str(train_ctype_acc), str(self.optim.param_groups[0]["lr"])])+"\n") 589 | 590 | steps_progress_bar.set_postfix(eval_loss=eval_loss) 591 | # Reset prediction result 592 | del train_prediction_res 593 | train_prediction_res = {"dmr_label":[], "pred_ctype_label":[], "ctype_label":[]} 594 | 595 | self.step += 1 596 | duration=0 597 | global_step_loss = 0 598 | 599 | steps_progress_bar.update() 600 | 601 | if steps == self.step: 602 | break 603 | local_step+=1 604 | 605 | steps_progress_bar.close() 606 | epoch_progress_bar.update() 607 | 608 | if steps == self.step: 609 | break 610 | 611 | def save(self, file_path: str="output/bert_trained.model"): 612 | ''' 613 | Save the MethylBERT model in the given path 614 | ''' 615 | self.bert.to("cpu") 616 | self.bert.save_pretrained(file_path) 617 | 618 | if hasattr(self.bert, "read_classifier"): 619 | torch.save(self.bert.read_classifier.state_dict(), os.path.dirname(file_path)+"/read_classification_model.pickle") 620 | 621 | if hasattr(self.bert, "dmr_encoder"): 622 | torch.save(self.bert.dmr_encoder.state_dict(), os.path.dirname(file_path)+"/dmr_encoder.pickle") 623 | 624 | self.bert.to(self.device) 625 | print("Step:%d Model Saved on:" % self.step, file_path) 626 | 627 | def load(self, dir_path: str, n_dmrs: int=None, load_fine_tune: bool=False): 628 | ''' 629 | Load pre-trained / fine-tuned MethylBERT model 630 | dir_path: str 631 | Directory to the saved bert model. It must contain "config.json" and "pytorch_model.bin" files 632 | n_dmrs: int (default: None) 633 | Number of DMRs to reconstruct the MethylBERT model. If the number is not given, the trainer auto-calculates the number from the same data 634 | load_fine_tune: bool (default: False) 635 | Whether the loaded model is a fine-tuned model including num_dmrs or a pre-trained model without num_dmrs 636 | ''' 637 | print(f"Restore the pretrained model {dir_path}") 638 | 639 | if load_fine_tune: 640 | ''' 641 | if n_dmrs is not None: 642 | raise ValueError("You cannot give a new number of DMRs for loading a fine-tuned model. The model should contains one. Please set either n_dmrs=None or load_fine_tune=False") 643 | ''' 644 | 645 | self.bert = MethylBertEmbeddedDMR.from_pretrained(dir_path, 646 | output_attentions=True, 647 | output_hidden_states=True, 648 | seq_len = self.train_data.dataset.seq_len, 649 | loss=self._config.loss, 650 | num_labels=n_dmrs 651 | ) 652 | 653 | try: 654 | self.bert.from_pretrained_dmr_encoder(os.path.dirname(dir_path)+"/dmr_encoder.pickle", self.device) 655 | print("Restore DMR encoder from %s"%(os.path.dirname(dir_path)+"/dmr_encoder.pickle")) 656 | except FileNotFoundError: 657 | print(os.path.dirname(dir_path)+"/dmr_encoder.pickle is not found.") 658 | 659 | try: 660 | self.bert.from_pretrained_read_classifier(os.path.dirname(dir_path)+"/read_classification_model.pickle", self.device) 661 | print("Restore read classification FCN model from %s"%(os.path.dirname(dir_path)+"/read_classification_model.pickle")) 662 | except FileNotFoundError: 663 | print(os.path.dirname(dir_path)+"/read_classification_model.pickle is not found.") 664 | else: 665 | self.bert = MethylBertEmbeddedDMR.from_pretrained(dir_path, 666 | num_labels=self.train_data.dataset.num_dmrs() if not n_dmrs else n_dmrs, 667 | output_attentions=True, 668 | output_hidden_states=True, 669 | seq_len = self.train_data.dataset.seq_len, 670 | loss=self._config.loss 671 | ) 672 | 673 | self._setup_model() 674 | 675 | def read_classification(self, data_loader: DataLoader = None, tokenizer: MethylVocab = None, logit: bool = False): 676 | ''' 677 | Classify sequencing reads into cell types 678 | 679 | data_loader: torch.utils.data.DataLoader 680 | DataLoader containing reads to classify. If nothing is given, the trainer tries to assign 'test_data' 681 | output_dir: str 682 | Directory to save the result. If nothing is given, the results is saved in 'save_path' 683 | save_logit: bool (default: False) 684 | Whether save the calculated classification logits or not 685 | ''' 686 | 687 | if data_loader is None: 688 | if self.test_data is None: 689 | ValueError("There is no test_data assigned to the trainer. Please give a DataLoader as an input.") 690 | else: 691 | data_loader = self.test_data 692 | 693 | # classification 694 | res = dict() 695 | logits = list() 696 | self.model.eval() 697 | 698 | pbar = tqdm(total=len(data_loader)) 699 | for i, batch in enumerate(data_loader): 700 | 701 | # 0. batch_data will be sent into the device(GPU or cpu) 702 | data = dict() 703 | 704 | for k, v in batch.items(): 705 | if type(v) != list: 706 | data[k] = v.to(self.device) 707 | if k not in res.keys(): 708 | res[k] = v.numpy() if type(v) == torch.Tensor else v 709 | else: 710 | res[k] = np.concatenate([res[k], v.numpy() if type(v) == torch.Tensor else v], axis=0) 711 | 712 | with torch.no_grad(): 713 | with torch.autocast(device_type="cuda" if self._config.with_cuda else "cpu", 714 | enabled=self._config.amp): 715 | mask_lm_output = self.model.forward(step=0, 716 | input_ids = data["dna_seq"], 717 | token_type_ids=data["methyl_seq"], 718 | labels = data["dmr_label"], 719 | ctype_label=data["ctype_label"]) 720 | 721 | if "pred" in res.keys(): 722 | res["pred"] = np.concatenate([res["pred"], np.argmax(mask_lm_output["classification_logits"].cpu().detach(), axis=-1)], axis=0) 723 | else: 724 | res["pred"] = np.argmax(mask_lm_output["classification_logits"].cpu().detach(), axis=-1) 725 | 726 | if logit: 727 | logits.append(mask_lm_output["classification_logits"].cpu().detach().numpy()) 728 | 729 | del mask_lm_output 730 | del data 731 | 732 | pbar.update(1) 733 | pbar.close() 734 | 735 | if logit: 736 | logits = np.concatenate(logits, axis=0) 737 | res["dna_seq"]=[get_dna_seq(s, tokenizer) for s in res["dna_seq"]] 738 | res["methyl_seq"]=["".join([str(mm) for mm in m]) for m in res["methyl_seq"]] 739 | 740 | res = pd.DataFrame(res) 741 | 742 | return res if not logit else res, logits 743 | -------------------------------------------------------------------------------- /src/methylbert/utils.py: -------------------------------------------------------------------------------- 1 | import warnings, random 2 | import numpy as np 3 | import torch 4 | 5 | def get_dna_seq(tokens, tokenizer): 6 | # Convert n-mers tokens into a DNA sequence 7 | seq = tokenizer.from_seq(tokens) 8 | seq = [s for s in seq if "<" not in s] 9 | 10 | seq = seq[0][0] + "".join([s[1] for s in seq]) + seq[-1][-1] 11 | 12 | return seq 13 | 14 | def set_seed(seed: int): 15 | """ 16 | Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if 17 | installed). 18 | 19 | Args: 20 | seed (:obj:`int`): The seed to set. 21 | """ 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | 27 | 28 | def _moment(a, moment, axis, *, mean=None): 29 | if np.abs(moment - np.round(moment)) > 0: 30 | raise ValueError("All moment parameters must be integers") 31 | 32 | # moment of empty array is the same regardless of order 33 | if a.size == 0: 34 | return np.mean(a, axis=axis) 35 | 36 | dtype = a.dtype.type if a.dtype.kind in 'fc' else np.float64 37 | 38 | if moment == 0 or (moment == 1 and mean is None): 39 | # By definition the zeroth moment is always 1, and the first *central* 40 | # moment is 0. 41 | shape = list(a.shape) 42 | del shape[axis] 43 | 44 | if len(shape) == 0: 45 | return dtype(1.0 if moment == 0 else 0.0) 46 | else: 47 | return (np.ones(shape, dtype=dtype) if moment == 0 48 | else np.zeros(shape, dtype=dtype)) 49 | else: 50 | # Exponentiation by squares: form exponent sequence 51 | n_list = [moment] 52 | current_n = moment 53 | while current_n > 2: 54 | if current_n % 2: 55 | current_n = (current_n - 1) / 2 56 | else: 57 | current_n /= 2 58 | n_list.append(current_n) 59 | 60 | # Starting point for exponentiation by squares 61 | mean = (a.mean(axis, keepdims=True) if mean is None 62 | else dtype(mean)) 63 | a_zero_mean = a - mean 64 | 65 | eps = np.finfo(a_zero_mean.dtype).resolution * 10 66 | with np.errstate(divide='ignore', invalid='ignore'): 67 | rel_diff = np.max(np.abs(a_zero_mean), axis=axis) / np.abs(mean) 68 | with np.errstate(invalid='ignore'): 69 | precision_loss = np.any(rel_diff < eps) 70 | n = a.shape[axis] if axis is not None else a.size 71 | if precision_loss and n > 1: 72 | message = ("Precision loss occurred in moment calculation due to " 73 | "catastrophic cancellation. This occurs when the data " 74 | "are nearly identical. Results may be unreliable.") 75 | warnings.warn(message, RuntimeWarning, stacklevel=4) 76 | 77 | if n_list[-1] == 1: 78 | s = a_zero_mean.copy() 79 | else: 80 | s = a_zero_mean**2 81 | 82 | # Perform multiplications 83 | for n in n_list[-2::-1]: 84 | s = s**2 85 | if n % 2: 86 | s *= a_zero_mean 87 | return np.mean(s, axis) -------------------------------------------------------------------------------- /test/data/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "gradient_checkpointing": false, 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 768, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 3072, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 512, 14 | "model_type": "bert", 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "pad_token_id": 0, 18 | "position_embedding_type": "absolute", 19 | "transformers_version": "4.6.0.dev0", 20 | "type_vocab_size": 2, 21 | "use_cache": true, 22 | "vocab_size": 30522 23 | } 24 | -------------------------------------------------------------------------------- /test/data/dmrs.csv: -------------------------------------------------------------------------------- 1 | chr start end length nCG meanMethy1 meanMethy2 diff.Methy areaStat abs_areaStat abs_diff.Methy ctype 2 | chr7 1268957 1277884.0 8928 753 0.793277869414822 0.129746654776829 0.6635312146379929 5722.0917896134315 5722.0917896134315 0.6635312146379929 T 3 | chr10 134597480 134602875.0 5396 670 0.861028679587746 0.140400099393177 0.7206285801945691 6144.08933128302 6144.08933128302 0.7206285801945691 T 4 | chr4 1395812 1402597.0 6786 663 0.831162403454465 0.18527153716544886 0.645890866289016 4941.4100890910795 4941.4100890910795 0.645890866289016 T 5 | chr5 2748155 2754016.0 5862 573 0.8148203766439749 0.17002488685453798 0.644795489789437 4189.2114605573 4189.2114605573 0.644795489789437 T 6 | chr16 54315654 54322597.0 6944 546 0.747227502020603 0.10176442809439196 0.645463073926211 4141.2109332293 4141.2109332293 0.645463073926211 T 7 | chr16 54962053 54967980.0 5928 546 0.783631162448667 0.0960950064378455 0.6875361560108221 4714.55179945095 4714.55179945095 0.6875361560108221 T 8 | chr5 1881585 1888224.0 6640 528 0.7987680443168008 0.12564640421328802 0.673121640103512 4171.561771876471 4171.561771876471 0.673121640103512 T 9 | chr10 131765062 131771581.0 6520 524 0.801528125038318 0.113918236245095 0.687609888793223 4299.66145327898 4299.66145327898 0.687609888793223 T 10 | chr7 157481091 157490504.0 9414 512 0.757144730252902 0.113882929422453 0.6432618008304479 3901.620411746861 3901.620411746861 0.6432618008304479 T 11 | chr18 76736906 76741580.0 4675 510 0.8294750161348671 0.10440258352730697 0.7250724326075592 4684.60838145462 4684.60838145462 0.7250724326075592 T 12 | chr2 176943087 176950973.0 7887 495 0.74654568134662 0.129027940983311 0.617517740363309 3539.5383752478706 3539.5383752478706 0.617517740363309 T 13 | chr5 170735063 170741271.0 6209 492 0.751652265379491 0.11562214253027805 0.636030122849213 3526.901391965229 3526.901391965229 0.636030122849213 T 14 | chr9 139738643 139742599.0 3957 492 0.8808959302812751 0.2310969689783081 0.6497989613029671 3963.5884338135706 3963.5884338135706 0.6497989613029671 T 15 | chr2 171568240 171574738.0 6499 491 0.794196442179253 0.0979056415153933 0.696290800663859 4251.7326592334 4251.7326592334 0.696290800663859 T 16 | chr4 20253098 20257098.0 4001 314 0.684188806286741 0.09329578183853 0.590893024448211 1964.11178773372 1964.11178773372 0.590893024448211 T 17 | chr7 49812430 49815938.0 3509 313 0.825644026891989 0.0778718880265556 0.747772138865433 3131.28977483213 3131.28977483213 0.747772138865433 T 18 | chr9 969337 974269.0 4933 312 0.7340507947651159 0.132192993291835 0.601857801473281 1977.04725239638 1977.04725239638 0.601857801473281 T 19 | chr2 287142 290512.0 3371 312 0.786974615931658 0.0756112977105267 0.711363318221131 2814.2267772240198 2814.2267772240198 0.711363318221131 T 20 | chr13 95362948 95365922.0 2975 311 0.7809442759215401 0.0812341789033101 0.69971009701823 2698.29385912429 2698.29385912429 0.69971009701823 T 21 | chr7 156795124 156799496.0 4373 309 0.809464690245395 0.115485490870645 0.6939791993747491 2671.03304242754 2671.03304242754 0.6939791993747491 T 22 | -------------------------------------------------------------------------------- /test/data/processed/dmrs.csv: -------------------------------------------------------------------------------- 1 | chr start end length nCG meanMethy1 meanMethy2 diff.Methy areaStat abs_areaStat abs_diff.Methy ctype dmr_id 2 | chr10 134597480 134602875.0 5396 670 0.861028679587746 0.140400099393177 0.7206285801945691 6144.08933128302 6144.08933128302 0.7206285801945691 T 0 3 | chr7 1268957 1277884.0 8928 753 0.793277869414822 0.129746654776829 0.6635312146379929 5722.091789613432 5722.091789613432 0.6635312146379929 T 1 4 | chr4 1395812 1402597.0 6786 663 0.831162403454465 0.1852715371654488 0.645890866289016 4941.41008909108 4941.41008909108 0.645890866289016 T 2 5 | chr16 54962053 54967980.0 5928 546 0.783631162448667 0.0960950064378455 0.6875361560108221 4714.55179945095 4714.55179945095 0.6875361560108221 T 3 6 | chr18 76736906 76741580.0 4675 510 0.8294750161348671 0.1044025835273069 0.7250724326075592 4684.60838145462 4684.60838145462 0.7250724326075592 T 4 7 | chr10 131765062 131771581.0 6520 524 0.801528125038318 0.113918236245095 0.687609888793223 4299.66145327898 4299.66145327898 0.687609888793223 T 5 8 | chr2 171568240 171574738.0 6499 491 0.794196442179253 0.0979056415153933 0.696290800663859 4251.7326592334 4251.7326592334 0.696290800663859 T 6 9 | chr5 2748155 2754016.0 5862 573 0.8148203766439749 0.1700248868545379 0.644795489789437 4189.2114605573 4189.2114605573 0.644795489789437 T 7 10 | chr5 1881585 1888224.0 6640 528 0.7987680443168008 0.125646404213288 0.673121640103512 4171.561771876471 4171.561771876471 0.673121640103512 T 8 11 | chr16 54315654 54322597.0 6944 546 0.747227502020603 0.1017644280943919 0.645463073926211 4141.2109332293 4141.2109332293 0.645463073926211 T 9 12 | -------------------------------------------------------------------------------- /test/test_deconvolute.py: -------------------------------------------------------------------------------- 1 | from methylbert.data.vocab import MethylVocab 2 | from methylbert.data.dataset import MethylBertFinetuneDataset 3 | from torch.utils.data import DataLoader 4 | from methylbert.deconvolute import deconvolute 5 | from methylbert.trainer import MethylBertFinetuneTrainer 6 | 7 | import pandas as pd 8 | import os 9 | 10 | def test_adjustment(trainer, tokenizer, data_loader, output_path, df_train): 11 | deconvolute(trainer = trainer, 12 | tokenizer = tokenizer, 13 | data_loader = data_loader, 14 | output_path = output_path, 15 | df_train = df_train, 16 | adjustment = True) 17 | 18 | assert pd.read_csv(os.path.join(output_path, "FI.csv"), sep="\t").shape[0] == len(df_train["dmr_label"].unique()) 19 | 20 | def test_multi_cell_type(trainer, tokenizer, data_loader, output_path, df_train): 21 | deconvolute(trainer = trainer, 22 | tokenizer = tokenizer, 23 | data_loader = data_loader, 24 | output_path = output_path, 25 | df_train = df_train, 26 | adjustment = False) 27 | 28 | assert pd.read_csv(os.path.join(output_path, "deconvolution.csv"), sep="\t").shape[0] == 3 29 | 30 | if __name__=="__main__": 31 | f_bulk = "data/processed/test_seq.csv" 32 | f_train = "data/processed/train_seq.csv" 33 | model_dir="res/" 34 | #model_dir = "tmp/bert.model/" 35 | out_dir = "res/deconvolution/" 36 | 37 | tokenizer = MethylVocab(k=3) 38 | 39 | dataset = MethylBertFinetuneDataset(f_bulk, 40 | tokenizer, 41 | seq_len=150) 42 | data_loader = DataLoader(dataset, batch_size=50, num_workers=20) 43 | df_train = pd.read_csv(f_train, sep="\t") 44 | 45 | trainer = MethylBertFinetuneTrainer(len(tokenizer), 46 | train_dataloader=data_loader, 47 | test_dataloader=data_loader, 48 | ) 49 | trainer.load(model_dir) 50 | 51 | test_adjustment(trainer, tokenizer, data_loader, out_dir, df_train) 52 | # multiple cell type 53 | model_dir = "data/multi_cell_type/res/bert.model/" 54 | f_bulk = "data/multi_cell_type/test_seq.csv" 55 | f_train = "data/multi_cell_type/train_seq.csv" 56 | dataset = MethylBertFinetuneDataset(f_bulk, 57 | tokenizer, 58 | seq_len=150) 59 | data_loader = DataLoader(dataset, batch_size=50, num_workers=20) 60 | df_train = pd.read_csv(f_train, sep="\t") 61 | 62 | trainer = MethylBertFinetuneTrainer(len(tokenizer), 63 | train_dataloader=data_loader, 64 | test_dataloader=data_loader, 65 | ) 66 | trainer.load(model_dir) 67 | test_multi_cell_type(trainer, tokenizer, data_loader, out_dir, df_train) 68 | 69 | print("Everything passed!") 70 | -------------------------------------------------------------------------------- /test/test_finetune.py: -------------------------------------------------------------------------------- 1 | from methylbert.data import finetune_data_generate as fdg 2 | from methylbert.data.dataset import MethylBertFinetuneDataset 3 | from methylbert.data.vocab import MethylVocab 4 | from methylbert.trainer import MethylBertFinetuneTrainer 5 | 6 | from torch.utils.data import DataLoader 7 | import pandas as pd 8 | import os, shutil, json 9 | 10 | def load_data(train_dataset: str, test_dataset: str, batch_size: int = 64, num_workers: int = 40): 11 | tokenizer=MethylVocab(k=3) 12 | 13 | # Load data sets 14 | train_dataset = MethylBertFinetuneDataset(train_dataset, tokenizer, seq_len=150) 15 | test_dataset = MethylBertFinetuneDataset(test_dataset, tokenizer, seq_len=150) 16 | 17 | if len(test_dataset) > 500: 18 | test_dataset.subset_data(500) 19 | 20 | # Create a data loader 21 | print("Creating Dataloader") 22 | local_step_batch_size = int(batch_size/4) 23 | print("Local step batch size : ", local_step_batch_size) 24 | 25 | train_data_loader = DataLoader(train_dataset, batch_size=local_step_batch_size, 26 | num_workers=num_workers, pin_memory=False, shuffle=True) 27 | 28 | test_data_loader = DataLoader(test_dataset, batch_size=local_step_batch_size, 29 | num_workers=num_workers, pin_memory=True, 30 | shuffle=False) if test_dataset is not None else None 31 | 32 | return tokenizer, train_data_loader, test_data_loader 33 | 34 | def test_finetune_no_pretrain(tokenizer : MethylVocab, 35 | save_path : str, 36 | train_data_loader : DataLoader, 37 | test_data_loader : DataLoader, 38 | pretrain_model : str, 39 | steps : int=10): 40 | 41 | trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer), 42 | save_path=save_path, 43 | train_dataloader=train_data_loader, 44 | test_dataloader=test_data_loader, 45 | with_cuda=False) 46 | 47 | trainer.create_model(config_file=os.path.join(pretrain_model, "config.json")) 48 | 49 | trainer.train(steps) 50 | 51 | assert os.path.exists(os.path.join(save_path, "config.json")) 52 | assert os.path.exists(os.path.join(save_path, "dmr_encoder.pickle")) 53 | #assert os.path.exists(os.path.join(save_path, "pytorch_model.bin")) 54 | assert os.path.exists(os.path.join(save_path, "read_classification_model.pickle")) 55 | assert os.path.exists(os.path.join(save_path, "eval.csv")) 56 | assert os.path.exists(os.path.join(save_path, "train.csv")) 57 | assert steps == pd.read_csv(os.path.join(save_path, "train.csv")).shape[0] 58 | 59 | def test_finetune_no_pretrain_focal(tokenizer : MethylVocab, 60 | save_path : str, 61 | train_data_loader : DataLoader, 62 | test_data_loader : DataLoader, 63 | pretrain_model : str, 64 | steps : int=10): 65 | 66 | trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer), 67 | save_path=save_path, 68 | train_dataloader=train_data_loader, 69 | test_dataloader=test_data_loader, 70 | with_cuda=False, 71 | loss="focal_bce") 72 | 73 | trainer.create_model(config_file=os.path.join(pretrain_model, "config.json")) 74 | 75 | trainer.train(steps) 76 | assert os.path.exists(os.path.exists(os.path.join(save_path, "config.json"))) 77 | 78 | with open(os.path.join(save_path, "config.json")) as fp: 79 | config = json.load(fp) 80 | assert config["loss"] == "focal_bce" 81 | 82 | def test_finetune_no_pretrain_focal(tokenizer : MethylVocab, 83 | save_path : str, 84 | train_data_loader : DataLoader, 85 | test_data_loader : DataLoader, 86 | pretrain_model : str, 87 | steps : int=10): 88 | 89 | trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer), 90 | save_path=save_path, 91 | train_dataloader=train_data_loader, 92 | test_dataloader=test_data_loader, 93 | with_cuda=False, 94 | loss="focal_bce") 95 | 96 | trainer.create_model(config_file=os.path.join(pretrain_model, "config.json")) 97 | 98 | trainer.train(steps) 99 | assert os.path.exists(os.path.exists(os.path.join(save_path, "config.json"))) 100 | 101 | with open(os.path.join(save_path, "config.json")) as fp: 102 | config = json.load(fp) 103 | assert config["loss"] == "focal_bce" 104 | 105 | 106 | def test_finetune_focal_multicelltype(tokenizer : MethylVocab, 107 | save_path : str, 108 | train_data_loader : DataLoader, 109 | test_data_loader : DataLoader, 110 | pretrain_model : str, 111 | steps : int=10): 112 | 113 | trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer), 114 | save_path=save_path+"bert.model/", 115 | train_dataloader=train_data_loader, 116 | test_dataloader=test_data_loader, 117 | with_cuda=False, 118 | loss="focal_bce") 119 | trainer.load(pretrain_model) 120 | trainer.train(steps) 121 | 122 | assert os.path.exists(os.path.join(save_path, "bert.model/config.json")) 123 | assert os.path.exists(os.path.join(save_path, "bert.model/dmr_encoder.pickle")) 124 | #assert os.path.exists(os.path.join(save_path, "bert.model/pytorch_model.bin")) 125 | assert os.path.exists(os.path.join(save_path, "bert.model/read_classification_model.pickle")) 126 | 127 | def test_finetune(tokenizer : MethylVocab, 128 | save_path : str, 129 | train_data_loader : DataLoader, 130 | test_data_loader : DataLoader, 131 | pretrain_model : str, 132 | steps : int=10): 133 | 134 | trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer), 135 | save_path=save_path+"bert.model/", 136 | train_dataloader=train_data_loader, 137 | test_dataloader=test_data_loader, 138 | with_cuda=False) 139 | trainer.load(pretrain_model) 140 | trainer.train(steps) 141 | 142 | assert os.path.exists(os.path.join(save_path, "bert.model/config.json")) 143 | assert os.path.exists(os.path.join(save_path, "bert.model/dmr_encoder.pickle")) 144 | #assert os.path.exists(os.path.join(save_path, "bert.model/pytorch_model.bin")) 145 | assert os.path.exists(os.path.join(save_path, "bert.model/read_classification_model.pickle")) 146 | 147 | def reset_dir(dirname): 148 | if os.path.exists(dirname): 149 | shutil.rmtree(dirname) 150 | os.mkdir(dirname) 151 | 152 | if __name__=="__main__": 153 | # For data processing 154 | f_bam_list = "data/bam_list.txt" 155 | f_dmr = "data/dmrs.csv" 156 | f_ref = "data/genome.fa" 157 | out_dir = "data/processed/" 158 | 159 | # Process data for fine-tuning 160 | fdg.finetune_data_generate( 161 | sc_dataset = f_bam_list, 162 | f_dmr = f_dmr, 163 | f_ref = f_ref, 164 | output_dir=out_dir, 165 | split_ratio = 0.7, 166 | n_cores=10, 167 | n_dmrs=10 168 | ) 169 | 170 | tokenizer, train_data_loader, test_data_loader = \ 171 | load_data(train_dataset = os.path.join(out_dir, "train_seq.csv"), 172 | test_dataset = os.path.join(out_dir, "test_seq.csv")) 173 | 174 | # For fine-tuning 175 | model_dir="data/pretrained_model/" 176 | save_path="res/" 177 | train_step=3 178 | 179 | # Test 180 | # amp warning issue 181 | # https://github.com/pytorch/pytorch/issues/67598 182 | 183 | reset_dir(save_path) 184 | test_finetune(tokenizer, save_path, train_data_loader, test_data_loader, "hanyangii/methylbert_hg19_2l", train_step) 185 | test_finetune(tokenizer, save_path, train_data_loader, test_data_loader, model_dir, train_step) 186 | 187 | reset_dir(save_path) 188 | test_finetune_no_pretrain(tokenizer, save_path, train_data_loader, test_data_loader, model_dir, train_step) 189 | 190 | reset_dir(save_path) 191 | test_finetune_no_pretrain_focal(tokenizer, save_path, train_data_loader, test_data_loader, model_dir, train_step) 192 | 193 | #TODO 194 | #reset_dir(save_path) 195 | #test_finetune_savefreq(tokenizer, save_path, train_data_loader, test_data_loader, model_dir, train_step, save_freq=1) 196 | 197 | # Multiple cell type 198 | out_dir="data/multi_cell_type/" 199 | tokenizer, train_data_loader, test_data_loader = \ 200 | load_data(train_dataset = os.path.join(out_dir, "train_seq.csv"), 201 | test_dataset = os.path.join(out_dir, "test_seq.csv")) 202 | 203 | # For fine-tuning 204 | model_dir="data/pretrained_model/" 205 | save_path="data/multi_cell_type/res/" 206 | 207 | reset_dir(save_path) 208 | test_finetune_focal_multicelltype(tokenizer, save_path, train_data_loader, test_data_loader, model_dir) 209 | 210 | print("Everything passed!") 211 | -------------------------------------------------------------------------------- /test/test_finetune_preprocess.py: -------------------------------------------------------------------------------- 1 | from methylbert.data import finetune_data_generate as fdg 2 | import pandas as pd 3 | import os 4 | 5 | def test_split_ratio(bam_file: str, f_dmr: str, f_ref: str, out_dir = "tmp/", split_ratio=0.5): 6 | fdg.finetune_data_generate( 7 | input_file = bam_file, 8 | f_dmr = f_dmr, 9 | f_ref = f_ref, 10 | output_dir=out_dir, 11 | split_ratio = split_ratio, 12 | n_cores=1 13 | ) 14 | 15 | assert os.path.exists(out_dir+"train_seq.csv") 16 | assert os.path.exists(out_dir+"test_seq.csv") 17 | assert os.path.exists(out_dir+"dmrs.csv") 18 | 19 | n_test_seqs = pd.read_csv(out_dir+"test_seq.csv", sep="\t").shape[0] 20 | n_train_seqs = pd.read_csv(out_dir+"train_seq.csv", sep="\t").shape[0] 21 | assert (n_train_seqs/(n_train_seqs+n_test_seqs) <= split_ratio + 0.05) and (n_train_seqs/(n_train_seqs+n_test_seqs) >= split_ratio - 0.05) 22 | 23 | print("test_split_ratio passed!") 24 | 25 | def test_multi_cores(bam_file: str, f_dmr: str, f_ref: str, out_dir = "tmp/", n_cores=4): 26 | fdg.finetune_data_generate( 27 | input_file = bam_file, 28 | f_dmr = f_dmr, 29 | f_ref = f_ref, 30 | output_dir=out_dir, 31 | split_ratio = 1.0, 32 | n_cores=n_cores 33 | ) 34 | 35 | assert os.path.exists(out_dir+"data.csv") 36 | assert os.path.exists(out_dir+"dmrs.csv") 37 | 38 | print("test_multi_cores passed!") 39 | 40 | def test_dmr_subset(bam_file: str, f_dmr: str, f_ref: str, out_dir = "tmp/", n_dmrs=10): 41 | fdg.finetune_data_generate( 42 | input_file = bam_file, 43 | f_dmr = f_dmr, 44 | f_ref = f_ref, 45 | output_dir=out_dir, 46 | split_ratio = 1.0, 47 | n_cores=1, 48 | n_dmrs=n_dmrs 49 | ) 50 | 51 | assert pd.read_csv(out_dir+"dmrs.csv", sep="\t").shape[0] == n_dmrs 52 | 53 | print("test_dmr_subset passed!") 54 | 55 | def test_list_bam_file(f_bam_file_list: str, f_dmr: str, f_ref: str, out_dir = "tmp/"): 56 | fdg.finetune_data_generate( 57 | sc_dataset = f_bam_file_list, 58 | f_dmr = f_dmr, 59 | f_ref = f_ref, 60 | output_dir=out_dir, 61 | split_ratio = 1.0, 62 | n_cores=1 63 | ) 64 | 65 | assert os.path.exists(out_dir+"data.csv") 66 | assert os.path.exists(out_dir+"dmrs.csv") 67 | 68 | res = pd.read_csv(out_dir+"data.csv", sep="\t") 69 | assert "T" in res["ctype"].tolist() 70 | assert "N" in res["ctype"].tolist() 71 | 72 | print("test_list_bam_file passed!") 73 | 74 | def test_single_bam_file(bam_file: str, f_dmr: str, f_ref: str, out_dir = "tmp/"): 75 | fdg.finetune_data_generate( 76 | input_file = bam_file, 77 | f_dmr = f_dmr, 78 | f_ref = f_ref, 79 | output_dir=out_dir, 80 | split_ratio = 1.0, 81 | n_cores=1 82 | ) 83 | 84 | assert os.path.exists(out_dir+"data.csv") 85 | assert os.path.exists(out_dir+"dmrs.csv") 86 | 87 | print("test_single_bam_file passed!") 88 | 89 | def test_dorado_aligned_file(bam_file: str, f_dmr: str, f_ref: str, out_dir = "tmp/"): 90 | fdg.finetune_data_generate( 91 | input_file = bam_file, 92 | f_dmr = f_dmr, 93 | f_ref = f_ref, 94 | output_dir=out_dir, 95 | split_ratio = 1.0, 96 | n_cores=1, 97 | methyl_caller="dorado" 98 | ) 99 | 100 | assert os.path.exists(out_dir+"data.csv") 101 | assert os.path.exists(out_dir+"dmrs.csv") 102 | 103 | print("test_dorado_aligned_file passed!") 104 | 105 | 106 | def test_multi_cell_type(f_bam_file_list: str, f_dmr: str, f_ref: str, out_dir = "tmp/"): 107 | fdg.finetune_data_generate( 108 | sc_dataset = f_bam_file_list, 109 | f_dmr = f_dmr, 110 | f_ref = f_ref, 111 | output_dir=out_dir, 112 | split_ratio = 0.8, 113 | n_cores=1 114 | ) 115 | 116 | assert os.path.exists(out_dir+"train_seq.csv") 117 | assert os.path.exists(out_dir+"test_seq.csv") 118 | assert os.path.exists(out_dir+"dmrs.csv") 119 | 120 | res = pd.read_csv(out_dir+"train_seq.csv", sep="\t") 121 | assert "T" in res["ctype"].tolist() 122 | assert "N" in res["ctype"].tolist() 123 | assert "P" in res["ctype"].tolist() 124 | 125 | print("test_multi_cell_type passed!") 126 | 127 | def test_ncbi_genome_style(bam_file: str, f_dmr: str, f_ref: str, out_dir = "tmp/"): 128 | 129 | new_dmr_file = "data/ncbi_dmr.csv" 130 | df_dmr = pd.read_csv(f_dmr, sep="\t", index_col=False) 131 | df_dmr.loc[:, "chr"] = df_dmr["chr"].apply(lambda x : x.split("hr")[1]) 132 | df_dmr.to_csv(new_dmr_file, sep="\t") 133 | 134 | fdg.finetune_data_generate( 135 | input_file = bam_file, 136 | f_dmr = new_dmr_file, 137 | f_ref = f_ref, 138 | output_dir=out_dir, 139 | split_ratio = 1.0, 140 | n_cores=1 141 | ) 142 | 143 | assert os.path.exists(out_dir+"data.csv") 144 | assert os.path.exists(out_dir+"dmrs.csv") 145 | 146 | print("test_single_bam_file passed!") 147 | 148 | 149 | if __name__=="__main__": 150 | f_bam = "data/T_sample.bam" 151 | f_bam_list = "data/bam_list.txt" 152 | f_dmr = "data/dmrs.csv" 153 | f_ref = "data/genome.fa" 154 | 155 | test_single_bam_file(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref) 156 | test_list_bam_file(f_bam_file_list = f_bam_list, f_dmr=f_dmr, f_ref=f_ref) 157 | test_dmr_subset(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref, n_dmrs=10) 158 | test_multi_cores(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref, n_cores=4) 159 | test_split_ratio(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref, split_ratio=0.7) 160 | 161 | 162 | f_bam_list = "data/multi_cell_type/bam_list.txt" 163 | f_dmr = "data/multi_cell_type/dmrs.csv" 164 | out_dir = "data/multi_cell_type/" 165 | test_multi_cell_type(f_bam_file_list = f_bam_list, f_dmr=f_dmr, f_ref=f_ref, out_dir=out_dir) 166 | 167 | f_dorado = "data/dorado_aligned.bam" 168 | f_ref_hg38="data/hg38_genome.fa" 169 | f_bam = "data/ncbi_genome_sample.bam" 170 | test_ncbi_genome_style(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref_hg38) 171 | test_dorado_aligned_file(bam_file = f_dorado, f_dmr=f_dmr, f_ref=f_ref_hg38) 172 | 173 | print("Everything passed!") 174 | -------------------------------------------------------------------------------- /tutorials/01_Data_Preparation.md: -------------------------------------------------------------------------------- 1 | # Data Preparation for your own BAM/SAM file to run _MethylBERT_ 2 | 3 | ## Input requirements 4 | 5 | In order to run _MethylBERT_, these files are required: 6 | 1. Input bulk sample as a BAM/SAM file 7 | 2. Reference genome as a FASTA file 8 | 3. DMRs as a tab-separated .csv file 9 | 4. (Optional, in case you want to fine-tune the MethylBERT model with your data) Pure tumour and normal samples as BAM/SAM files 10 | 11 | #### 1. BAM/SAM File format 12 | _MethylBERT_ currently supports only [bismark](https://www.bioinformatics.babraham.ac.uk/projects/bismark/)-aligned samples where read-level methylation calls are given with `XM` tag. `XM` tage stores methylation calls as follows: 13 | - `x` : Unmethylated cytosine at CHH 14 | - `X` : Methylated cytosine at CHH 15 | - `h` : Unmethylated cytosine at CHG context 16 | - `H` : Methylated cytosine at CHG context 17 | - `z` : Unmethylated cytosine at CpG context 18 | - `Z` : Methylated cytosine at CpG context 19 | 20 | Each sequence read has its methylation call with `XM` tag like: 21 | ``` 22 | SRR5390326.sra.2060072_2060072_length=150 16 chr1 3000485 42 118M * 0 0 23 | AATTTCAACTCTAAATTTAATTATTTCCTACTATCTACTCATCTTAAATAAATTTACTTCCTTTTATTCTAAAACTTCTAAATTTACTATCAAACTACTAATATATACTCTAATTTCC 24 | JA-FFJJJFJJJJJJJJJJJJJJFJJJJJJJJJFJJJJFJJFJJJJJJJJJJJJJJJJFJJJJJJJJJJJJJJJFJJFJFJFJJJFJJJJJJJFJJAJJ\n", 186 | "\n", 199 | "\n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | "
nameflagref_nameref_posmap_qualitycigarnext_ref_namenext_ref_poslengthseq...PGXGNMXMXRdna_seqmethyl_seqdmr_ctypedmr_labelctype
0SRR10165464.6790597_6790597_length=15183chr217694354140151M=176943475-217AATTAACAATTTTCATCATAATCTACACATTATTAACATCAAACTT......MarkDuplicatesGA37h...hh........z.........x..........h.............CTGAT ATT TTG TGG GGC GCA CAA AAT ATT TTT TTT TT...2222222222220222222222222222222222222222222222...T12N
1SRR10165994.18752987_18752987_length=149163chr715748661640149M=157486650183AGGCACGCGACCACCCTAAACCTCGAACAAAACTAAAAAAACGCAA......MarkDuplicatesGA51..Z...Z.Zx.......xhh....Zx...xhh...hhhhh..Z..x...GACCG CGC GCA CAC ACG CGC GCG CGG GGC GCC CCA CA...1222121222222222222222122222222222222222122222...T11T
2SRR10165994.2935274_2935274_length=15083chr7127022242150M=1269981-391ACGAACATTAAAACGCACGGAACCGCCGCGACGCGGACTCGCTCTT......MarkDuplicatesGA27h.Z.h....hhh..Z...ZX.h..Z..Z.Zx.Z.ZX....Z........CTGCG CGA GAG AGC GCA CAT ATT TTG TGG GGG GGA GA...1222222222221222122222122121221212222212222222...T1T
3SRR10165464.56090327_56090327_length=151163chr217694951142149M=176949602242AGGATTTCTTACTACATAACCACAAAAATACATTAAACCCACACCT......MarkDuplicatesGA36h.Z.......h....z.hh..z.zx.hh.h....hhh...z.z......GAGCG CGC GCT CTT TTT TTC TCT CTT TTG TGC GCT CT...1222222222222022222020222222222222222202022222...T12N
4SRR10165464.47924911_47924911_length=150147chr7127248042151M=1272378-253AATTATTGGGAGTTTGATGTTGATAAGTAAAGTGTTGGAGTGTGGG......MarkDuplicatesCT31......z.....h...................z.xz......z......GAAAT ATT TTA TAT ATC TCG CGG GGG GGA GAG AGC GC...2222202222222222222222222222222022022222202220...T1N
\n", 349 | "

5 rows × 22 columns

\n", 350 | "" 351 | ], 352 | "text/plain": [ 353 | " name flag ref_name ref_pos \\\n", 354 | "0 SRR10165464.6790597_6790597_length=151 83 chr2 176943541 \n", 355 | "1 SRR10165994.18752987_18752987_length=149 163 chr7 157486616 \n", 356 | "2 SRR10165994.2935274_2935274_length=150 83 chr7 1270222 \n", 357 | "3 SRR10165464.56090327_56090327_length=151 163 chr2 176949511 \n", 358 | "4 SRR10165464.47924911_47924911_length=150 147 chr7 1272480 \n", 359 | "\n", 360 | " map_quality cigar next_ref_name next_ref_pos length \\\n", 361 | "0 40 151M = 176943475 -217 \n", 362 | "1 40 149M = 157486650 183 \n", 363 | "2 42 150M = 1269981 -391 \n", 364 | "3 42 149M = 176949602 242 \n", 365 | "4 42 151M = 1272378 -253 \n", 366 | "\n", 367 | " seq ... PG XG \\\n", 368 | "0 AATTAACAATTTTCATCATAATCTACACATTATTAACATCAAACTT... ... MarkDuplicates GA \n", 369 | "1 AGGCACGCGACCACCCTAAACCTCGAACAAAACTAAAAAAACGCAA... ... MarkDuplicates GA \n", 370 | "2 ACGAACATTAAAACGCACGGAACCGCCGCGACGCGGACTCGCTCTT... ... MarkDuplicates GA \n", 371 | "3 AGGATTTCTTACTACATAACCACAAAAATACATTAAACCCACACCT... ... MarkDuplicates GA \n", 372 | "4 AATTATTGGGAGTTTGATGTTGATAAGTAAAGTGTTGGAGTGTGGG... ... MarkDuplicates CT \n", 373 | "\n", 374 | " NM XM XR \\\n", 375 | "0 37 h...hh........z.........x..........h............. CT \n", 376 | "1 51 ..Z...Z.Zx.......xhh....Zx...xhh...hhhhh..Z..x... GA \n", 377 | "2 27 h.Z.h....hhh..Z...ZX.h..Z..Z.Zx.Z.ZX....Z........ CT \n", 378 | "3 36 h.Z.......h....z.hh..z.zx.hh.h....hhh...z.z...... GA \n", 379 | "4 31 ......z.....h...................z.xz......z...... GA \n", 380 | "\n", 381 | " dna_seq \\\n", 382 | "0 GAT ATT TTG TGG GGC GCA CAA AAT ATT TTT TTT TT... \n", 383 | "1 CCG CGC GCA CAC ACG CGC GCG CGG GGC GCC CCA CA... \n", 384 | "2 GCG CGA GAG AGC GCA CAT ATT TTG TGG GGG GGA GA... \n", 385 | "3 GCG CGC GCT CTT TTT TTC TCT CTT TTG TGC GCT CT... \n", 386 | "4 AAT ATT TTA TAT ATC TCG CGG GGG GGA GAG AGC GC... \n", 387 | "\n", 388 | " methyl_seq dmr_ctype dmr_label ctype \n", 389 | "0 2222222222220222222222222222222222222222222222... T 12 N \n", 390 | "1 1222121222222222222222122222222222222222122222... T 11 T \n", 391 | "2 1222222222221222122222122121221212222212222222... T 1 T \n", 392 | "3 1222222222222022222020222222222222222202022222... T 12 N \n", 393 | "4 2222202222222222222222222222222022022222202220... T 1 N \n", 394 | "\n", 395 | "[5 rows x 22 columns]" 396 | ] 397 | }, 398 | "execution_count": 11, 399 | "metadata": {}, 400 | "output_type": "execute_result" 401 | } 402 | ], 403 | "source": [ 404 | "import pandas as pd\n", 405 | "pd.read_csv(\"tmp/test_seq.csv\", sep='\\t').head()" 406 | ] 407 | } 408 | ], 409 | "metadata": { 410 | "kernelspec": { 411 | "display_name": "dnabert", 412 | "language": "python", 413 | "name": "dnabert" 414 | }, 415 | "language_info": { 416 | "codemirror_mode": { 417 | "name": "ipython", 418 | "version": 3 419 | }, 420 | "file_extension": ".py", 421 | "mimetype": "text/x-python", 422 | "name": "python", 423 | "nbconvert_exporter": "python", 424 | "pygments_lexer": "ipython3", 425 | "version": "3.6.13" 426 | } 427 | }, 428 | "nbformat": 4, 429 | "nbformat_minor": 5 430 | } 431 | -------------------------------------------------------------------------------- /tutorials/03_Preprocessing_bulk_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "715b70e8-5218-4542-bb99-33d72ecb44d6", 6 | "metadata": {}, 7 | "source": [ 8 | "# Preprocessing for bulk data \n", 9 | "\n", 10 | "The bulk sample you want to deconvolute using _MethylBERT_ also needs to be preprocessed using `finetune_data_generate` function. " 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "e1168bb7-e72e-43ba-b858-23d6dc05abf0", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "DMRs sorted by areaStat\n", 24 | " chr start end length nCG meanMethy1 meanMethy2 \\\n", 25 | "1 chr10 134597480 134602875.0 5396 670 0.861029 0.140400 \n", 26 | "0 chr7 1268957 1277884.0 8928 753 0.793278 0.129747 \n", 27 | "2 chr4 1395812 1402597.0 6786 663 0.831162 0.185272 \n", 28 | "5 chr16 54962053 54967980.0 5928 546 0.783631 0.096095 \n", 29 | "9 chr18 76736906 76741580.0 4675 510 0.829475 0.104403 \n", 30 | "\n", 31 | " diff.Methy areaStat abs_areaStat abs_diff.Methy ctype dmr_id \n", 32 | "1 0.720629 6144.089331 6144.089331 0.720629 T 0 \n", 33 | "0 0.663531 5722.091790 5722.091790 0.663531 T 1 \n", 34 | "2 0.645891 4941.410089 4941.410089 0.645891 T 2 \n", 35 | "5 0.687536 4714.551799 4714.551799 0.687536 T 3 \n", 36 | "9 0.725072 4684.608381 4684.608381 0.725072 T 4 \n", 37 | "Number of DMRs to extract sequence reads: 20\n", 38 | "Fine-tuning data generated: name flag ref_name ref_pos \\\n", 39 | "0 SRR10166000.9089788_9089788_length=151 147 chr10 131767360 \n", 40 | "1 SRR10165998.65829390_65829390_length=150 163 chr4 20254248 \n", 41 | "2 SRR10165467.85837758_85837758_length=151 99 chr4 1401206 \n", 42 | "3 SRR10165995.16747267_16747267_length=149 83 chr2 176945656 \n", 43 | "4 SRR10165995.46034072_46034072_length=151 99 chr4 20253524 \n", 44 | "\n", 45 | " map_quality cigar next_ref_name next_ref_pos length \\\n", 46 | "0 42 151M = 131767187 -324 \n", 47 | "1 23 151M = 20254343 244 \n", 48 | "2 40 151M = 1401285 227 \n", 49 | "3 40 149M = 176945572 -233 \n", 50 | "4 40 151M = 20253771 398 \n", 51 | "\n", 52 | " seq ... NM \\\n", 53 | "0 GTGGAGTGTCGTTGCGTAGTCGGGAGTCGGGAGTAGAATAGTTTGG... ... 49 \n", 54 | "1 GGGGATTCTACCTTTACCATCAAATATCTACCGCGAAACTACGACT... ... 35 \n", 55 | "2 AAAATGAGAGATTGTTTGTTTTTTTTAATTTGTTTTTAAAAGGGGG... ... 40 \n", 56 | "3 AAATAACTTAATCTACTTCTCTCCGACCAAACCCAACCCCAAATAC... ... 35 \n", 57 | "4 TCGGATTTGGTGTTATTTATTTGGGAAGCGTCCGGACGGCGGAGCT... ... 2 \n", 58 | "\n", 59 | " XM XR \\\n", 60 | "0 ........xZ.x..Z.x..xZ.....xZ.....x....x..hx...... GA \n", 61 | "1 H..............h......xh.h...x..Z.Zx.h..x.Zx..... GA \n", 62 | "2 ...........x..h....hhh.h....hxz.hhhhh............ CT \n", 63 | "3 x...hh...hh.............Z.....h.........z.h...... CT \n", 64 | "4 .Z...h......................Z.hXZ...Z..Z....H.... CT \n", 65 | "\n", 66 | " PG RG \\\n", 67 | "0 MarkDuplicates-287B47C6 diffuse_large_B_cell_lymphoma_test_8 \n", 68 | "1 MarkDuplicates-3DAAB091 diffuse_large_B_cell_lymphoma_test_8 \n", 69 | "2 MarkDuplicates-36E4BA78 Bcell_noncancer_test_8 \n", 70 | "3 MarkDuplicates-74536757 diffuse_large_B_cell_lymphoma_test_8 \n", 71 | "4 MarkDuplicates-74536757 diffuse_large_B_cell_lymphoma_test_8 \n", 72 | "\n", 73 | " dna_seq \\\n", 74 | "0 GTG TGG GGA GAG AGT GTG TGC GCC CCG CGC GCT CT... \n", 75 | "1 GTT TTT TTC TCT CTT TTC TCT CTA TAC ACC CCT CT... \n", 76 | "2 AAA AAA AAT ATG TGA GAG AGA GAG AGA GAC ACT CT... \n", 77 | "3 GAA AAT ATG TGG GGC GCT CTT TTG TGG GGT GTC TC... \n", 78 | "4 TCG CGG GGA GAC ACT CTT TTG TGG GGT GTG TGT GT... \n", 79 | "\n", 80 | " methyl_seq dmr_ctype dmr_label ctype \n", 81 | "0 2222222212222122222122222212222222222222222222... T 5 NA \n", 82 | "1 2222222222222222222222222222221212222222122222... T 19 NA \n", 83 | "2 2222222222222222222222222222202222222222222222... T 2 NA \n", 84 | "3 2222222222222222222222122222222222222202222222... T 12 NA \n", 85 | "4 1222222222222222222222222221222122212212222222... T 19 NA \n", 86 | "\n", 87 | "[5 rows x 23 columns]\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "from methylbert.data import finetune_data_generate as fdg\n", 93 | "\n", 94 | "f_bam = \"../test/data/bulk.bam\"\n", 95 | "f_dmr = \"../test/data/dmrs.csv\"\n", 96 | "f_ref = \"../../../genome/hg19.fa\"\n", 97 | "out_dir = \"tmp/\"\n", 98 | "\n", 99 | "fdg.finetune_data_generate(\n", 100 | " input_file = f_bam,\n", 101 | " f_dmr = f_dmr,\n", 102 | " f_ref = f_ref,\n", 103 | " output_dir=out_dir,\n", 104 | " n_mers=3, # 3-mer DNA sequences \n", 105 | " n_cores=20\n", 106 | ")" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "id": "01a4e20f-6589-4111-b3ce-b88956dfe926", 112 | "metadata": {}, 113 | "source": [ 114 | "This process generates a new file `data.csv` where the preprocessed bulk data is contained. " 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 4, 120 | "id": "681b9ccf-5afc-42c3-b2d9-0b7995effbd7", 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "data.csv dmrs.csv test_seq.csv train_seq.csv\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "ls tmp/" 133 | ] 134 | }, 135 | { 136 | "cell_type": "markdown", 137 | "id": "a709eb70-0dd2-446c-907b-a2dff96c016f", 138 | "metadata": {}, 139 | "source": [ 140 | "Since the cell-type information is not given with the bulk sample, `ctype` column only contains `NaN` value. " 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 3, 146 | "id": "c52e4fef-32bb-4f68-9710-43cf8f1c76c7", 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "data": { 151 | "text/html": [ 152 | "
\n", 153 | "\n", 166 | "\n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | "
nameflagref_nameref_posmap_qualitycigarnext_ref_namenext_ref_poslengthseq...NMXMXRPGRGdna_seqmethyl_seqdmr_ctypedmr_labelctype
0SRR10166000.9089788_9089788_length=151147chr1013176736042151M=131767187-324GTGGAGTGTCGTTGCGTAGTCGGGAGTCGGGAGTAGAATAGTTTGG......49........xZ.x..Z.x..xZ.....xZ.....x....x..hx......GAMarkDuplicates-287B47C6diffuse_large_B_cell_lymphoma_test_8GTG TGG GGA GAG AGT GTG TGC GCC CCG CGC GCT CT...2222222212222122222122222212222222222222222222...T5NaN
1SRR10165998.65829390_65829390_length=150163chr42025424823151M=20254343244GGGGATTCTACCTTTACCATCAAATATCTACCGCGAAACTACGACT......35H..............h......xh.h...x..Z.Zx.h..x.Zx.....GAMarkDuplicates-3DAAB091diffuse_large_B_cell_lymphoma_test_8GTT TTT TTC TCT CTT TTC TCT CTA TAC ACC CCT CT...2222222222222222222222222222221212222222122222...T19NaN
2SRR10165467.85837758_85837758_length=15199chr4140120640151M=1401285227AAAATGAGAGATTGTTTGTTTTTTTTAATTTGTTTTTAAAAGGGGG......40...........x..h....hhh.h....hxz.hhhhh............CTMarkDuplicates-36E4BA78Bcell_noncancer_test_8AAA AAA AAT ATG TGA GAG AGA GAG AGA GAC ACT CT...2222222222222222222222222222202222222222222222...T2NaN
3SRR10165995.16747267_16747267_length=14983chr217694565640149M=176945572-233AAATAACTTAATCTACTTCTCTCCGACCAAACCCAACCCCAAATAC......35x...hh...hh.............Z.....h.........z.h......CTMarkDuplicates-74536757diffuse_large_B_cell_lymphoma_test_8GAA AAT ATG TGG GGC GCT CTT TTG TGG GGT GTC TC...2222222222222222222222122222222222222202222222...T12NaN
4SRR10165995.46034072_46034072_length=15199chr42025352440151M=20253771398TCGGATTTGGTGTTATTTATTTGGGAAGCGTCCGGACGGCGGAGCT......2.Z...h......................Z.hXZ...Z..Z....H....CTMarkDuplicates-74536757diffuse_large_B_cell_lymphoma_test_8TCG CGG GGA GAC ACT CTT TTG TGG GGT GTG TGT GT...1222222222222222222222222221222122212212222222...T19NaN
\n", 316 | "

5 rows × 23 columns

\n", 317 | "
" 318 | ], 319 | "text/plain": [ 320 | " name flag ref_name ref_pos \\\n", 321 | "0 SRR10166000.9089788_9089788_length=151 147 chr10 131767360 \n", 322 | "1 SRR10165998.65829390_65829390_length=150 163 chr4 20254248 \n", 323 | "2 SRR10165467.85837758_85837758_length=151 99 chr4 1401206 \n", 324 | "3 SRR10165995.16747267_16747267_length=149 83 chr2 176945656 \n", 325 | "4 SRR10165995.46034072_46034072_length=151 99 chr4 20253524 \n", 326 | "\n", 327 | " map_quality cigar next_ref_name next_ref_pos length \\\n", 328 | "0 42 151M = 131767187 -324 \n", 329 | "1 23 151M = 20254343 244 \n", 330 | "2 40 151M = 1401285 227 \n", 331 | "3 40 149M = 176945572 -233 \n", 332 | "4 40 151M = 20253771 398 \n", 333 | "\n", 334 | " seq ... NM \\\n", 335 | "0 GTGGAGTGTCGTTGCGTAGTCGGGAGTCGGGAGTAGAATAGTTTGG... ... 49 \n", 336 | "1 GGGGATTCTACCTTTACCATCAAATATCTACCGCGAAACTACGACT... ... 35 \n", 337 | "2 AAAATGAGAGATTGTTTGTTTTTTTTAATTTGTTTTTAAAAGGGGG... ... 40 \n", 338 | "3 AAATAACTTAATCTACTTCTCTCCGACCAAACCCAACCCCAAATAC... ... 35 \n", 339 | "4 TCGGATTTGGTGTTATTTATTTGGGAAGCGTCCGGACGGCGGAGCT... ... 2 \n", 340 | "\n", 341 | " XM XR \\\n", 342 | "0 ........xZ.x..Z.x..xZ.....xZ.....x....x..hx...... GA \n", 343 | "1 H..............h......xh.h...x..Z.Zx.h..x.Zx..... GA \n", 344 | "2 ...........x..h....hhh.h....hxz.hhhhh............ CT \n", 345 | "3 x...hh...hh.............Z.....h.........z.h...... CT \n", 346 | "4 .Z...h......................Z.hXZ...Z..Z....H.... CT \n", 347 | "\n", 348 | " PG RG \\\n", 349 | "0 MarkDuplicates-287B47C6 diffuse_large_B_cell_lymphoma_test_8 \n", 350 | "1 MarkDuplicates-3DAAB091 diffuse_large_B_cell_lymphoma_test_8 \n", 351 | "2 MarkDuplicates-36E4BA78 Bcell_noncancer_test_8 \n", 352 | "3 MarkDuplicates-74536757 diffuse_large_B_cell_lymphoma_test_8 \n", 353 | "4 MarkDuplicates-74536757 diffuse_large_B_cell_lymphoma_test_8 \n", 354 | "\n", 355 | " dna_seq \\\n", 356 | "0 GTG TGG GGA GAG AGT GTG TGC GCC CCG CGC GCT CT... \n", 357 | "1 GTT TTT TTC TCT CTT TTC TCT CTA TAC ACC CCT CT... \n", 358 | "2 AAA AAA AAT ATG TGA GAG AGA GAG AGA GAC ACT CT... \n", 359 | "3 GAA AAT ATG TGG GGC GCT CTT TTG TGG GGT GTC TC... \n", 360 | "4 TCG CGG GGA GAC ACT CTT TTG TGG GGT GTG TGT GT... \n", 361 | "\n", 362 | " methyl_seq dmr_ctype dmr_label ctype \n", 363 | "0 2222222212222122222122222212222222222222222222... T 5 NaN \n", 364 | "1 2222222222222222222222222222221212222222122222... T 19 NaN \n", 365 | "2 2222222222222222222222222222202222222222222222... T 2 NaN \n", 366 | "3 2222222222222222222222122222222222222202222222... T 12 NaN \n", 367 | "4 1222222222222222222222222221222122212212222222... T 19 NaN \n", 368 | "\n", 369 | "[5 rows x 23 columns]" 370 | ] 371 | }, 372 | "execution_count": 3, 373 | "metadata": {}, 374 | "output_type": "execute_result" 375 | } 376 | ], 377 | "source": [ 378 | "import pandas as pd\n", 379 | "pd.read_csv(\"tmp/data.csv\", sep=\"\\t\").head()" 380 | ] 381 | } 382 | ], 383 | "metadata": { 384 | "kernelspec": { 385 | "display_name": "dnabert", 386 | "language": "python", 387 | "name": "dnabert" 388 | }, 389 | "language_info": { 390 | "codemirror_mode": { 391 | "name": "ipython", 392 | "version": 3 393 | }, 394 | "file_extension": ".py", 395 | "mimetype": "text/x-python", 396 | "name": "python", 397 | "nbconvert_exporter": "python", 398 | "pygments_lexer": "ipython3", 399 | "version": "3.6.13" 400 | } 401 | }, 402 | "nbformat": 4, 403 | "nbformat_minor": 5 404 | } 405 | -------------------------------------------------------------------------------- /tutorials/05_tumour_deconvolution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "f48c7917-8d6c-47fd-9378-ace38e238cdc", 6 | "metadata": {}, 7 | "source": [ 8 | "# Tumour deconvolution with the fine-tuned `MethylBERT` model \n", 9 | "\n", 10 | "### Load the bulk data and the fine-tuned model\n", 11 | "\n", 12 | "Please load your preprocessed bulk data following the [tutorial](https://github.com/hanyangii/methylbert/blob/main/tutorials/03_Preprocessing_bulk_data.ipynb) into the `MethylBertFinetuneDataset` and `DataLoader`." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "57ed608b-9a4a-42a7-b38a-7f1831f51d2c", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from methylbert.utils import set_seed\n", 23 | "\n", 24 | "set_seed(42)\n", 25 | "seq_len=150\n", 26 | "n_mers=3\n", 27 | "batch_size=5\n", 28 | "num_workers=20\n", 29 | "output_path=\"tmp/deconvolution/\"" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "id": "7543b0f3-943f-4989-9ca5-67a1c2a27ac1", 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "Building Vocab\n", 43 | "Total number of sequences : 3070\n", 44 | "# of reads in each label: [3070.]\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "from methylbert.data.vocab import MethylVocab\n", 50 | "from methylbert.data.dataset import MethylBertFinetuneDataset\n", 51 | "from torch.utils.data import DataLoader\n", 52 | "\n", 53 | "tokenizer = MethylVocab(k=n_mers)\n", 54 | "dataset = MethylBertFinetuneDataset(\"tmp/data.csv\", \n", 55 | " tokenizer, \n", 56 | " seq_len=seq_len)\n", 57 | "data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "8179433f-55d6-4f04-8d23-8a816ebdad48", 63 | "metadata": {}, 64 | "source": [ 65 | "### Load the fine-tuned `MethylBERT` model\n", 66 | "\n", 67 | "`load` function in `MethylBertFinetuneTrainer` automatically detects `config.json`, `pytorch_model.bin`, `dmr_encoder.pickle` and `read_classification_model.pickle` files in the given directory and load the fine-tuned model. " 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "id": "d732913f-01be-4bae-8064-aecdcd7b63d4", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "No detected GPU device. Load the model on CPU\n", 81 | "The model is loaded on CPU\n", 82 | "Restore the pretrained model tmp/fine_tune/\n", 83 | "Restore DMR encoder from tmp/fine_tune/dmr_encoder.pickle\n", 84 | "Restore read classification FCN model from tmp/fine_tune/read_classification_model.pickle\n", 85 | "Total Parameters: 32754130\n" 86 | ] 87 | } 88 | ], 89 | "source": [ 90 | "from methylbert.trainer import MethylBertFinetuneTrainer\n", 91 | "\n", 92 | "restore_dir = \"tmp/fine_tune/\"\n", 93 | "trainer = MethylBertFinetuneTrainer(len(tokenizer), \n", 94 | " train_dataloader=data_loader, \n", 95 | " test_dataloader=data_loader,\n", 96 | " )\n", 97 | "trainer.load(restore_dir) # Load the fine-tuned MethylBERT model" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "7aef54f0-8565-4e89-85bf-b42aa2b64206", 103 | "metadata": {}, 104 | "source": [ 105 | "### Deconvolution\n", 106 | "For the deconvolution, the training data as a `pandas.DataFrame` object is required for the maginal probability of cell types. " 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 4, 112 | "id": "270cb5c1-da7a-4cc3-812f-20c17643c230", 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "name": "stderr", 117 | "output_type": "stream", 118 | "text": [ 119 | " 0%| | 0/614 [00:00\n", 222 | "\n", 235 | "\n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | "
cell_typepred
0T0.0
1N1.0
\n", 256 | "" 257 | ], 258 | "text/plain": [ 259 | " cell_type pred\n", 260 | "0 T 0.0\n", 261 | "1 N 1.0" 262 | ] 263 | }, 264 | "execution_count": 5, 265 | "metadata": {}, 266 | "output_type": "execute_result" 267 | } 268 | ], 269 | "source": [ 270 | "import os\n", 271 | "pd.read_csv(os.path.join(output_path, \"deconvolution.csv\"), sep=\"\\t\").head()" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 6, 277 | "id": "6647712f-1d4a-4ba3-b770-83ec04e12fad", 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "data": { 282 | "text/html": [ 283 | "
\n", 284 | "\n", 297 | "\n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | "
fi
04.243251
\n", 311 | "
" 312 | ], 313 | "text/plain": [ 314 | " fi\n", 315 | "0 4.243251" 316 | ] 317 | }, 318 | "execution_count": 6, 319 | "metadata": {}, 320 | "output_type": "execute_result" 321 | } 322 | ], 323 | "source": [ 324 | "pd.read_csv(os.path.join(output_path, \"FI.csv\"), sep=\"\\t\").head()" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 7, 330 | "id": "dbf770b9-8cbc-476e-9603-0770ecbed7b5", 331 | "metadata": {}, 332 | "outputs": [ 333 | { 334 | "data": { 335 | "text/html": [ 336 | "
\n", 337 | "\n", 350 | "\n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | "
nameflagref_nameref_posmap_qualitycigarnext_ref_namenext_ref_poslengthseq...XMXRPGRGdna_seqmethyl_seqdmr_ctypedmr_labelctypepred
0SRR10166000.9089788_9089788_length=151147chr1013176736042151M=131767187-324GTGGAGTGTCGTTGCGTAGTCGGGAGTCGGGAGTAGAATAGTTTGG..............xZ.x..Z.x..xZ.....xZ.....x....x..hx......GAMarkDuplicates-287B47C6diffuse_large_B_cell_lymphoma_test_8GTGGAGTGCCGCTGCGCAGCCGGGAGCCGGGAGCAGAACAGCCTGG...2222222221222212222212222221222222222222222222...T5NaN0
1SRR10165998.65829390_65829390_length=150163chr42025424823151M=20254343244GGGGATTCTACCTTTACCATCAAATATCTACCGCGAAACTACGACT......H..............h......xh.h...x..Z.Zx.h..x.Zx.....GAMarkDuplicates-3DAAB091diffuse_large_B_cell_lymphoma_test_8GTTTCTTCTACCTTTGCCATCAGGTGTCTGCCGCGGAGCTGCGGCT...2222222222222222222222222222222121222222212222...T19NaN0
2SRR10165467.85837758_85837758_length=15199chr4140120640151M=1401285227AAAATGAGAGATTGTTTGTTTTTTTTAATTTGTTTTTAAAAGGGGG.................x..h....hhh.h....hxz.hhhhh............CTMarkDuplicates-36E4BA78Bcell_noncancer_test_8AAAATGAGAGACTGCTTGTCCCTCTTAACCCGCCCCCAAAAGGGGG...2222222222222222222222222222220222222222222222...T2NaN0
3SRR10165995.16747267_16747267_length=14983chr217694565640149M=176945572-233AAATAACTTAATCTACTTCTCTCCGACCAAACCCAACCCCAAATAC......x...hh...hh.............Z.....h.........z.h......CTMarkDuplicates-74536757diffuse_large_B_cell_lymphoma_test_8GAATGGCTTGGTCTACTTCTCTCCGACCAAGCCCAACCCCGAGTAC...2222222222222222222222212222222222222220222222...T12NaN0
4SRR10165995.46034072_46034072_length=15199chr42025352440151M=20253771398TCGGATTTGGTGTTATTTATTTGGGAAGCGTCCGGACGGCGGAGCT.......Z...h......................Z.hXZ...Z..Z....H....CTMarkDuplicates-74536757diffuse_large_B_cell_lymphoma_test_8TCGGACTTGGTGTTATTTATTTGGGAAGCGCCCGGACGGCGGAGCT...2122222222222222222222222222122212221221222222...T19NaN0
\n", 500 | "

5 rows × 24 columns

\n", 501 | "
" 502 | ], 503 | "text/plain": [ 504 | " name flag ref_name ref_pos \\\n", 505 | "0 SRR10166000.9089788_9089788_length=151 147 chr10 131767360 \n", 506 | "1 SRR10165998.65829390_65829390_length=150 163 chr4 20254248 \n", 507 | "2 SRR10165467.85837758_85837758_length=151 99 chr4 1401206 \n", 508 | "3 SRR10165995.16747267_16747267_length=149 83 chr2 176945656 \n", 509 | "4 SRR10165995.46034072_46034072_length=151 99 chr4 20253524 \n", 510 | "\n", 511 | " map_quality cigar next_ref_name next_ref_pos length \\\n", 512 | "0 42 151M = 131767187 -324 \n", 513 | "1 23 151M = 20254343 244 \n", 514 | "2 40 151M = 1401285 227 \n", 515 | "3 40 149M = 176945572 -233 \n", 516 | "4 40 151M = 20253771 398 \n", 517 | "\n", 518 | " seq ... \\\n", 519 | "0 GTGGAGTGTCGTTGCGTAGTCGGGAGTCGGGAGTAGAATAGTTTGG... ... \n", 520 | "1 GGGGATTCTACCTTTACCATCAAATATCTACCGCGAAACTACGACT... ... \n", 521 | "2 AAAATGAGAGATTGTTTGTTTTTTTTAATTTGTTTTTAAAAGGGGG... ... \n", 522 | "3 AAATAACTTAATCTACTTCTCTCCGACCAAACCCAACCCCAAATAC... ... \n", 523 | "4 TCGGATTTGGTGTTATTTATTTGGGAAGCGTCCGGACGGCGGAGCT... ... \n", 524 | "\n", 525 | " XM XR \\\n", 526 | "0 ........xZ.x..Z.x..xZ.....xZ.....x....x..hx...... GA \n", 527 | "1 H..............h......xh.h...x..Z.Zx.h..x.Zx..... GA \n", 528 | "2 ...........x..h....hhh.h....hxz.hhhhh............ CT \n", 529 | "3 x...hh...hh.............Z.....h.........z.h...... CT \n", 530 | "4 .Z...h......................Z.hXZ...Z..Z....H.... CT \n", 531 | "\n", 532 | " PG RG \\\n", 533 | "0 MarkDuplicates-287B47C6 diffuse_large_B_cell_lymphoma_test_8 \n", 534 | "1 MarkDuplicates-3DAAB091 diffuse_large_B_cell_lymphoma_test_8 \n", 535 | "2 MarkDuplicates-36E4BA78 Bcell_noncancer_test_8 \n", 536 | "3 MarkDuplicates-74536757 diffuse_large_B_cell_lymphoma_test_8 \n", 537 | "4 MarkDuplicates-74536757 diffuse_large_B_cell_lymphoma_test_8 \n", 538 | "\n", 539 | " dna_seq \\\n", 540 | "0 GTGGAGTGCCGCTGCGCAGCCGGGAGCCGGGAGCAGAACAGCCTGG... \n", 541 | "1 GTTTCTTCTACCTTTGCCATCAGGTGTCTGCCGCGGAGCTGCGGCT... \n", 542 | "2 AAAATGAGAGACTGCTTGTCCCTCTTAACCCGCCCCCAAAAGGGGG... \n", 543 | "3 GAATGGCTTGGTCTACTTCTCTCCGACCAAGCCCAACCCCGAGTAC... \n", 544 | "4 TCGGACTTGGTGTTATTTATTTGGGAAGCGCCCGGACGGCGGAGCT... \n", 545 | "\n", 546 | " methyl_seq dmr_ctype dmr_label \\\n", 547 | "0 2222222221222212222212222221222222222222222222... T 5 \n", 548 | "1 2222222222222222222222222222222121222222212222... T 19 \n", 549 | "2 2222222222222222222222222222220222222222222222... T 2 \n", 550 | "3 2222222222222222222222212222222222222220222222... T 12 \n", 551 | "4 2122222222222222222222222222122212221221222222... T 19 \n", 552 | "\n", 553 | " ctype pred \n", 554 | "0 NaN 0 \n", 555 | "1 NaN 0 \n", 556 | "2 NaN 0 \n", 557 | "3 NaN 0 \n", 558 | "4 NaN 0 \n", 559 | "\n", 560 | "[5 rows x 24 columns]" 561 | ] 562 | }, 563 | "execution_count": 7, 564 | "metadata": {}, 565 | "output_type": "execute_result" 566 | } 567 | ], 568 | "source": [ 569 | "pd.read_csv(os.path.join(output_path, \"res.csv\"), sep=\"\\t\").head()" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "id": "9ba25384-4d34-421f-9086-ac5b98088cc4", 575 | "metadata": {}, 576 | "source": [ 577 | "### Deconvolution with the estimate adjustment\n", 578 | "\n", 579 | "_MethylBERT_ supports the tumour purity estimation adjustment considering the different distribution of tumour-derived reads in DMRs. \n", 580 | "\n", 581 | "You can turn `adjustment` option on for this. " 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": 8, 587 | "id": "9345dd97-63c5-4e79-a52d-1786ae1faaa3", 588 | "metadata": {}, 589 | "outputs": [ 590 | { 591 | "name": "stderr", 592 | "output_type": "stream", 593 | "text": [ 594 | " 0%| | 0/614 [00:00\n", 696 | "\n", 709 | "\n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | "
dmr_labelfi
001.0
111.0
221.0
331.0
441.0
\n", 745 | "" 746 | ], 747 | "text/plain": [ 748 | " dmr_label fi\n", 749 | "0 0 1.0\n", 750 | "1 1 1.0\n", 751 | "2 2 1.0\n", 752 | "3 3 1.0\n", 753 | "4 4 1.0" 754 | ] 755 | }, 756 | "execution_count": 9, 757 | "metadata": {}, 758 | "output_type": "execute_result" 759 | } 760 | ], 761 | "source": [ 762 | "pd.read_csv(os.path.join(output_path, \"FI.csv\"), sep=\"\\t\").head()" 763 | ] 764 | } 765 | ], 766 | "metadata": { 767 | "kernelspec": { 768 | "display_name": "dnabert", 769 | "language": "python", 770 | "name": "dnabert" 771 | }, 772 | "language_info": { 773 | "codemirror_mode": { 774 | "name": "ipython", 775 | "version": 3 776 | }, 777 | "file_extension": ".py", 778 | "mimetype": "text/x-python", 779 | "name": "python", 780 | "nbconvert_exporter": "python", 781 | "pygments_lexer": "ipython3", 782 | "version": "3.6.13" 783 | } 784 | }, 785 | "nbformat": 4, 786 | "nbformat_minor": 5 787 | } 788 | --------------------------------------------------------------------------------