├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── conda ├── GPURTX.yml ├── README.md ├── create-conda-env.sh └── gpurtx-2.yml ├── data ├── .gitattributes ├── JASPAR │ └── JASPAR2020_CORE_vertebrates_non-redundant_pfms_meme.txt ├── figs │ └── ExplaiNN.png ├── test │ ├── MAX_JUND_FOXA1_tomtom.tsv │ ├── explainn_filters.meme │ ├── figs │ │ ├── example_train.png │ │ ├── importance_TF.png │ │ └── weights_TF.png │ ├── tf_peaks_TEST_sparse_Remap.h5.gz │ └── weights │ │ └── model_epoch_best_4.pth └── tutorial │ ├── README.md │ └── slides.pdf ├── explainn ├── __init__.py ├── interpretation │ ├── __init__.py │ └── interpretation.py ├── models │ ├── __init__.py │ └── networks.py ├── train │ ├── __init__.py │ ├── test.py │ └── train.py └── utils │ ├── __init__.py │ └── tools.py ├── notebooks ├── Loss.png ├── paper │ ├── ExplaiNN_DeepSTARR_interactions.ipynb │ ├── ExplaiNN_TL_with_DanQ_(Figure_5).ipynb │ ├── ExplaiNN_TL_with_JASPAR_(Figure_5).ipynb │ ├── ExplaiNN_for_TF_binding_(Figure_1).ipynb │ ├── ExplaiNN_for_immune_accessibility_(Figure_3).ipynb │ ├── ExplaiNN_for_single_cell_(Figure_4).ipynb │ ├── Single_cell_pancreas_PeakVI.ipynb │ └── checkpoints │ │ └── model_epoch_best_5_.pth └── test.ipynb ├── requirements.txt ├── scripts ├── DanQ │ ├── fasta2predictions.py │ ├── finetune.py │ ├── interpret.py │ ├── test.py │ ├── train.py │ └── tsv2predictions.py ├── fasta2predictions.py ├── finetune.py ├── interpret.py ├── parsers │ ├── GRECO-BIT │ │ ├── afs+hts2explainn.py │ │ ├── chs2explainn.py │ │ ├── hts2explainn.py │ │ ├── pbm2explainn.py │ │ └── sms2explainn.py │ ├── de-novo │ │ ├── hts2explainn.py │ │ ├── matrix2explainn.py │ │ ├── pbm2explainn.py │ │ └── sms2explainn.py │ ├── fasta2explainn.py │ ├── fastq2explainn.py │ ├── json2explainn.py │ └── pbm2explainn.py ├── test.py ├── train.py ├── tsv2predictions.py └── utils │ ├── __init__.py │ ├── fonts │ └── Arial.ttf │ ├── jaspar2logo.py │ ├── match-seqs-by-gc.py │ ├── meme2clusters.py │ ├── meme2logo.py │ ├── meme2scores.py │ ├── pwm2scores.py │ ├── resize.py │ ├── subsample-seqs-by-gc.py │ └── tomtom.py └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.h5 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Gherman Novakovsky 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /conda/README.md: -------------------------------------------------------------------------------- 1 | ```bash 2 | # Create conda environment 3 | conda create -y -n explainn -c pytorch -c conda-forge -c bioconda \ 4 | bedops \ 5 | biasaway \ 6 | biopython \ 7 | click click-option-group \ 8 | cudatoolkit=11.0.3 pytorch=1.11 torchaudio=0.12.1 torchvision=0.13.1 \ 9 | fastcluster \ 10 | genomepy \ 11 | h5py \ 12 | joblib=1.1.0 \ 13 | jupyterlab \ 14 | logomaker \ 15 | matplotlib \ 16 | numpy \ 17 | pandas \ 18 | parallel-fastq-dump \ 19 | pybedtools \ 20 | python=3.9.12 \ 21 | scikit-learn \ 22 | sra-tools=3.0.0 \ 23 | tqdm 24 | ``` -------------------------------------------------------------------------------- /conda/create-conda-env.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Create conda environment 4 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 5 | conda env create -f ${SCRIPT_DIR}/${HOSTNAME}.yml 6 | -------------------------------------------------------------------------------- /data/.gitattributes: -------------------------------------------------------------------------------- 1 | .h5 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /data/figs/ExplaiNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/data/figs/ExplaiNN.png -------------------------------------------------------------------------------- /data/test/figs/example_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/data/test/figs/example_train.png -------------------------------------------------------------------------------- /data/test/figs/importance_TF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/data/test/figs/importance_TF.png -------------------------------------------------------------------------------- /data/test/figs/weights_TF.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/data/test/figs/weights_TF.png -------------------------------------------------------------------------------- /data/test/tf_peaks_TEST_sparse_Remap.h5.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/data/test/tf_peaks_TEST_sparse_Remap.h5.gz -------------------------------------------------------------------------------- /data/test/weights/model_epoch_best_4.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/data/test/weights/model_epoch_best_4.pth -------------------------------------------------------------------------------- /data/tutorial/README.md: -------------------------------------------------------------------------------- 1 | Slides are also available on [Google Docs](https://docs.google.com/presentation/d/1KMWp0oMGo6dhn0pWaz-TFLeLBwtuWmUUD6ldEQ0C4Do/edit?usp=sharing). -------------------------------------------------------------------------------- /data/tutorial/slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/data/tutorial/slides.pdf -------------------------------------------------------------------------------- /explainn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ExplaiNN 3 | interpretable and transparent neural networks for genomics 4 | """ 5 | 6 | from .utils import tools 7 | from .interpretation import interpretation 8 | from .train import train, test 9 | from .models import networks 10 | 11 | __version__ = "0.1.5" -------------------------------------------------------------------------------- /explainn/interpretation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/explainn/interpretation/__init__.py -------------------------------------------------------------------------------- /explainn/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/explainn/models/__init__.py -------------------------------------------------------------------------------- /explainn/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/explainn/train/__init__.py -------------------------------------------------------------------------------- /explainn/train/test.py: -------------------------------------------------------------------------------- 1 | # ============================================================================= 2 | # IMPORTS 3 | # ============================================================================= 4 | import torch.nn as nn 5 | import torch 6 | import numpy as np 7 | 8 | # ============================================================================= 9 | # FUNCTIONS 10 | # ============================================================================= 11 | 12 | def run_test(model, data_loader, device, isSigmoid=False): 13 | """ 14 | 15 | :param model: ExplaiNN model 16 | :param data_loader: torch DataLoader, data loader with the sequences of interest 17 | :param device: current available device ('cuda:0' or 'cpu') 18 | :param isSigmoid: boolean, True if the model output is binary 19 | :return: 20 | """ 21 | running_outputs = [] 22 | running_labels = [] 23 | sigmoid = nn.Sigmoid() 24 | with torch.no_grad(): 25 | for seq, lbl in data_loader: 26 | seq = seq.to(device) 27 | out = model(seq) 28 | out = out.detach().cpu() 29 | if isSigmoid: 30 | out = sigmoid(out) 31 | running_outputs.extend(out.numpy()) 32 | running_labels.extend(lbl.numpy()) 33 | return np.array(running_labels), np.array(running_outputs) 34 | -------------------------------------------------------------------------------- /explainn/train/train.py: -------------------------------------------------------------------------------- 1 | # ============================================================================= 2 | # IMPORTS 3 | # ============================================================================= 4 | import copy 5 | import os 6 | import torch 7 | from torch.nn import CosineSimilarity 8 | 9 | # ============================================================================= 10 | # FUNCTIONS 11 | # ============================================================================= 12 | 13 | def train_explainn(train_loader, test_loader, model, device, criterion, 14 | optimizer, num_epochs=100, weights_folder="./", 15 | name_ind=None, verbose=False, trim_weights=False, 16 | checkpoint=1, patience=0): 17 | """ 18 | Function to train the ExplaiNN model 19 | 20 | :param train_loader: pytorch DataLoader, train data 21 | :param test_loader: pytorch DataLoader, validation data 22 | :param model: ExplaiNN model 23 | :param device: current available device ("cuda:0" or "cpu") 24 | :param criterion: objective (loss) function to use (e.g. MSELoss) 25 | :param optimizer: pytorch Optimizer (e.g. SGD) 26 | :param num_epochs: int, number of epochs to train the model 27 | :param weights_folder: string, folder where to save checkpoints 28 | :param name_ind: string, suffix name of the checkpoints 29 | :param verbose: boolean, if False, does not print the progress 30 | :param trim_weights: boolean, if True, makes output layer weights non-negative 31 | :param checkpoint: int, how often to save checkpoints (e.g. 1 means that the model will be saved after each epoch; 32 | 0 that only the best model will be saved) 33 | :param patience: int, number of epochs to wait before stopping training if validation loss does not improve; 34 | if 0, this parameter is ignored 35 | :return: tuple: 36 | trained ExplaiNN model, 37 | list, train losses, 38 | list, test losses 39 | """ 40 | 41 | train_error = [] 42 | test_error = [] 43 | 44 | best_model_wts = copy.deepcopy(model.state_dict()) 45 | # if save_optimizer: 46 | # best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 47 | best_loss_valid = float('inf') 48 | best_epoch = 1 49 | 50 | for epoch in range(1, num_epochs+1): 51 | 52 | running_loss = 0.0 53 | 54 | model.train() 55 | for seqs, labels in train_loader: 56 | x = seqs.to(device) 57 | labels = labels.to(device) 58 | 59 | optimizer.zero_grad() 60 | 61 | outputs = model(x) 62 | loss = criterion(outputs, labels) 63 | 64 | # backward and optimize 65 | loss.backward() 66 | 67 | optimizer.step() 68 | 69 | # to clip the weights (constrain them to be non-negative) 70 | if trim_weights: 71 | model.final.weight.data.clamp_(0) 72 | 73 | running_loss += loss.item() 74 | 75 | # save training loss to file 76 | epoch_loss = running_loss / len(train_loader) 77 | train_error.append(epoch_loss) 78 | 79 | # calculate test (validation) loss for epoch 80 | test_loss = 0.0 81 | 82 | with torch.no_grad(): # we don't train and don't save gradients here 83 | model.eval() # we set forward module to change dropout and batch normalization techniques 84 | for seqs, labels in test_loader: 85 | x = seqs.to(device) 86 | y = labels.to(device) 87 | outputs = model(x) 88 | loss = criterion(outputs, y) 89 | test_loss += loss.item() 90 | 91 | test_loss = test_loss / len(test_loader) 92 | test_error.append(test_loss) 93 | 94 | if verbose: 95 | print('Epoch [{}], Current Train Loss: {:.5f}, Current Val Loss: {:.5f}' 96 | .format(epoch, epoch_loss, test_loss)) 97 | 98 | if test_loss < best_loss_valid: 99 | best_loss_valid = test_loss 100 | best_epoch = epoch 101 | best_model_wts = copy.deepcopy(model.state_dict()) 102 | # if save_optimizer: 103 | # best_optimizer_wts = copy.deepcopy(optimizer.state_dict()) 104 | 105 | if checkpoint: 106 | if epoch % checkpoint == 0: 107 | model.load_state_dict(best_model_wts) 108 | if name_ind: 109 | f = f"model_epoch_{epoch}_{name_ind}.pth" 110 | else: 111 | f = f"model_epoch_{epoch}.pth" 112 | torch.save(best_model_wts, os.path.join(weights_folder, f)) 113 | # if save_optimizer: 114 | # optimizer.load_state_dict(best_optimizer_wts) 115 | # if name_ind: 116 | # f = f"optimizer_epoch_{epoch}_{name_ind}.pth" 117 | # else: 118 | # f = f"optimizer_epoch_{epoch}.pth" 119 | # torch.save(best_optimizer_wts, os.path.join(weights_folder, f)) 120 | 121 | if patience: 122 | if epoch >= best_epoch + patience: # at last, we lost our patience! 123 | if verbose: 124 | print('Early stopping, Current Epoch {}, Best Epoch: {}, Patience: {}' 125 | .format(epoch, best_epoch, patience)) 126 | break 127 | 128 | model.load_state_dict(best_model_wts) 129 | if name_ind: 130 | f = f"model_epoch_best_{best_epoch}_{name_ind}.pth" 131 | else: 132 | f = f"model_epoch_best_{best_epoch}.pth" 133 | torch.save(best_model_wts, os.path.join(weights_folder, f)) 134 | # if save_optimizer: 135 | # optimizer.load_state_dict(best_optimizer_wts) 136 | # if name_ind: 137 | # f = f"optimizer_epoch_best_{best_epoch}_{name_ind}.pth" 138 | # else: 139 | # f = f"optimizer_epoch_best_{best_epoch}.pth" 140 | # torch.save(best_optimizer_wts, os.path.join(weights_folder, f)) 141 | 142 | return model, train_error, test_error 143 | -------------------------------------------------------------------------------- /explainn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/explainn/utils/__init__.py -------------------------------------------------------------------------------- /explainn/utils/tools.py: -------------------------------------------------------------------------------- 1 | # ============================================================================= 2 | # IMPORTS 3 | # ============================================================================= 4 | import torch 5 | import numpy as np 6 | import h5py 7 | from tqdm import tqdm 8 | import torch.nn as nn 9 | import matplotlib.pyplot as plt 10 | 11 | # ============================================================================= 12 | # FUNCTIONS 13 | # ============================================================================= 14 | def count_parameters(model): 15 | """ 16 | Calculates the number of parameters in the model 17 | 18 | :param model: pytorch model 19 | :return: int, number of parameters 20 | """ 21 | 22 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 23 | 24 | 25 | def showPlot(points, points2, title, save=None): 26 | """ 27 | function to show training curve 28 | save - place to save the figure 29 | :param points: list, training loss 30 | :param points2: list, validation loss 31 | :param title: string, optional title 32 | :param save: boolean, save figure as a sep file or not 33 | :return: 34 | """ 35 | 36 | plt.figure() 37 | fig, ax = plt.subplots() 38 | # this locator puts ticks at regular intervals 39 | # loc = ticker.MultipleLocator(base=0.2) 40 | # ax.yaxis.set_major_locator(loc) 41 | plt.plot(points) 42 | plt.plot(points2) 43 | plt.ylabel("Loss") 44 | plt.legend(['train', 'validation'], loc='upper right') 45 | plt.title(title) 46 | 47 | if save: 48 | plt.savefig(save) 49 | else: 50 | plt.show() 51 | 52 | 53 | def load_datas(path_h5, batch_size, num_workers, shuffle): 54 | """ 55 | Loads the dataset from an h5 file 56 | 57 | :param path_h5: string, path to the file 58 | :param batch_size: int, batch size to use 59 | :param num_workers: int, number of workers for DataLoader 60 | :param shuffle: boolean, shuffle parameter in DataLoader 61 | :return: tuple: 62 | a dictionary with train, validation, and test sets; 63 | a list with the names of the output nodes; 64 | an array with true values for the train samples; 65 | """ 66 | 67 | data = h5py.File(path_h5, 'r') 68 | dataset = {} 69 | dataloaders = {} 70 | # Train data 71 | dataset['train'] = torch.utils.data.TensorDataset(torch.Tensor(np.array(data['train_in'])), 72 | torch.Tensor(np.array(data['train_out']))) 73 | dataloaders['train'] = torch.utils.data.DataLoader(dataset['train'], 74 | batch_size=batch_size, shuffle=shuffle, 75 | num_workers=num_workers) 76 | 77 | # Validation data 78 | dataset['valid'] = torch.utils.data.TensorDataset(torch.Tensor(np.array(data['valid_in'])), 79 | torch.Tensor(np.array(data['valid_out']))) 80 | dataloaders['valid'] = torch.utils.data.DataLoader(dataset['valid'], 81 | batch_size=batch_size, shuffle=shuffle, 82 | num_workers=num_workers) 83 | 84 | # Test data 85 | dataset['test'] = torch.utils.data.TensorDataset(torch.Tensor(np.array(data['test_in'])), 86 | torch.Tensor(np.array(data['test_out']))) 87 | dataloaders['test'] = torch.utils.data.DataLoader(dataset['test'], 88 | batch_size=batch_size, shuffle=shuffle, 89 | num_workers=num_workers) 90 | print('Dataset Loaded') 91 | target_labels = list(data['target_labels']) 92 | train_out = np.array(data['train_out']) 93 | return dataloaders, target_labels, train_out 94 | 95 | 96 | def load_single_data(path_h5, batch_size, num_workers, shuffle): 97 | """ 98 | Load h5 file as a single dataset. Has to have the following fields: 99 | train_in : input train data 100 | valid_in : input validation data 101 | test_in : input test data 102 | train_out : output train labels 103 | valid_out : output validation labels 104 | test_out : output test labels 105 | :param path_h5: string, path to h5 file with the data 106 | :param batch_size: int, size of the batch in the dataloader 107 | :param num_workers: int, number of workers for DataLoader 108 | :param shuffle: boolean, shuffle parameter in DataLoader 109 | :return: tuple: 110 | DataLoader, a dataset where all train, validation, and test samples are put together 111 | torch.tensor, all input data 112 | torch.tensor, all labels of input data 113 | """ 114 | 115 | data = h5py.File(path_h5, 'r') 116 | 117 | x = torch.Tensor(np.array(data['train_in'])) 118 | y = torch.Tensor(np.array(data['valid_in'])) 119 | z = torch.Tensor(np.array(data['test_in'])) 120 | 121 | x_lab = torch.Tensor(np.array(data['train_out'])) 122 | y_lab = torch.Tensor(np.array(data['valid_out'])) 123 | z_lab = torch.Tensor(np.array(data['test_out'])) 124 | 125 | res = torch.cat((x, y, z), dim=0) 126 | res_lab = torch.cat((x_lab, y_lab, z_lab), dim=0) 127 | 128 | all_dataset = torch.utils.data.TensorDataset(res, res_lab) 129 | dataloader = torch.utils.data.DataLoader(all_dataset, 130 | batch_size=128, shuffle=False, 131 | num_workers=0) 132 | 133 | return dataloader, res, res_lab 134 | 135 | 136 | def dna_one_hot(seq, seq_len=None, flatten=False): 137 | """ 138 | Converts an input dna sequence to one hot encoded representation, with (A:0,C:1,G:2,T:3) alphabet 139 | 140 | :param seq: string, input dna sequence 141 | :param seq_len: int, optional, length of the string 142 | :param flatten: boolean, if true, makes a 1 column vector 143 | :return: numpy.array, one-hot encoded matrix of size (4, L), where L - the length of the input sequence 144 | """ 145 | 146 | if seq_len == None: 147 | seq_len = len(seq) 148 | seq_start = 0 149 | else: 150 | if seq_len <= len(seq): 151 | # trim the sequence 152 | seq_trim = (len(seq) - seq_len) // 2 153 | seq = seq[seq_trim:seq_trim + seq_len] 154 | seq_start = 0 155 | else: 156 | seq_start = (seq_len - len(seq)) // 2 157 | 158 | seq = seq.upper() 159 | 160 | seq = seq.replace("A", "0") 161 | seq = seq.replace("C", "1") 162 | seq = seq.replace("G", "2") 163 | seq = seq.replace("T", "3") 164 | 165 | # map nt's to a matrix 4 x len(seq) of 0's and 1's. 166 | # dtype="int8" fails for N's 167 | seq_code = np.zeros((4, seq_len), dtype="float16") 168 | for i in range(seq_len): 169 | if i < seq_start: 170 | seq_code[:, i] = 0. 171 | else: 172 | try: 173 | seq_code[int(seq[i - seq_start]), i] = 1. 174 | except: 175 | seq_code[:, i] = 0. 176 | 177 | # flatten and make a column vector 1 x len(seq) 178 | if flatten: 179 | seq_code = seq_code.flatten()[None, :] 180 | 181 | return seq_code 182 | 183 | 184 | def convert_one_hot_back_to_seq(dataloader): 185 | """ 186 | Converts one-hot encoded matrices back to DNA sequences 187 | :param dataloader: pytorch, DataLoader 188 | :return: list of strings, DNA sequences 189 | """ 190 | 191 | sequences = [] 192 | code = list("ACGT") 193 | for seqs, labels in tqdm(dataloader, total=len(dataloader)): 194 | x = seqs.permute(0, 1, 3, 2) 195 | x = x.squeeze(-1) 196 | for i in range(x.shape[0]): 197 | seq = "" 198 | for j in range(x.shape[-1]): 199 | try: 200 | seq = seq + code[int(np.where(x[i, :, j] == 1)[0])] 201 | except: 202 | print("error") 203 | print(x[i, :, j]) 204 | print(np.where(x[i, :, j] == 1)) 205 | break 206 | sequences.append(seq) 207 | return sequences 208 | 209 | 210 | def _flip(x, dim): 211 | """ 212 | Adapted from Selene: 213 | https://github.com/FunctionLab/selene/blob/master/selene_sdk/utils/non_strand_specific_module.py 214 | 215 | Reverses the elements in a given dimension `dim` of the Tensor. 216 | source: https://github.com/pytorch/pytorch/issues/229 217 | """ 218 | 219 | xsize = x.size() 220 | dim = x.dim() + dim if dim < 0 else dim 221 | x = x.contiguous() 222 | x = x.view(-1, *xsize[dim:]) 223 | x = x.view( 224 | x.size(0), x.size(1), -1)[:, getattr( 225 | torch.arange(x.size(1) - 1, -1, -1), 226 | ("cpu", "cuda")[x.is_cuda])().long(), :] 227 | 228 | return x.view(xsize) 229 | 230 | 231 | def pearson_loss(x, y): 232 | """ 233 | Loss that is based on Pearson correlation/objective function 234 | :param x: torch, input data 235 | :param y: torch, output labels 236 | :return: torch, pearson loss per sample 237 | """ 238 | 239 | mx = torch.mean(x, dim=1, keepdim=True) 240 | my = torch.mean(y, dim=1, keepdim=True) 241 | xm, ym = x - mx, y - my 242 | 243 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 244 | loss = torch.sum(1-cos(xm,ym)) 245 | return loss 246 | 247 | 248 | def change_grad_filters(model, val, index): 249 | """ 250 | Function to modify the gradient of filters of interest by the specified value 251 | :param model: ExplaiNN model 252 | :param val: float or int, value by which gradients will be multiplied 253 | :param index: list, gradients of which unit filters to modify 254 | :return: ExplaiNN model 255 | """ 256 | 257 | def replace_val(grad, val, index): 258 | grad[index, :, :] *= val 259 | 260 | return grad 261 | 262 | model.linears[0].weight.register_hook(lambda grad: replace_val(grad, val, index)) 263 | 264 | return model 265 | 266 | 267 | def _PWM_to_filter_weights(pwm, filter_size=19): 268 | """ 269 | Function to convert a given pwm into convolutional filter 270 | :param pwm: list of length L (size of the PWM) with lists of size 4 (nucleotide values) 271 | :param filter_size: int, the size of the pwm 272 | :return: numpy.array, of shape (L, 4) 273 | """ 274 | 275 | # Initialize 276 | lpop = 0 277 | rpop = 0 278 | 279 | pwm = [[.25,.25,.25,.25]]*filter_size+pwm+[[.25,.25,.25,.25]]*filter_size 280 | 281 | while len(pwm) > filter_size: 282 | if max(pwm[0]) < max(pwm[-1]): 283 | pwm.pop(0) 284 | lpop += 1 285 | elif max(pwm[-1]) < max(pwm[0]): 286 | pwm.pop(-1) 287 | rpop += 1 288 | else: 289 | if lpop > rpop: 290 | pwm.pop(-1) 291 | rpop += 1 292 | else: 293 | pwm.pop(0) 294 | lpop += 1 295 | 296 | return(np.array(pwm) - .25) 297 | 298 | 299 | def read_meme(meme_file): 300 | """ 301 | Function the motifs from the input meme file 302 | Right now works only if motifs are of the same length 303 | :param meme_file: string, path to the meme file 304 | :return: tuple: 305 | numpy.array, matrix of pwms of size (N, L, 4), where N is the number of motifs, L is the motif length, 306 | 4 number of nucleotides 307 | """ 308 | 309 | with open(meme_file) as fp: 310 | line = fp.readline() 311 | motifs = [] 312 | motif_names = [] 313 | while line: 314 | # determine length of next motif 315 | if line.split(" ")[0] == 'MOTIF': 316 | line = line.strip() 317 | # add motif number to separate array 318 | motif_names.append(line.split(" ")[1]) 319 | # get length of motif 320 | line2 = fp.readline().split(" ") 321 | motif_length = int(float(line2[5])) 322 | # read in motif 323 | current_motif = np.zeros((motif_length, 4)) # Edited pad shorter ones with 0 324 | for i in range(motif_length): 325 | current_motif[i, :] = fp.readline().strip().split() 326 | motifs.append(current_motif) 327 | line = fp.readline() 328 | motifs = np.stack(motifs) 329 | motif_names = np.stack(motif_names) 330 | 331 | return motifs, motif_names 332 | -------------------------------------------------------------------------------- /notebooks/Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/notebooks/Loss.png -------------------------------------------------------------------------------- /notebooks/paper/checkpoints/model_epoch_best_5_.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/notebooks/paper/checkpoints/model_epoch_best_5_.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.6 2 | h5py==3.6.0 3 | tqdm==4.64.0 4 | pandas==1.3.5 5 | matplotlib==3.5.2 -------------------------------------------------------------------------------- /scripts/DanQ/fasta2predictions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from Bio import SeqIO 4 | import click 5 | from click_option_group import optgroup 6 | import json 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | import sys 11 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 12 | os.pardir)) 13 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 14 | os.pardir, 15 | os.pardir)) 16 | import time 17 | import torch 18 | from tqdm import tqdm 19 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 20 | 21 | from explainn.models.networks import DanQ 22 | from utils import (_dna_one_hot_many, _resize_sequence, get_file_handle, 23 | get_data_loader, get_device) 24 | 25 | CONTEXT_SETTINGS = { 26 | "help_option_names": ["-h", "--help"], 27 | } 28 | 29 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 30 | @click.argument( 31 | "model_file", 32 | type=click.Path(exists=True, resolve_path=True) 33 | ) 34 | @click.argument( 35 | "training_parameters_file", 36 | type=click.Path(exists=True, resolve_path=True), 37 | ) 38 | @click.argument( 39 | "fasta_file", 40 | type=click.Path(exists=True, resolve_path=True), 41 | nargs=-1 42 | ) 43 | @click.option( 44 | "-c", "--cpu-threads", 45 | help="Number of CPU threads to use.", 46 | type=int, 47 | default=1, 48 | show_default=True, 49 | ) 50 | @click.option( 51 | "-o", "--output-dir", 52 | help="Output directory.", 53 | type=click.Path(resolve_path=True), 54 | default="./", 55 | show_default=True, 56 | ) 57 | @click.option( 58 | "-t", "--time", 59 | help="Return the program's running execution time in seconds.", 60 | is_flag=True, 61 | ) 62 | @optgroup.group("Predict") 63 | @optgroup.option( 64 | "--apply-sigmoid", 65 | help="Apply the logistic sigmoid function to outputs.", 66 | is_flag=True, 67 | ) 68 | @optgroup.option( 69 | "--batch-size", 70 | help="Batch size.", 71 | type=int, 72 | default=100, 73 | show_default=True, 74 | ) 75 | 76 | def cli(**args): 77 | 78 | # Start execution 79 | start_time = time.time() 80 | 81 | # Initialize 82 | if not os.path.exists(args["output_dir"]): 83 | os.makedirs(args["output_dir"]) 84 | 85 | # Save exec. parameters as JSON 86 | json_file = os.path.join(args["output_dir"], 87 | f"parameters-{os.path.basename(__file__)}.json") 88 | handle = get_file_handle(json_file, "wt") 89 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 90 | handle.close() 91 | 92 | ############## 93 | # Load Data # 94 | ############## 95 | 96 | # Load training parameters 97 | handle = get_file_handle(args["training_parameters_file"], "rt") 98 | train_args = json.load(handle) 99 | handle.close() 100 | 101 | # Get sequences and ids 102 | seqs, ids = [], [] 103 | for fasta_file in args["fasta_file"]: 104 | fh = get_file_handle(fasta_file, "rt") 105 | records = list(SeqIO.parse(fh, "fasta")) 106 | fh.close() 107 | for record in records: 108 | seqs.append( 109 | _resize_sequence(str(record.seq), train_args["input_length"]) 110 | ) 111 | ids.append(record.id) 112 | seqs = _dna_one_hot_many(seqs) 113 | 114 | ############## 115 | # Predict # 116 | ############## 117 | 118 | # Get device 119 | device = get_device() 120 | 121 | # Get model 122 | state_dict = torch.load(args["model_file"]) 123 | for k in reversed(state_dict.keys()): 124 | num_classes = state_dict[k].shape[0] 125 | break 126 | 127 | # Get model 128 | m = DanQ(train_args["input_length"], num_classes, args["model_file"]) 129 | 130 | # Test 131 | _predict(seqs, ids, num_classes, m, device, args["output_dir"], 132 | args["apply_sigmoid"], args["batch_size"]) 133 | 134 | # Finish execution 135 | seconds = format(time.time() - start_time, ".2f") 136 | if args["time"]: 137 | f = os.path.join(args["output_dir"], 138 | f"time-{os.path.basename(__file__)}.txt") 139 | handle = get_file_handle(f, "wt") 140 | handle.write(f"{seconds} seconds") 141 | handle.close() 142 | print(f"Execution time {seconds} seconds") 143 | 144 | def _predict(seqs, ids, num_classes, model, device, output_dir="./", 145 | apply_sigmoid=False, batch_size=100): 146 | 147 | # Initialize 148 | idx = 0 149 | predictions = np.empty((len(seqs), num_classes, 4)) 150 | model.to(device) 151 | model.eval() 152 | 153 | # Get training DataLoader 154 | data_loader = get_data_loader( 155 | seqs, 156 | np.array([s[::-1, ::-1] for s in seqs]), 157 | batch_size 158 | ) 159 | 160 | with torch.no_grad(): 161 | 162 | for fwd, rev in tqdm(iter(data_loader), total=len(data_loader), 163 | bar_format=bar_format): 164 | 165 | # Get strand-specific predictions 166 | fwd = np.expand_dims(model(fwd.to(device)).cpu().numpy(), axis=2) 167 | rev = np.expand_dims(model(rev.to(device)).cpu().numpy(), axis=2) 168 | 169 | # Combine predictions from both strands 170 | fwd_rev = np.concatenate((fwd, rev), axis=2) 171 | mean_fwd_rev = np.expand_dims(np.mean(fwd_rev, axis=2), axis=2) 172 | max_fwd_rev = np.expand_dims(np.max(fwd_rev, axis=2), axis=2) 173 | 174 | # Concatenate predictions for this batch 175 | p = np.concatenate((fwd, rev, mean_fwd_rev, max_fwd_rev), axis=2) 176 | predictions[idx:idx+fwd.shape[0]] = p 177 | 178 | # Index increase 179 | idx += fwd.shape[0] 180 | 181 | # Apply sigmoid 182 | if apply_sigmoid: 183 | predictions = torch.sigmoid(torch.Tensor(predictions)).numpy() 184 | 185 | # Get predictions 186 | tsv_file = os.path.join(output_dir, "predictions.tsv.gz") 187 | if not os.path.exists(tsv_file): 188 | dfs = [] 189 | for i in range(num_classes): 190 | p = predictions[:, i, :] 191 | df = pd.DataFrame(p, columns=["Fwd", "Rev", "Mean", "Max"]) 192 | df["SeqId"] = ids 193 | df["Class"] = i 194 | dfs.append(df) 195 | df = pd.concat(dfs)[["SeqId", "Class", "Fwd", "Rev", "Mean", "Max"]] 196 | df.reset_index(drop=True, inplace=True) 197 | df.to_csv(tsv_file, sep="\t", index=False) 198 | 199 | if __name__ == "__main__": 200 | cli() -------------------------------------------------------------------------------- /scripts/DanQ/finetune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup 5 | import json 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import shutil 10 | import sys 11 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 12 | os.pardir)) 13 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 14 | os.pardir, 15 | os.pardir)) 16 | import time 17 | import torch 18 | 19 | from explainn.train.train import train_explainn 20 | from explainn.utils.tools import pearson_loss 21 | from explainn.models.networks import DanQ 22 | from utils import (get_file_handle, get_seqs_labels_ids, get_data_loader, 23 | get_device) 24 | 25 | CONTEXT_SETTINGS = { 26 | "help_option_names": ["-h", "--help"], 27 | } 28 | 29 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 30 | @click.argument( 31 | "model_file", 32 | type=click.Path(exists=True, resolve_path=True), 33 | ) 34 | @click.argument( 35 | "training_parameters_file", 36 | type=click.Path(exists=True, resolve_path=True), 37 | ) 38 | @click.argument( 39 | "training_file", 40 | type=click.Path(exists=True, resolve_path=True), 41 | ) 42 | @click.argument( 43 | "validation_file", 44 | type=click.Path(exists=True, resolve_path=True), 45 | ) 46 | @click.option( 47 | "-c", "--cpu-threads", 48 | help="Number of CPU threads to use.", 49 | type=int, 50 | default=1, 51 | show_default=True, 52 | ) 53 | @click.option( 54 | "-d", "--debugging", 55 | help="Debugging mode.", 56 | is_flag=True, 57 | ) 58 | @click.option( 59 | "-o", "--output-dir", 60 | help="Output directory.", 61 | type=click.Path(resolve_path=True), 62 | default="./", 63 | show_default=True, 64 | ) 65 | @click.option( 66 | "-t", "--time", 67 | help="Return the program's running execution time in seconds.", 68 | is_flag=True, 69 | ) 70 | @optgroup.group("Optimizer") 71 | @optgroup.option( 72 | "--criterion", 73 | help="Loss (objective) function to use. Select \"BCEWithLogits\" for binary or multi-class classification tasks (e.g. predict the binding of one or more TFs to a sequence), \"CrossEntropy\" for multi-class classification tasks wherein only one solution is possible (e.g. predict the species of origin of a sequence between human, mouse or zebrafish), \"MSE\" for regression tasks (e.g. predict probe intensity signals), \"Pearson\" also for regression tasks (e.g. modeling accessibility across 81 cell types), and \"PoissonNLL\" for modeling count data (e.g. total number of reads at ChIP-/ATAC-seq peaks).", 74 | type=click.Choice(["BCEWithLogits", "CrossEntropy", "MSE", "Pearson", "PoissonNLL"], case_sensitive=False), 75 | required=True 76 | ) 77 | @optgroup.option( 78 | "--lr", 79 | help="Learning rate.", 80 | type=float, 81 | default=5e-05, 82 | show_default=True, 83 | ) 84 | @optgroup.option( 85 | "--optimizer", 86 | help="`torch.optim.Optimizer` with which to minimize the loss during training.", 87 | type=click.Choice(["Adam", "SGD"], case_sensitive=False), 88 | default="Adam", 89 | show_default=True, 90 | ) 91 | @optgroup.group("Fine-tuning") 92 | @optgroup.option( 93 | "--batch-size", 94 | help="Batch size.", 95 | type=int, 96 | default=100, 97 | show_default=True, 98 | ) 99 | @optgroup.option( 100 | "--checkpoint", 101 | help="How often to save checkpoints (e.g. 1 means that the model will be saved after each epoch; by default, i.e. 0, only the best model will be saved).", 102 | type=int, 103 | default=0, 104 | show_default=True, 105 | ) 106 | @optgroup.option( 107 | "--freeze", 108 | help="Do not update the model weights during training.", 109 | is_flag=True, 110 | ) 111 | @optgroup.option( 112 | "--num-epochs", 113 | help="Number of epochs to train the model.", 114 | type=int, 115 | default=100, 116 | show_default=True, 117 | ) 118 | @optgroup.option( 119 | "--patience", 120 | help="Number of epochs to wait before stopping training if the validation loss does not improve.", 121 | type=int, 122 | default=10, 123 | show_default=True, 124 | ) 125 | @optgroup.option( 126 | "--rev-complement", 127 | help="Reverse and complement training sequences.", 128 | is_flag=True, 129 | ) 130 | @optgroup.option( 131 | "--trim-weights", 132 | help="Constrain output weights to be non-negative (i.e. to ease interpretation).", 133 | is_flag=True, 134 | ) 135 | 136 | def cli(**args): 137 | 138 | # Start execution 139 | start_time = time.time() 140 | 141 | # Initialize 142 | if not os.path.exists(args["output_dir"]): 143 | os.makedirs(args["output_dir"]) 144 | 145 | # Save exec. parameters as JSON 146 | json_file = os.path.join(args["output_dir"], 147 | f"parameters-{os.path.basename(__file__)}.json") 148 | handle = get_file_handle(json_file, "wt") 149 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 150 | handle.close() 151 | 152 | ############## 153 | # Load Data # 154 | ############## 155 | 156 | # Load training parameters 157 | handle = get_file_handle(args["training_parameters_file"], "rt") 158 | train_args = json.load(handle) 159 | handle.close() 160 | 161 | # Get training/test sequences and labels 162 | train_seqs, train_labels, _ = get_seqs_labels_ids(args["training_file"], 163 | args["debugging"], 164 | args["rev_complement"], 165 | train_args["input_length"]) 166 | test_seqs, test_labels, _ = get_seqs_labels_ids(args["validation_file"], 167 | args["debugging"], 168 | args["rev_complement"], 169 | train_args["input_length"]) 170 | 171 | # Get training/test DataLoaders 172 | train_loader = get_data_loader(train_seqs, train_labels, 173 | args["batch_size"], shuffle=True) 174 | test_loader = get_data_loader(test_seqs, test_labels, 175 | args["batch_size"], shuffle=True) 176 | 177 | # Load pre-trained state dict 178 | state_dict_pretrain = torch.load(args["model_file"]) 179 | 180 | ############## 181 | # Fine-tune # 182 | ############## 183 | 184 | # Infer input length/type, and the number of classes 185 | # input_length = train_seqs[0].shape[1] 186 | num_classes = train_labels[0].shape[0] 187 | 188 | # Get device 189 | device = get_device() 190 | 191 | # Get criterion 192 | if args["criterion"].lower() == "bcewithlogits": 193 | criterion = torch.nn.BCEWithLogitsLoss() 194 | elif args["criterion"].lower() == "crossentropy": 195 | criterion = torch.nn.CrossEntropyLoss() 196 | elif args["criterion"].lower() == "mse": 197 | criterion = torch.nn.MSELoss() 198 | elif args["criterion"].lower() == "pearson": 199 | criterion = pearson_loss 200 | elif args["criterion"].lower() == "poissonnll": 201 | criterion = torch.nn.PoissonNLLLoss() 202 | 203 | # Get model 204 | m = DanQ(train_args["input_length"], num_classes) 205 | 206 | # Get optimizer 207 | o = _get_optimizer(args["optimizer"], m.parameters(), args["lr"]) 208 | 209 | # Transfer learning 210 | state_dict = m.state_dict() 211 | for k in state_dict: 212 | if not k.startswith("linear.2"): # do not transfer the final layer 213 | state_dict[k] = state_dict_pretrain[k] 214 | m.load_state_dict(state_dict) 215 | 216 | # Freeze weights 217 | if args["freeze"]: 218 | for n, p in m.named_parameters(): 219 | # Allow learning the weights of the final layer 220 | if not n.startswith("linear.2") and p.requires_grad: 221 | p.requires_grad = False 222 | 223 | # Fine-tune 224 | _finetune(train_loader, test_loader, m, device, criterion, o, 225 | args["num_epochs"], args["output_dir"], None, True, False, 226 | args["checkpoint"], args["patience"]) 227 | 228 | # Finish execution 229 | seconds = format(time.time() - start_time, ".2f") 230 | if args["time"]: 231 | f = os.path.join(args["output_dir"], 232 | f"time-{os.path.basename(__file__)}.txt") 233 | handle = get_file_handle(f, "wt") 234 | handle.write(f"{seconds} seconds") 235 | handle.close() 236 | print(f"Execution time {seconds} seconds") 237 | 238 | def _get_optimizer(optimizer, parameters, lr=5e-05): 239 | 240 | if optimizer.lower() == "adam": 241 | return torch.optim.Adam(parameters, lr=lr) 242 | elif optimizer.lower() == "sgd": 243 | return torch.optim.SGD(parameters, lr=lr) 244 | 245 | def _finetune(train_loader, test_loader, model, device, criterion, optimizer, 246 | num_epochs=100, output_dir="./", name_ind=None, verbose=False, 247 | trim_weights=False, checkpoint=0, patience=0): 248 | 249 | # Initialize 250 | model.to(device) 251 | 252 | # Train 253 | _, train_error, test_error = train_explainn(train_loader, test_loader, 254 | model, device, criterion, 255 | optimizer, num_epochs, 256 | output_dir, name_ind, 257 | verbose, trim_weights, 258 | checkpoint, patience) 259 | 260 | # Save losses 261 | df = pd.DataFrame(list(zip(train_error, test_error)), 262 | columns=["Train loss", "Validation loss"]) 263 | df.index += 1 264 | df.index.rename("Epoch", inplace=True) 265 | df.to_csv(os.path.join(output_dir, "losses.tsv"), sep="\t") 266 | 267 | if __name__ == "__main__": 268 | cli() -------------------------------------------------------------------------------- /scripts/DanQ/interpret.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup 5 | import copy 6 | import json 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | import pickle 11 | import shutil 12 | import sys 13 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 14 | os.pardir)) 15 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 16 | os.pardir, 17 | os.pardir)) 18 | import time 19 | import torch 20 | from tqdm import tqdm 21 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 22 | 23 | 24 | from explainn.models.networks import DanQ 25 | from explainn.interpretation.interpretation import (get_explainn_predictions, 26 | get_danq_activations, 27 | get_pwms_explainn, 28 | pwm_to_meme) 29 | 30 | from utils import (get_file_handle, get_seqs_labels_ids, get_data_loader, 31 | get_device) 32 | 33 | CONTEXT_SETTINGS = { 34 | "help_option_names": ["-h", "--help"], 35 | } 36 | 37 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 38 | @click.argument( 39 | "model_file", 40 | type=click.Path(exists=True, resolve_path=True), 41 | ) 42 | @click.argument( 43 | "training_parameters_file", 44 | type=click.Path(exists=True, resolve_path=True), 45 | ) 46 | @click.argument( 47 | "training_file", 48 | type=click.Path(exists=True, resolve_path=True), 49 | ) 50 | @click.option( 51 | "-c", "--cpu-threads", 52 | help="Number of CPU threads to use.", 53 | type=int, 54 | default=1, 55 | show_default=True, 56 | ) 57 | @click.option( 58 | "-d", "--debugging", 59 | help="Debugging mode.", 60 | is_flag=True, 61 | ) 62 | @click.option( 63 | "-o", "--output-dir", 64 | help="Output directory.", 65 | type=click.Path(resolve_path=True), 66 | default="./", 67 | show_default=True, 68 | ) 69 | @click.option( 70 | "-t", "--time", 71 | help="Return the program's running execution time in seconds.", 72 | is_flag=True, 73 | ) 74 | @optgroup.group("Interpretation") 75 | @optgroup.option( 76 | "--batch-size", 77 | help="Batch size.", 78 | type=int, 79 | default=100, 80 | show_default=True, 81 | ) 82 | @optgroup.option( 83 | "--num-well-pred-seqs", 84 | help="Number of well-predicted sequences to use. [default: use all sequences]", 85 | type=int, 86 | ) 87 | @optgroup.group("\n Define \"well-predicted\" sequences based on", cls=RequiredMutuallyExclusiveOptionGroup) 88 | @optgroup.option( 89 | "--correlation", 90 | help="If the correlation between the predicted and actual sequence values is >x (e.g. multi-class regression tasks).", 91 | type=click.FloatRange(0, 1, clamp=True), 92 | ) 93 | @optgroup.option( 94 | "--exact-match", 95 | help="If the predicted and actual sequence values are equal (e.g. binary or multi-class classification tasks).", 96 | is_flag=True 97 | ) 98 | @optgroup.option( 99 | "--percentile-bottom", 100 | help="If the predicted and actual sequence values are within the bottom x percentile (e.g. single-class regression tasks).", 101 | type=click.IntRange(1, 100, clamp=True), 102 | ) 103 | @optgroup.option( 104 | "--percentile-top", 105 | help="If the predicted and actual sequence values are within the top x percentile (e.g. single-class regression tasks).", 106 | type=click.IntRange(1, 100, clamp=True), 107 | ) 108 | 109 | def cli(**args): 110 | 111 | # Start execution 112 | start_time = time.time() 113 | 114 | # Initialize 115 | if not os.path.exists(args["output_dir"]): 116 | os.makedirs(args["output_dir"]) 117 | 118 | # Save exec. parameters as JSON 119 | json_file = os.path.join(args["output_dir"], 120 | f"parameters-{os.path.basename(__file__)}.json") 121 | handle = get_file_handle(json_file, "wt") 122 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 123 | handle.close() 124 | 125 | ############## 126 | # Load Data # 127 | ############## 128 | 129 | # Load training parameters 130 | handle = get_file_handle(args["training_parameters_file"], "rt") 131 | train_args = json.load(handle) 132 | handle.close() 133 | 134 | # Get training sequences and labels 135 | seqs, labels, _ = get_seqs_labels_ids(args["training_file"], 136 | args["debugging"], 137 | False, 138 | train_args["input_length"]) 139 | 140 | ############## 141 | # Interpret # 142 | ############## 143 | 144 | # Infer input type, and the number of classes 145 | num_classes = labels[0].shape[0] 146 | if np.unique(labels[:, 0]).size == 2: 147 | input_type = "binary" 148 | else: 149 | input_type = "non-binary" 150 | 151 | # Get device 152 | device = get_device() 153 | 154 | # Get criterion/threshold (if applicable) for well-predicted sequences 155 | if args["correlation"]: 156 | criterion = "correlation" 157 | threshold = args["correlation"] 158 | elif args["exact_match"]: 159 | criterion = "exact_match" 160 | threshold = None 161 | elif args["percentile_bottom"]: 162 | criterion = "percentile_bottom" 163 | threshold = args["percentile_bottom"] 164 | elif args["percentile_top"]: 165 | criterion = "percentile_top" 166 | threshold = args["percentile_top"] 167 | 168 | # Get model 169 | m = DanQ(train_args["input_length"], num_classes, args["model_file"]) 170 | 171 | # Interpret 172 | _interpret(seqs, labels, m, device, input_type, criterion, threshold, 173 | train_args["filter_size"], train_args["rev_complement"], 174 | args["output_dir"], args["batch_size"], 175 | args["num_well_pred_seqs"]) 176 | 177 | # Finish execution 178 | seconds = format(time.time() - start_time, ".2f") 179 | if args["time"]: 180 | f = os.path.join(args["output_dir"], 181 | f"time-{os.path.basename(__file__)}.txt") 182 | handle = get_file_handle(f, "wt") 183 | handle.write(f"{seconds} seconds") 184 | handle.close() 185 | print(f"Execution time {seconds} seconds") 186 | 187 | def _interpret(seqs, labels, model, device, input_type, criterion, 188 | threshold, filter_size, rev_complement, output_dir="./", 189 | batch_size=100, num_well_pred_seqs=None): 190 | 191 | # Initialize 192 | activations = [] 193 | predictions = [] 194 | outputs = [] 195 | model.to(device) 196 | model.eval() 197 | 198 | # Get training DataLoader 199 | data_loader = get_data_loader(seqs, labels, batch_size) 200 | 201 | # Get rev. complement 202 | if rev_complement: 203 | rev_seqs = np.array([s[::-1, ::-1] for s in seqs]) 204 | rev_data_loader = get_data_loader(rev_seqs, labels, batch_size) 205 | else: 206 | rev_seqs = None 207 | rev_data_loader = None 208 | 209 | # Get predictions 210 | for dl in [data_loader, rev_data_loader]: 211 | if dl is None: # skip 212 | continue 213 | preds, labels = get_explainn_predictions(dl, model, device, 214 | isSigmoid=False) 215 | predictions.append(preds) 216 | 217 | # Avg. predictions from both strands 218 | if len(predictions) == 2: 219 | avg_predictions = np.empty(predictions[0].shape) 220 | for i in range(predictions[0].shape[1]): 221 | avg_predictions[:, i] = np.mean([predictions[0][:, i], 222 | predictions[1][:, i]], axis=0) 223 | else: 224 | avg_predictions = predictions[0] 225 | if input_type == "binary": 226 | for i in range(avg_predictions.shape[1]): 227 | avg_predictions[:, i] = \ 228 | torch.sigmoid(torch.from_numpy(avg_predictions[:, i])).numpy() 229 | 230 | # Get well-predicted sequences 231 | if criterion == "correlation": 232 | correlations = [] 233 | for i in range(len(avg_predictions)): 234 | x = np.corrcoef(labels[i, :], avg_predictions[i, :])[0, 1] 235 | correlations.append(x) 236 | idx = np.argwhere(np.asarray(correlations) > threshold).squeeze() 237 | elif criterion == "exact_match": 238 | arr = np.round(avg_predictions) # round predictions 239 | arr = np.equal(arr, labels) 240 | idx = np.argwhere(np.sum(arr, axis=1) == labels.shape[1]).squeeze() 241 | elif criterion == "percentile_bottom": 242 | threshold = threshold / 100. 243 | arr_1 = np.argsort(labels.flatten())[:int(max(labels.shape)*threshold)] 244 | arr_2 = np.argsort(avg_predictions.flatten())[:int(max(avg_predictions.shape)*threshold)] 245 | idx = np.intersect1d(arr_1, arr_2) 246 | elif criterion == "percentile_top": 247 | threshold = threshold / 100. 248 | arr_1 = np.argsort(-labels.flatten())[:int(max(labels.shape)*threshold)] 249 | arr_2 = np.argsort(-avg_predictions.flatten())[:int(max(avg_predictions.shape)*threshold)] 250 | idx = np.intersect1d(arr_1, arr_2) 251 | if num_well_pred_seqs: 252 | rng = np.random.default_rng() 253 | size = min(num_well_pred_seqs, len(idx)) 254 | idx = rng.choice(idx, size=size, replace=False) 255 | 256 | # Get training DataLoader 257 | seqs = seqs[idx] 258 | labels = labels[idx] 259 | data_loader = get_data_loader(seqs, labels, batch_size) 260 | 261 | # Get rev. complement 262 | if rev_complement: 263 | rev_seqs = np.array([s[::-1, ::-1] for s in seqs]) 264 | rev_data_loader = get_data_loader(rev_seqs, labels, batch_size) 265 | else: 266 | rev_seqs = None 267 | rev_data_loader = None 268 | 269 | # Get activations 270 | for dl in [data_loader, rev_data_loader]: 271 | if dl is None: # skip 272 | continue 273 | acts = get_danq_activations(dl, model, device) 274 | activations.append(acts) 275 | if rev_complement: 276 | seqs = np.concatenate((seqs, rev_seqs)) 277 | activations = np.concatenate(activations) 278 | else: 279 | activations = activations[0] 280 | 281 | # Get MEMEs 282 | meme_file = os.path.join(output_dir, "filters.meme") 283 | if not os.path.exists(meme_file): 284 | pwms = get_pwms_explainn(activations, seqs, filter_size) 285 | pwm_to_meme(pwms, meme_file) 286 | 287 | # Missing filter nullification step!!! 288 | 289 | if __name__ == "__main__": 290 | cli() -------------------------------------------------------------------------------- /scripts/DanQ/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup 5 | import json 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | from sklearn.metrics import average_precision_score, roc_auc_score 10 | import sys 11 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 12 | os.pardir)) 13 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 14 | os.pardir, 15 | os.pardir)) 16 | import time 17 | import torch 18 | 19 | from explainn.interpretation.interpretation import get_explainn_predictions 20 | from explainn.models.networks import DanQ 21 | from utils import (get_file_handle, get_seqs_labels_ids, get_data_loader, 22 | get_device) 23 | 24 | CONTEXT_SETTINGS = { 25 | "help_option_names": ["-h", "--help"], 26 | } 27 | 28 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 29 | @click.argument( 30 | "model_file", 31 | type=click.Path(exists=True, resolve_path=True), 32 | ) 33 | @click.argument( 34 | "training_parameters_file", 35 | type=click.Path(exists=True, resolve_path=True), 36 | ) 37 | @click.argument( 38 | "test_file", 39 | type=click.Path(exists=True, resolve_path=True), 40 | ) 41 | @click.option( 42 | "-c", "--cpu-threads", 43 | help="Number of CPU threads to use.", 44 | type=int, 45 | default=1, 46 | show_default=True, 47 | ) 48 | @click.option( 49 | "-d", "--debugging", 50 | help="Debugging mode.", 51 | is_flag=True, 52 | ) 53 | @click.option( 54 | "-o", "--output-dir", 55 | help="Output directory.", 56 | type=click.Path(resolve_path=True), 57 | default="./", 58 | show_default=True, 59 | ) 60 | @click.option( 61 | "-t", "--time", 62 | help="Return the program's running execution time in seconds.", 63 | is_flag=True, 64 | ) 65 | @optgroup.group("Test") 66 | @optgroup.option( 67 | "--batch-size", 68 | help="Batch size.", 69 | type=int, 70 | default=100, 71 | show_default=True, 72 | ) 73 | 74 | def cli(**args): 75 | 76 | # Start execution 77 | start_time = time.time() 78 | 79 | # Initialize 80 | if not os.path.exists(args["output_dir"]): 81 | os.makedirs(args["output_dir"]) 82 | 83 | # Save exec. parameters as JSON 84 | json_file = os.path.join(args["output_dir"], 85 | f"parameters-{os.path.basename(__file__)}.json") 86 | handle = get_file_handle(json_file, "wt") 87 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 88 | handle.close() 89 | 90 | ############## 91 | # Load Data # 92 | ############## 93 | 94 | # Load training parameters 95 | handle = get_file_handle(args["training_parameters_file"], "rt") 96 | train_args = json.load(handle) 97 | handle.close() 98 | if "training_parameters_file" in train_args: # i.e. for fine-tuned models 99 | handle = get_file_handle(train_args["training_parameters_file"], "rt") 100 | train_args = json.load(handle) 101 | handle.close() 102 | 103 | # Get test sequences and labels 104 | seqs, labels, _ = get_seqs_labels_ids(args["test_file"], 105 | args["debugging"], 106 | False, 107 | train_args["input_length"]) 108 | 109 | ############## 110 | # Test # 111 | ############## 112 | 113 | # Infer input type, and the number of classes 114 | num_classes = labels[0].shape[0] 115 | if np.unique(labels[:, 0]).size == 2: 116 | input_type = "binary" 117 | else: 118 | input_type = "non-binary" 119 | 120 | # Get device 121 | device = get_device() 122 | 123 | # Get model 124 | m = DanQ(train_args["input_length"], num_classes, args["model_file"]) 125 | 126 | # Test 127 | _test(seqs, labels, m, device, input_type, train_args["rev_complement"], 128 | args["output_dir"], args["batch_size"]) 129 | 130 | # Finish execution 131 | seconds = format(time.time() - start_time, ".2f") 132 | if args["time"]: 133 | f = os.path.join(args["output_dir"], 134 | f"time-{os.path.basename(__file__)}.txt") 135 | handle = get_file_handle(f, "wt") 136 | handle.write(f"{seconds} seconds") 137 | handle.close() 138 | print(f"Execution time {seconds} seconds") 139 | 140 | def _test(seqs, labels, model, device, input_type, rev_complement, 141 | output_dir="./", batch_size=100): 142 | 143 | # Initialize 144 | predictions = [] 145 | model.to(device) 146 | model.eval() 147 | 148 | # Get training DataLoader 149 | data_loader = get_data_loader(seqs, labels, batch_size) 150 | 151 | # Get rev. complement 152 | if rev_complement: 153 | rev_seqs = np.array([s[::-1, ::-1] for s in seqs]) 154 | rev_data_loader = get_data_loader(rev_seqs, labels, batch_size) 155 | else: 156 | rev_seqs = None 157 | rev_data_loader = None 158 | 159 | for dl in [data_loader, rev_data_loader]: 160 | 161 | # Skip 162 | if dl is None: 163 | continue 164 | 165 | # Get predictions 166 | preds, labels = get_explainn_predictions(dl, model, device, 167 | isSigmoid=False) 168 | predictions.append(preds) 169 | 170 | # Avg. predictions from both strands 171 | if len(predictions) == 2: 172 | avg_predictions = np.empty(predictions[0].shape) 173 | for i in range(predictions[0].shape[1]): 174 | avg_predictions[:, i] = np.mean([predictions[0][:, i], 175 | predictions[1][:, i]], axis=0) 176 | else: 177 | avg_predictions = predictions[0] 178 | if input_type == "binary": 179 | for i in range(avg_predictions.shape[1]): 180 | avg_predictions[:, i] = \ 181 | torch.sigmoid(torch.from_numpy(avg_predictions[:, i])).numpy() 182 | 183 | # Get performance metrics 184 | metrics = __get_metrics(input_data=input_type) 185 | tsv_file = os.path.join(output_dir, "performance-metrics.tsv") 186 | if not os.path.exists(tsv_file): 187 | data = [] 188 | for m in metrics: 189 | data.append([m]) 190 | data[-1].append(metrics[m](labels, avg_predictions)) 191 | for i in range(labels.shape[1]): 192 | data[-1].append(metrics[m](labels[:, i], 193 | avg_predictions[:, i])) 194 | column_names = ["metric", "global"] + list(range(labels.shape[1])) 195 | df = pd.DataFrame(data, columns=column_names) 196 | df.to_csv(tsv_file, sep="\t", index=False) 197 | 198 | def __get_metrics(input_data="binary"): 199 | 200 | if input_data == "binary": 201 | return(dict(aucROC=roc_auc_score, aucPR=average_precision_score)) 202 | 203 | return(dict(Pearson=pearson_corrcoef)) 204 | 205 | def pearson_corrcoef(y_true, y_score): 206 | 207 | if y_true.ndim == 1: 208 | return np.corrcoef(y_true, y_score)[0, 1] 209 | else: 210 | if y_true.shape[1] == 1: 211 | return np.corrcoef(y_true, y_score)[0, 1] 212 | else: 213 | corrcoefs = [] 214 | for i in range(len(y_score)): 215 | x = np.corrcoef(y_true[i, :], y_score[i, :])[0, 1] 216 | corrcoefs.append(x) 217 | return np.mean(corrcoefs) 218 | 219 | if __name__ == "__main__": 220 | cli() -------------------------------------------------------------------------------- /scripts/DanQ/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup 5 | import json 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import shutil 10 | import sys 11 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 12 | os.pardir)) 13 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 14 | os.pardir, 15 | os.pardir)) 16 | import time 17 | import torch 18 | 19 | from explainn.train.train import train_explainn 20 | from explainn.utils.tools import pearson_loss 21 | from explainn.models.networks import DanQ 22 | from utils import (get_file_handle, get_seqs_labels_ids, get_data_loader, 23 | get_device) 24 | 25 | CONTEXT_SETTINGS = { 26 | "help_option_names": ["-h", "--help"], 27 | } 28 | 29 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 30 | @click.argument( 31 | "training_file", 32 | type=click.Path(exists=True, resolve_path=True), 33 | ) 34 | @click.argument( 35 | "validation_file", 36 | type=click.Path(exists=True, resolve_path=True), 37 | ) 38 | @click.option( 39 | "-c", "--cpu-threads", 40 | help="Number of CPU threads to use.", 41 | type=int, 42 | default=1, 43 | show_default=True, 44 | ) 45 | @click.option( 46 | "-d", "--debugging", 47 | help="Debugging mode.", 48 | is_flag=True, 49 | ) 50 | @click.option( 51 | "-o", "--output-dir", 52 | help="Output directory.", 53 | type=click.Path(resolve_path=True), 54 | default="./", 55 | show_default=True, 56 | ) 57 | @click.option( 58 | "-t", "--time", 59 | help="Return the program's running execution time in seconds.", 60 | is_flag=True, 61 | ) 62 | @optgroup.group("DanQ") 63 | @optgroup.option( 64 | "--input-length", 65 | help="Input length (for longer and shorter sequences, trim or add padding, i.e. Ns, up to the specified length).", 66 | type=int, 67 | required=True, 68 | ) 69 | @optgroup.group("Optimizer") 70 | @optgroup.option( 71 | "--criterion", 72 | help="Loss (objective) function to use. Select \"BCEWithLogits\" for binary or multi-class classification tasks (e.g. predict the binding of one or more TFs to a sequence), \"CrossEntropy\" for multi-class classification tasks wherein only one solution is possible (e.g. predict the species of origin of a sequence between human, mouse or zebrafish), \"MSE\" for regression tasks (e.g. predict probe intensity signals), \"Pearson\" also for regression tasks (e.g. modeling accessibility across 81 cell types), and \"PoissonNLL\" for modeling count data (e.g. total number of reads at ChIP-/ATAC-seq peaks).", 73 | type=click.Choice(["BCEWithLogits", "CrossEntropy", "MSE", "Pearson", "PoissonNLL"], case_sensitive=False), 74 | required=True 75 | ) 76 | @optgroup.option( 77 | "--lr", 78 | help="Learning rate.", 79 | type=float, 80 | default=0.0005, 81 | show_default=True, 82 | ) 83 | @optgroup.option( 84 | "--optimizer", 85 | help="`torch.optim.Optimizer` with which to minimize the loss during training.", 86 | type=click.Choice(["Adam", "SGD"], case_sensitive=False), 87 | default="Adam", 88 | show_default=True, 89 | ) 90 | @optgroup.group("Training") 91 | @optgroup.option( 92 | "--batch-size", 93 | help="Batch size.", 94 | type=int, 95 | default=100, 96 | show_default=True, 97 | ) 98 | @optgroup.option( 99 | "--checkpoint", 100 | help="How often to save checkpoints (e.g. 1 means that the model will be saved after each epoch; by default, i.e. 0, only the best model will be saved).", 101 | type=int, 102 | default=0, 103 | show_default=True, 104 | ) 105 | @optgroup.option( 106 | "--num-epochs", 107 | help="Number of epochs to train the model.", 108 | type=int, 109 | default=100, 110 | show_default=True, 111 | ) 112 | @optgroup.option( 113 | "--patience", 114 | help="Number of epochs to wait before stopping training if the validation loss does not improve.", 115 | type=int, 116 | default=10, 117 | show_default=True, 118 | ) 119 | @optgroup.option( 120 | "--rev-complement", 121 | help="Reverse and complement training sequences.", 122 | is_flag=True, 123 | ) 124 | @optgroup.option( 125 | "--trim-weights", 126 | help="Constrain output weights to be non-negative (i.e. to ease interpretation).", 127 | is_flag=True, 128 | ) 129 | 130 | def cli(**args): 131 | 132 | # Start execution 133 | start_time = time.time() 134 | 135 | # Initialize 136 | if not os.path.exists(args["output_dir"]): 137 | os.makedirs(args["output_dir"]) 138 | 139 | # Save exec. parameters as JSON 140 | json_file = os.path.join(args["output_dir"], 141 | f"parameters-{os.path.basename(__file__)}.json") 142 | handle = get_file_handle(json_file, "wt") 143 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 144 | handle.close() 145 | 146 | ############## 147 | # Load Data # 148 | ############## 149 | 150 | # Get training/test sequences and labels 151 | train_seqs, train_labels, _ = get_seqs_labels_ids(args["training_file"], 152 | args["debugging"], 153 | args["rev_complement"], 154 | args["input_length"]) 155 | test_seqs, test_labels, _ = get_seqs_labels_ids(args["validation_file"], 156 | args["debugging"], 157 | args["rev_complement"], 158 | args["input_length"]) 159 | 160 | # Get training/test DataLoaders 161 | train_loader = get_data_loader(train_seqs, train_labels, 162 | args["batch_size"], shuffle=True) 163 | test_loader = get_data_loader(test_seqs, test_labels, 164 | args["batch_size"], shuffle=True) 165 | 166 | ############## 167 | # Train # 168 | ############## 169 | 170 | # Infer input length/type, and the number of classes 171 | # input_length = train_seqs[0].shape[1] 172 | num_classes = train_labels[0].shape[0] 173 | 174 | # Get device 175 | device = get_device() 176 | 177 | # Get criterion 178 | if args["criterion"].lower() == "bcewithlogits": 179 | criterion = torch.nn.BCEWithLogitsLoss() 180 | elif args["criterion"].lower() == "crossentropy": 181 | criterion = torch.nn.CrossEntropyLoss() 182 | elif args["criterion"].lower() == "mse": 183 | criterion = torch.nn.MSELoss() 184 | elif args["criterion"].lower() == "pearson": 185 | criterion = pearson_loss 186 | elif args["criterion"].lower() == "poissonnll": 187 | criterion = torch.nn.PoissonNLLLoss() 188 | 189 | # Get model and optimizer 190 | m = DanQ(args["input_length"], num_classes) 191 | 192 | # Get optimizer 193 | o = _get_optimizer(args["optimizer"], m.parameters(), args["lr"]) 194 | 195 | # Train 196 | _train(train_loader, test_loader, m, device, criterion, o, 197 | args["num_epochs"], args["output_dir"], None, True, False, 198 | args["checkpoint"], args["patience"]) 199 | 200 | # Finish execution 201 | seconds = format(time.time() - start_time, ".2f") 202 | if args["time"]: 203 | f = os.path.join(args["output_dir"], 204 | f"time-{os.path.basename(__file__)}.txt") 205 | handle = get_file_handle(f, "wt") 206 | handle.write(f"{seconds} seconds") 207 | handle.close() 208 | print(f"Execution time {seconds} seconds") 209 | 210 | def _get_optimizer(optimizer, parameters, lr=0.0005): 211 | 212 | if optimizer.lower() == "adam": 213 | return torch.optim.Adam(parameters, lr=lr) 214 | elif optimizer.lower() == "sgd": 215 | return torch.optim.SGD(parameters, lr=lr) 216 | 217 | def _train(train_loader, test_loader, model, device, criterion, optimizer, 218 | num_epochs=100, output_dir="./", name_ind=None, verbose=False, 219 | trim_weights=False, checkpoint=0, patience=0): 220 | 221 | # Initialize 222 | model.to(device) 223 | 224 | # Train 225 | _, train_error, test_error = train_explainn(train_loader, test_loader, 226 | model, device, criterion, 227 | optimizer, num_epochs, 228 | output_dir, name_ind, 229 | verbose, trim_weights, 230 | checkpoint, patience) 231 | 232 | # Save losses 233 | df = pd.DataFrame(list(zip(train_error, test_error)), 234 | columns=["Train loss", "Validation loss"]) 235 | df.index += 1 236 | df.index.rename("Epoch", inplace=True) 237 | df.to_csv(os.path.join(output_dir, "losses.tsv"), sep="\t") 238 | 239 | if __name__ == "__main__": 240 | cli() -------------------------------------------------------------------------------- /scripts/DanQ/tsv2predictions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup 5 | import json 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import sys 10 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 11 | os.pardir)) 12 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 13 | os.pardir, 14 | os.pardir)) 15 | import time 16 | import torch 17 | from tqdm import tqdm 18 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 19 | 20 | from explainn.models.networks import DanQ 21 | from utils import (get_seqs_labels_ids, get_file_handle, get_data_loader, 22 | get_device) 23 | 24 | CONTEXT_SETTINGS = { 25 | "help_option_names": ["-h", "--help"], 26 | } 27 | 28 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 29 | @click.argument( 30 | "model_file", 31 | type=click.Path(exists=True, resolve_path=True) 32 | ) 33 | @click.argument( 34 | "training_parameters_file", 35 | type=click.Path(exists=True, resolve_path=True), 36 | ) 37 | @click.argument( 38 | "tsv_file", 39 | type=click.Path(exists=True, resolve_path=True), 40 | ) 41 | @click.option( 42 | "-c", "--cpu-threads", 43 | help="Number of CPU threads to use.", 44 | type=int, 45 | default=1, 46 | show_default=True, 47 | ) 48 | @click.option( 49 | "-d", "--debugging", 50 | help="Debugging mode.", 51 | is_flag=True, 52 | ) 53 | @click.option( 54 | "-o", "--output-dir", 55 | help="Output directory.", 56 | type=click.Path(resolve_path=True), 57 | default="./", 58 | show_default=True, 59 | ) 60 | @click.option( 61 | "-t", "--time", 62 | help="Return the program's running execution time in seconds.", 63 | is_flag=True, 64 | ) 65 | @optgroup.group("Predict") 66 | @optgroup.option( 67 | "--apply-sigmoid", 68 | help="Apply the logistic sigmoid function to outputs.", 69 | is_flag=True, 70 | ) 71 | @optgroup.option( 72 | "--batch-size", 73 | help="Batch size.", 74 | type=int, 75 | default=100, 76 | show_default=True, 77 | ) 78 | 79 | def cli(**args): 80 | 81 | # Start execution 82 | start_time = time.time() 83 | 84 | # Initialize 85 | if not os.path.exists(args["output_dir"]): 86 | os.makedirs(args["output_dir"]) 87 | 88 | # Save exec. parameters as JSON 89 | json_file = os.path.join(args["output_dir"], 90 | f"parameters-{os.path.basename(__file__)}.json") 91 | handle = get_file_handle(json_file, "wt") 92 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 93 | handle.close() 94 | 95 | ############## 96 | # Load Data # 97 | ############## 98 | 99 | # Load training parameters 100 | handle = get_file_handle(args["training_parameters_file"], "rt") 101 | train_args = json.load(handle) 102 | handle.close() 103 | if "training_parameters_file" in train_args: # i.e. for fine-tuned models 104 | handle = get_file_handle(train_args["training_parameters_file"], "rt") 105 | train_args = json.load(handle) 106 | handle.close() 107 | 108 | # Get test sequences and labels 109 | seqs, _, ids = get_seqs_labels_ids(args["tsv_file"], 110 | args["debugging"], 111 | False, 112 | train_args["input_length"]) 113 | 114 | ############## 115 | # Predict # 116 | ############## 117 | 118 | # Get device 119 | device = get_device() 120 | 121 | # Get model 122 | state_dict = torch.load(args["model_file"]) 123 | for k in reversed(state_dict.keys()): 124 | num_classes = state_dict[k].shape[0] 125 | break 126 | 127 | # Get model 128 | m = DanQ(train_args["input_length"], num_classes, args["model_file"]) 129 | 130 | # Test 131 | _predict(seqs, ids, num_classes, m, device, args["output_dir"], 132 | args["apply_sigmoid"], args["batch_size"]) 133 | 134 | # Finish execution 135 | seconds = format(time.time() - start_time, ".2f") 136 | if args["time"]: 137 | f = os.path.join(args["output_dir"], 138 | f"time-{os.path.basename(__file__)}.txt") 139 | handle = get_file_handle(f, "wt") 140 | handle.write(f"{seconds} seconds") 141 | handle.close() 142 | print(f"Execution time {seconds} seconds") 143 | 144 | def _predict(seqs, ids, num_classes, model, device, output_dir="./", 145 | apply_sigmoid=False, batch_size=100): 146 | 147 | # Initialize 148 | idx = 0 149 | predictions = np.empty((len(seqs), num_classes, 4)) 150 | model.to(device) 151 | model.eval() 152 | 153 | # Get training DataLoader 154 | data_loader = get_data_loader( 155 | seqs, 156 | np.array([s[::-1, ::-1] for s in seqs]), 157 | batch_size 158 | ) 159 | 160 | with torch.no_grad(): 161 | 162 | for fwd, rev in tqdm(iter(data_loader), total=len(data_loader), 163 | bar_format=bar_format): 164 | 165 | # Get strand-specific predictions 166 | fwd = np.expand_dims(model(fwd.to(device)).cpu().numpy(), axis=2) 167 | rev = np.expand_dims(model(rev.to(device)).cpu().numpy(), axis=2) 168 | 169 | # Combine predictions from both strands 170 | fwd_rev = np.concatenate((fwd, rev), axis=2) 171 | mean_fwd_rev = np.expand_dims(np.mean(fwd_rev, axis=2), axis=2) 172 | max_fwd_rev = np.expand_dims(np.max(fwd_rev, axis=2), axis=2) 173 | 174 | # Concatenate predictions for this batch 175 | p = np.concatenate((fwd, rev, mean_fwd_rev, max_fwd_rev), axis=2) 176 | predictions[idx:idx+fwd.shape[0]] = p 177 | 178 | # Index increase 179 | idx += fwd.shape[0] 180 | 181 | # Apply sigmoid 182 | if apply_sigmoid: 183 | predictions = torch.sigmoid(torch.Tensor(predictions)).numpy() 184 | 185 | # Get predictions 186 | tsv_file = os.path.join(output_dir, "predictions.tsv.gz") 187 | if not os.path.exists(tsv_file): 188 | dfs = [] 189 | for i in range(num_classes): 190 | p = predictions[:, i, :] 191 | df = pd.DataFrame(p, columns=["Fwd", "Rev", "Mean", "Max"]) 192 | df["SeqId"] = ids 193 | df["Class"] = i 194 | dfs.append(df) 195 | df = pd.concat(dfs)[["SeqId", "Class", "Fwd", "Rev", "Mean", "Max"]] 196 | df.reset_index(drop=True, inplace=True) 197 | df.to_csv(tsv_file, sep="\t", index=False) 198 | 199 | if __name__ == "__main__": 200 | cli() -------------------------------------------------------------------------------- /scripts/fasta2predictions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from Bio import SeqIO 4 | import click 5 | from click_option_group import optgroup 6 | import json 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | import sys 11 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 12 | os.pardir)) 13 | import time 14 | import torch 15 | from tqdm import tqdm 16 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 17 | 18 | from explainn.models.networks import ExplaiNN 19 | from utils import (_dna_one_hot_many, _resize_sequence, get_file_handle, 20 | get_data_loader, get_device) 21 | 22 | CONTEXT_SETTINGS = { 23 | "help_option_names": ["-h", "--help"], 24 | } 25 | 26 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 27 | @click.argument( 28 | "model_file", 29 | type=click.Path(exists=True, resolve_path=True) 30 | ) 31 | @click.argument( 32 | "training_parameters_file", 33 | type=click.Path(exists=True, resolve_path=True), 34 | ) 35 | @click.argument( 36 | "fasta_file", 37 | type=click.Path(exists=True, resolve_path=True), 38 | nargs=-1 39 | ) 40 | @click.option( 41 | "-c", "--cpu-threads", 42 | help="Number of CPU threads to use.", 43 | type=int, 44 | default=1, 45 | show_default=True, 46 | ) 47 | @click.option( 48 | "-o", "--output-dir", 49 | help="Output directory.", 50 | type=click.Path(resolve_path=True), 51 | default="./", 52 | show_default=True, 53 | ) 54 | @click.option( 55 | "-t", "--time", 56 | help="Return the program's running execution time in seconds.", 57 | is_flag=True, 58 | ) 59 | @optgroup.group("Predict") 60 | @optgroup.option( 61 | "--apply-sigmoid", 62 | help="Apply the logistic sigmoid function to outputs.", 63 | is_flag=True, 64 | ) 65 | @optgroup.option( 66 | "--batch-size", 67 | help="Batch size.", 68 | type=int, 69 | default=100, 70 | show_default=True, 71 | ) 72 | 73 | def cli(**args): 74 | 75 | # Start execution 76 | start_time = time.time() 77 | 78 | # Initialize 79 | if not os.path.exists(args["output_dir"]): 80 | os.makedirs(args["output_dir"]) 81 | 82 | # Save exec. parameters as JSON 83 | json_file = os.path.join(args["output_dir"], 84 | f"parameters-{os.path.basename(__file__)}.json") 85 | handle = get_file_handle(json_file, "wt") 86 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 87 | handle.close() 88 | 89 | ############## 90 | # Load Data # 91 | ############## 92 | 93 | # Load training parameters 94 | handle = get_file_handle(args["training_parameters_file"], "rt") 95 | train_args = json.load(handle) 96 | handle.close() 97 | if "training_parameters_file" in train_args: # i.e. for fine-tuned models 98 | handle = get_file_handle(train_args["training_parameters_file"], "rt") 99 | train_args = json.load(handle) 100 | handle.close() 101 | 102 | # Get sequences and ids 103 | seqs, ids = [], [] 104 | for fasta_file in args["fasta_file"]: 105 | fh = get_file_handle(fasta_file, "rt") 106 | records = list(SeqIO.parse(fh, "fasta")) 107 | fh.close() 108 | for record in records: 109 | seqs.append( 110 | _resize_sequence(str(record.seq), train_args["input_length"]) 111 | ) 112 | ids.append(record.id) 113 | seqs = _dna_one_hot_many(seqs) 114 | 115 | ############## 116 | # Predict # 117 | ############## 118 | 119 | # Get device 120 | device = get_device() 121 | 122 | # Get model 123 | state_dict = torch.load(args["model_file"]) 124 | for k in reversed(state_dict.keys()): 125 | num_classes = state_dict[k].shape[0] 126 | break 127 | 128 | # Get model 129 | m = ExplaiNN(train_args["num_units"], train_args["input_length"], 130 | num_classes, train_args["filter_size"], train_args["num_fc"], 131 | train_args["pool_size"], train_args["pool_stride"], 132 | args["model_file"]) 133 | 134 | # Test 135 | _predict(seqs, ids, num_classes, m, device, args["output_dir"], 136 | args["apply_sigmoid"], args["batch_size"]) 137 | 138 | # Finish execution 139 | seconds = format(time.time() - start_time, ".2f") 140 | if args["time"]: 141 | f = os.path.join(args["output_dir"], 142 | f"time-{os.path.basename(__file__)}.txt") 143 | handle = get_file_handle(f, "wt") 144 | handle.write(f"{seconds} seconds") 145 | handle.close() 146 | print(f"Execution time {seconds} seconds") 147 | 148 | def _predict(seqs, ids, num_classes, model, device, output_dir="./", 149 | apply_sigmoid=False, batch_size=100): 150 | 151 | # Initialize 152 | idx = 0 153 | predictions = np.empty((len(seqs), num_classes, 4)) 154 | model.to(device) 155 | model.eval() 156 | 157 | # Get training DataLoader 158 | data_loader = get_data_loader( 159 | seqs, 160 | np.array([s[::-1, ::-1] for s in seqs]), 161 | batch_size 162 | ) 163 | 164 | with torch.no_grad(): 165 | 166 | for fwd, rev in tqdm(iter(data_loader), total=len(data_loader), 167 | bar_format=bar_format): 168 | 169 | # Get strand-specific predictions 170 | fwd = np.expand_dims(model(fwd.to(device)).cpu().numpy(), axis=2) 171 | rev = np.expand_dims(model(rev.to(device)).cpu().numpy(), axis=2) 172 | 173 | # Combine predictions from both strands 174 | fwd_rev = np.concatenate((fwd, rev), axis=2) 175 | mean_fwd_rev = np.expand_dims(np.mean(fwd_rev, axis=2), axis=2) 176 | max_fwd_rev = np.expand_dims(np.max(fwd_rev, axis=2), axis=2) 177 | 178 | # Concatenate predictions for this batch 179 | p = np.concatenate((fwd, rev, mean_fwd_rev, max_fwd_rev), axis=2) 180 | predictions[idx:idx+fwd.shape[0]] = p 181 | 182 | # Index increase 183 | idx += fwd.shape[0] 184 | 185 | # Apply sigmoid 186 | if apply_sigmoid: 187 | predictions = torch.sigmoid(torch.Tensor(predictions)).numpy() 188 | 189 | # Get predictions 190 | tsv_file = os.path.join(output_dir, "predictions.tsv.gz") 191 | if not os.path.exists(tsv_file): 192 | dfs = [] 193 | for i in range(num_classes): 194 | p = predictions[:, i, :] 195 | df = pd.DataFrame(p, columns=["Fwd", "Rev", "Mean", "Max"]) 196 | df["SeqId"] = ids 197 | df["Class"] = i 198 | dfs.append(df) 199 | df = pd.concat(dfs)[["SeqId", "Class", "Fwd", "Rev", "Mean", "Max"]] 200 | df.reset_index(drop=True, inplace=True) 201 | df.to_csv(tsv_file, sep="\t", index=False) 202 | 203 | if __name__ == "__main__": 204 | cli() -------------------------------------------------------------------------------- /scripts/finetune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup 5 | import json 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import shutil 10 | import sys 11 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 12 | os.pardir)) 13 | import time 14 | import torch 15 | 16 | from explainn.train.train import train_explainn 17 | from explainn.utils.tools import pearson_loss 18 | from explainn.models.networks import ExplaiNN 19 | from utils import (get_file_handle, get_seqs_labels_ids, get_data_loader, 20 | get_device) 21 | 22 | CONTEXT_SETTINGS = { 23 | "help_option_names": ["-h", "--help"], 24 | } 25 | 26 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 27 | @click.argument( 28 | "model_file", 29 | type=click.Path(exists=True, resolve_path=True), 30 | ) 31 | @click.argument( 32 | "training_parameters_file", 33 | type=click.Path(exists=True, resolve_path=True), 34 | ) 35 | @click.argument( 36 | "training_file", 37 | type=click.Path(exists=True, resolve_path=True), 38 | ) 39 | @click.argument( 40 | "validation_file", 41 | type=click.Path(exists=True, resolve_path=True), 42 | ) 43 | @click.option( 44 | "-c", "--cpu-threads", 45 | help="Number of CPU threads to use.", 46 | type=int, 47 | default=1, 48 | show_default=True, 49 | ) 50 | @click.option( 51 | "-d", "--debugging", 52 | help="Debugging mode.", 53 | is_flag=True, 54 | ) 55 | @click.option( 56 | "-o", "--output-dir", 57 | help="Output directory.", 58 | type=click.Path(resolve_path=True), 59 | default="./", 60 | show_default=True, 61 | ) 62 | @click.option( 63 | "-t", "--time", 64 | help="Return the program's running execution time in seconds.", 65 | is_flag=True, 66 | ) 67 | @optgroup.group("Optimizer") 68 | @optgroup.option( 69 | "--criterion", 70 | help="Loss (objective) function to use. Select \"BCEWithLogits\" for binary or multi-class classification tasks (e.g. predict the binding of one or more TFs to a sequence), \"CrossEntropy\" for multi-class classification tasks wherein only one solution is possible (e.g. predict the species of origin of a sequence between human, mouse or zebrafish), \"MSE\" for regression tasks (e.g. predict probe intensity signals), \"Pearson\" also for regression tasks (e.g. modeling accessibility across 81 cell types), and \"PoissonNLL\" for modeling count data (e.g. total number of reads at ChIP-/ATAC-seq peaks).", 71 | type=click.Choice(["BCEWithLogits", "CrossEntropy", "MSE", "Pearson", "PoissonNLL"], case_sensitive=False), 72 | required=True 73 | ) 74 | @optgroup.option( 75 | "--lr", 76 | help="Learning rate.", 77 | type=float, 78 | default=5e-05, 79 | show_default=True, 80 | ) 81 | @optgroup.option( 82 | "--optimizer", 83 | help="`torch.optim.Optimizer` with which to minimize the loss during training.", 84 | type=click.Choice(["Adam", "SGD"], case_sensitive=False), 85 | default="Adam", 86 | show_default=True, 87 | ) 88 | @optgroup.group("Fine-tuning") 89 | @optgroup.option( 90 | "--batch-size", 91 | help="Batch size.", 92 | type=int, 93 | default=100, 94 | show_default=True, 95 | ) 96 | @optgroup.option( 97 | "--checkpoint", 98 | help="How often to save checkpoints (e.g. 1 means that the model will be saved after each epoch; by default, i.e. 0, only the best model will be saved).", 99 | type=int, 100 | default=0, 101 | show_default=True, 102 | ) 103 | @optgroup.option( 104 | "--freeze", 105 | help="Do not update the model weights during training.", 106 | is_flag=True, 107 | ) 108 | @optgroup.option( 109 | "--num-epochs", 110 | help="Number of epochs to train the model.", 111 | type=int, 112 | default=100, 113 | show_default=True, 114 | ) 115 | @optgroup.option( 116 | "--patience", 117 | help="Number of epochs to wait before stopping training if the validation loss does not improve.", 118 | type=int, 119 | default=10, 120 | show_default=True, 121 | ) 122 | @optgroup.option( 123 | "--rev-complement", 124 | help="Reverse and complement training sequences.", 125 | is_flag=True, 126 | ) 127 | @optgroup.option( 128 | "--trim-weights", 129 | help="Constrain output weights to be non-negative (i.e. to ease interpretation).", 130 | is_flag=True, 131 | ) 132 | 133 | def cli(**args): 134 | 135 | # Start execution 136 | start_time = time.time() 137 | 138 | # Initialize 139 | if not os.path.exists(args["output_dir"]): 140 | os.makedirs(args["output_dir"]) 141 | 142 | # Save exec. parameters as JSON 143 | json_file = os.path.join(args["output_dir"], 144 | f"parameters-{os.path.basename(__file__)}.json") 145 | handle = get_file_handle(json_file, "wt") 146 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 147 | handle.close() 148 | 149 | ############## 150 | # Load Data # 151 | ############## 152 | 153 | # Load training parameters 154 | handle = get_file_handle(args["training_parameters_file"], "rt") 155 | train_args = json.load(handle) 156 | handle.close() 157 | 158 | # Get training/test sequences and labels 159 | train_seqs, train_labels, _ = get_seqs_labels_ids(args["training_file"], 160 | args["debugging"], 161 | args["rev_complement"], 162 | train_args["input_length"]) 163 | test_seqs, test_labels, _ = get_seqs_labels_ids(args["validation_file"], 164 | args["debugging"], 165 | args["rev_complement"], 166 | train_args["input_length"]) 167 | 168 | # Get training/test DataLoaders 169 | train_loader = get_data_loader(train_seqs, train_labels, 170 | args["batch_size"], shuffle=True) 171 | test_loader = get_data_loader(test_seqs, test_labels, 172 | args["batch_size"], shuffle=True) 173 | 174 | # Load pre-trained state dict 175 | state_dict_pretrain = torch.load(args["model_file"]) 176 | 177 | ############## 178 | # Fine-tune # 179 | ############## 180 | 181 | # Infer input length/type, and the number of classes 182 | # input_length = train_seqs[0].shape[1] 183 | num_classes = train_labels[0].shape[0] 184 | 185 | # Get device 186 | device = get_device() 187 | 188 | # Get criterion 189 | if args["criterion"].lower() == "bcewithlogits": 190 | criterion = torch.nn.BCEWithLogitsLoss() 191 | elif args["criterion"].lower() == "crossentropy": 192 | criterion = torch.nn.CrossEntropyLoss() 193 | elif args["criterion"].lower() == "mse": 194 | criterion = torch.nn.MSELoss() 195 | elif args["criterion"].lower() == "pearson": 196 | criterion = pearson_loss 197 | elif args["criterion"].lower() == "poissonnll": 198 | criterion = torch.nn.PoissonNLLLoss() 199 | 200 | # Get model 201 | m = ExplaiNN(train_args["num_units"], train_args["input_length"], 202 | num_classes, train_args["filter_size"], 203 | train_args["num_fc"], train_args["pool_size"], 204 | train_args["pool_stride"]) 205 | 206 | # Get optimizer 207 | o = _get_optimizer(args["optimizer"], m.parameters(), args["lr"]) 208 | 209 | # Transfer learning 210 | state_dict = m.state_dict() 211 | for k in state_dict: 212 | if not k.startswith("final"): # do not transfer the final layer 213 | state_dict[k] = state_dict_pretrain[k] 214 | m.load_state_dict(state_dict) 215 | 216 | # Freeze weights 217 | if args["freeze"]: 218 | for n, p in m.named_parameters(): 219 | # Allow learning the weights of the final layer 220 | if not n.startswith("final") and p.requires_grad: 221 | p.requires_grad = False 222 | 223 | # Fine-tune 224 | _finetune(train_loader, test_loader, m, device, criterion, o, 225 | args["num_epochs"], args["output_dir"], None, True, False, 226 | args["checkpoint"], args["patience"]) 227 | 228 | # Finish execution 229 | seconds = format(time.time() - start_time, ".2f") 230 | if args["time"]: 231 | f = os.path.join(args["output_dir"], 232 | f"time-{os.path.basename(__file__)}.txt") 233 | handle = get_file_handle(f, "wt") 234 | handle.write(f"{seconds} seconds") 235 | handle.close() 236 | print(f"Execution time {seconds} seconds") 237 | 238 | def _get_optimizer(optimizer, parameters, lr=5e-05): 239 | 240 | if optimizer.lower() == "adam": 241 | return torch.optim.Adam(parameters, lr=lr) 242 | elif optimizer.lower() == "sgd": 243 | return torch.optim.SGD(parameters, lr=lr) 244 | 245 | def _finetune(train_loader, test_loader, model, device, criterion, optimizer, 246 | num_epochs=100, output_dir="./", name_ind=None, verbose=False, 247 | trim_weights=False, checkpoint=0, patience=0): 248 | 249 | # Initialize 250 | model.to(device) 251 | 252 | # Train 253 | _, train_error, test_error = train_explainn(train_loader, test_loader, 254 | model, device, criterion, 255 | optimizer, num_epochs, 256 | output_dir, name_ind, 257 | verbose, trim_weights, 258 | checkpoint, patience) 259 | 260 | # Save losses 261 | df = pd.DataFrame(list(zip(train_error, test_error)), 262 | columns=["Train loss", "Validation loss"]) 263 | df.index += 1 264 | df.index.rename("Epoch", inplace=True) 265 | df.to_csv(os.path.join(output_dir, "losses.tsv"), sep="\t") 266 | 267 | if __name__ == "__main__": 268 | cli() 269 | -------------------------------------------------------------------------------- /scripts/parsers/GRECO-BIT/afs+hts2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | from multiprocessing import Pool 6 | import os 7 | import pandas as pd 8 | import re 9 | import subprocess as sp 10 | import sys 11 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 12 | os.pardir, 13 | os.pardir)) 14 | from tqdm import tqdm 15 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 16 | 17 | from parsers.fastq2explainn import _to_ExplaiNN 18 | 19 | CONTEXT_SETTINGS = { 20 | "help_option_names": ["-h", "--help"], 21 | } 22 | 23 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 24 | @click.argument( 25 | "reads_dir", type=click.Path(exists=True, resolve_path=True) 26 | ) 27 | @click.option( 28 | "-c", "--cpu-threads", 29 | help="Number of CPU threads to use.", 30 | type=int, 31 | default=1, 32 | show_default=True, 33 | ) 34 | @click.option( 35 | "-o", "--output-dir", 36 | help="Output directory.", 37 | type=click.Path(resolve_path=True), 38 | default="./", 39 | show_default=True 40 | ) 41 | 42 | def main(**args): 43 | 44 | # Create output dir 45 | if not os.path.exists(args["output_dir"]): 46 | os.makedirs(args["output_dir"]) 47 | 48 | # Get reads files 49 | reads_files = [] 50 | for reads_file in os.listdir(args["reads_dir"]): 51 | reads_files.append(os.path.join(args["reads_dir"], reads_file)) 52 | 53 | # Group reads files from different cycles 54 | grouped_reads_files = {} 55 | for reads_file in sorted(reads_files): 56 | m = re.search("^(\S+@\S+@\S+)\.C\d\.5\w+\.3\w+@\S+", 57 | os.path.basename(reads_file)) 58 | grouped_reads_files.setdefault(m.group(1), list()) 59 | grouped_reads_files[m.group(1)].append(reads_file) 60 | grouped_reads_files = list(grouped_reads_files.values()) 61 | 62 | # Get ExplaiNN files 63 | kwargs = {"total": len(grouped_reads_files), "bar_format": bar_format} 64 | pool = Pool(args["cpu_threads"]) 65 | p = partial(_get_ExplaiNN_files, output_dir=args["output_dir"]) 66 | for _ in tqdm(pool.imap(p, grouped_reads_files), **kwargs): 67 | pass 68 | 69 | def _get_ExplaiNN_files(reads_files, output_dir="./"): 70 | """ 71 | Naming format: 72 | AKAP8L@silly-willy+clever-peter@Megaman.HOMER@motif-id123.pcm 73 | """ 74 | 75 | # Initialize 76 | prefix = None 77 | suffixes = [] 78 | 79 | # For each read file... 80 | for i, reads_file in enumerate(sorted(reads_files)): 81 | 82 | # Get prefix and sufix 83 | m = re.search("^(\S+@\S+@\S+)\.C\d\.5\w+\.3\w+@\S+\.(\S+)\.\S+\.fastq", 84 | os.path.basename(reads_file)) 85 | if prefix is None: 86 | prefix = m.group(1) 87 | suffixes.append(m.group(2)) 88 | 89 | # Get prefix 90 | prefix += "@%s" % "+".join(suffixes) 91 | 92 | # Create train and validation sets 93 | _to_ExplaiNN(clip_left=None, clip_right=None, dummy_dir="/tmp/", 94 | fastq_1=reads_files, fastq_2=[], output_dir=output_dir, 95 | prefix=prefix, random_seed=1714, splits=[80, 20, 0]) 96 | 97 | # Create set for PWM scoring 98 | validation_file = os.path.join(output_dir, f"{prefix}.validation.tsv.gz") 99 | test_file = os.path.join(output_dir, f"{prefix}.pwm-scoring.tsv.gz") 100 | df = pd.read_table(validation_file, header=None) 101 | zeros = [1.] + [0. for i in range(len(df.columns) - 3)] 102 | ones = [0. for i in range(len(df.columns) - 3)] + [1.] 103 | sub_df = df.iloc[:, 2:] 104 | zeros = sub_df[sub_df == zeros].dropna() 105 | ones = sub_df[sub_df == ones].dropna() 106 | zeros = df.loc[zeros.index].iloc[:, :2] 107 | zeros[2] = 0. 108 | ones = df.loc[ones.index].iloc[:, :2] 109 | ones[2] = 1. 110 | df = pd.concat((zeros, ones)).reset_index(drop=True) 111 | df.to_csv(test_file, sep="\t", header=False, index=False, 112 | compression="gzip") 113 | 114 | if __name__ == "__main__": 115 | main() -------------------------------------------------------------------------------- /scripts/parsers/GRECO-BIT/chs2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | from multiprocessing import Pool 6 | import os 7 | import pandas as pd 8 | from pybedtools import BedTool 9 | import re 10 | import subprocess as sp 11 | import sys 12 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 13 | os.pardir, 14 | os.pardir)) 15 | from tqdm import tqdm 16 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 17 | 18 | from parsers.bed2explainn import _to_ExplaiNN 19 | from utils import shuffle_string 20 | 21 | CONTEXT_SETTINGS = { 22 | "help_option_names": ["-h", "--help"], 23 | } 24 | 25 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 26 | @click.argument( 27 | "intervals_dir", type=click.Path(exists=True, resolve_path=True) 28 | ) 29 | @click.argument( 30 | "genome_file", type=click.Path(exists=True, resolve_path=True) 31 | ) 32 | @click.option( 33 | "-c", "--cpu-threads", 34 | help="Number of CPU threads to use.", 35 | type=int, 36 | default=1, 37 | show_default=True, 38 | ) 39 | @click.option( 40 | "-d", "--dummy-dir", 41 | help="Dummy directory.", 42 | type=click.Path(resolve_path=True), 43 | default="/tmp/", 44 | show_default=True 45 | ) 46 | @click.option( 47 | "-o", "--output-dir", 48 | help="Output directory.", 49 | type=click.Path(resolve_path=True), 50 | default="./", 51 | show_default=True 52 | ) 53 | 54 | def main(**args): 55 | 56 | # Create output dir 57 | if not os.path.exists(args["output_dir"]): 58 | os.makedirs(args["output_dir"]) 59 | 60 | # Get intervals files 61 | intervals_files = [] 62 | for intervals_file in os.listdir(args["intervals_dir"]): 63 | intervals_files.append(os.path.join(args["intervals_dir"], 64 | intervals_file)) 65 | 66 | # Get ExplaiNN files 67 | kwargs = {"total": len(intervals_files), "bar_format": bar_format} 68 | pool = Pool(args["cpu_threads"]) 69 | p = partial(_get_ExplaiNN_file, genome_file=args["genome_file"], 70 | dummy_dir=args["dummy_dir"], output_dir=args["output_dir"]) 71 | for _ in tqdm(pool.imap(p, intervals_files), **kwargs): 72 | pass 73 | 74 | def _get_ExplaiNN_file(intervals_file, genome_file, dummy_dir="/tmp/", 75 | output_dir="./"): 76 | """ 77 | Naming format: 78 | AKAP8L@silly-willy+clever-peter@Megaman.HOMER@motif-id123.pcm 79 | """ 80 | 81 | # Initialize 82 | prefix = None 83 | 84 | # Get prefix 85 | m = re.search("^(\S+@\S+@\S+)@\S+\.(\S+)\.(\S+)\.peaks$", 86 | os.path.basename(intervals_file)) 87 | prefix = m.group(1) + "@" + m.group(2) 88 | 89 | # Get DataFrame 90 | df = pd.read_table(intervals_file, header=0) 91 | df["START"] = df["abs_summit"] - 100 - 1 92 | df["END"] = df["abs_summit"] + 100 93 | 94 | # Get BED file 95 | intervals = [] 96 | bed_file = os.path.join(dummy_dir, 97 | "%s+%s+%s.bed" % (os.path.split(__file__)[1], 98 | str(os.getpid()), prefix)) 99 | for _, row in df.iterrows(): 100 | i = [row[0], row[1], row[2], row[8], row[4], "."] 101 | intervals.append("\t".join(map(str, i))) 102 | a = BedTool("\n".join([i for i in intervals]), from_string=True) 103 | a.saveas(bed_file) 104 | 105 | # To ExplaiNN 106 | _to_ExplaiNN([bed_file], genome_file, dummy_dir="/tmp/", 107 | output_dir=output_dir, prefix=prefix, random_seed=1714, 108 | splits=[80, 20, 0]) 109 | 110 | # Remove BED file 111 | os.remove(bed_file) 112 | 113 | # Create set for PWM scoring 114 | validation_file = os.path.join(output_dir, f"{prefix}.validation.tsv.gz") 115 | test_file = os.path.join(output_dir, f"{prefix}.pwm-scoring.tsv.gz") 116 | ones = pd.read_table(validation_file, header=None) 117 | zeros = ones.copy(deep=True) 118 | ones[2] = 1. 119 | zeros[0] += "_shuffled" 120 | zeros[1] = [shuffle_string(s, k=2, random_seed=1714) for s in ones[1].tolist()] 121 | zeros[2] = 0. 122 | df = pd.concat((ones, zeros)).reset_index(drop=True) 123 | df.to_csv(test_file, sep="\t", header=False, index=False, 124 | compression="gzip") 125 | 126 | if __name__ == "__main__": 127 | main() -------------------------------------------------------------------------------- /scripts/parsers/GRECO-BIT/hts2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | from multiprocessing import Pool 6 | import os 7 | import re 8 | import subprocess as sp 9 | from tqdm import tqdm 10 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 11 | 12 | from fastq2explainn import _to_ExplaiNN 13 | 14 | CONTEXT_SETTINGS = { 15 | "help_option_names": ["-h", "--help"], 16 | } 17 | 18 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 19 | @click.argument( 20 | "reads_dir", type=click.Path(exists=True, resolve_path=True) 21 | ) 22 | @click.option( 23 | "-c", "--cpu-threads", 24 | help="Number of CPU threads to use.", 25 | type=int, 26 | default=1, 27 | show_default=True, 28 | ) 29 | @click.option( 30 | "-o", "--output-dir", 31 | help="Output directory.", 32 | type=click.Path(resolve_path=True), 33 | default="./", 34 | show_default=True 35 | ) 36 | 37 | def main(**args): 38 | 39 | # Create output dir 40 | if not os.path.exists(args["output_dir"]): 41 | os.makedirs(args["output_dir"]) 42 | 43 | # Get reads files 44 | reads_files = [] 45 | for reads_file in os.listdir(args["reads_dir"]): 46 | reads_files.append(os.path.join(args["reads_dir"], reads_file)) 47 | 48 | # Group reads files from different cycles 49 | grouped_reads_files = {} 50 | for reads_file in sorted(reads_files): 51 | m = re.search("^(\S+@\S+@\S+)\.C\d\.5\w+\.3\w+@\S+", 52 | os.path.basename(reads_file)) 53 | grouped_reads_files.setdefault(m.group(1), list()) 54 | grouped_reads_files[m.group(1)].append(reads_file) 55 | grouped_reads_files = list(grouped_reads_files.values()) 56 | 57 | # Get ExplaiNN files 58 | kwargs = {"total": len(grouped_reads_files), "bar_format": bar_format} 59 | pool = Pool(args["cpu_threads"]) 60 | p = partial(_get_ExplaiNN_files, output_dir=args["output_dir"]) 61 | for _ in tqdm(pool.imap(p, grouped_reads_files), **kwargs): 62 | pass 63 | 64 | def _get_ExplaiNN_files(reads_files, output_dir="./"): 65 | """ 66 | Naming format: 67 | AKAP8L@silly-willy+clever-peter@Megaman.HOMER@motif-id123.pcm 68 | """ 69 | 70 | # Initialize 71 | prefix = None 72 | splits = [] 73 | prefixes = [] 74 | 75 | # For each read file... 76 | for i, reads_file in enumerate(sorted(reads_files)): 77 | 78 | # Get data splits, prefixes 79 | m = re.search("^(\S+@\S+@\S+)\.C\d\.5\w+\.3\w+@\S+\.(\S+)\.(\S+)\.fastq.gz$", 80 | os.path.basename(reads_file)) 81 | if prefix is None: 82 | prefix = m.group(1) 83 | if m.group(3) == "Train": 84 | splits = [100, 0, 0] 85 | elif m.group(3) == "Val": 86 | splits = [0, 100, 0] 87 | prefixes.append(m.group(2)) 88 | 89 | # To ExplaiNN 90 | prefix += "@%s" % "+".join(prefixes) 91 | _to_ExplaiNN(clip_left=None, clip_right=None, dummy_dir="/tmp/", 92 | fastq_1=reads_files, fastq_2=[], output_dir=output_dir, 93 | prefix=prefix, random_seed=1714, splits=splits) 94 | 95 | if __name__ == "__main__": 96 | main() -------------------------------------------------------------------------------- /scripts/parsers/GRECO-BIT/pbm2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | import gzip 6 | from multiprocessing import Pool 7 | import os 8 | import pandas as pd 9 | import re 10 | from sklearn.preprocessing import quantile_transform 11 | import subprocess as sp 12 | import sys 13 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 14 | os.pardir, 15 | os.pardir)) 16 | from tqdm import tqdm 17 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 18 | 19 | from utils import get_data_splits 20 | 21 | CONTEXT_SETTINGS = { 22 | "help_option_names": ["-h", "--help"], 23 | } 24 | 25 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 26 | @click.argument( 27 | "intensities_dir", type=click.Path(exists=True, resolve_path=True) 28 | ) 29 | @click.option( 30 | "-c", "--cpu-threads", 31 | help="Number of CPU threads to use.", 32 | type=int, 33 | default=1, 34 | show_default=True, 35 | ) 36 | @click.option( 37 | "-n", "--no-linker", 38 | help="Exclude the linker sequence.", 39 | is_flag=True, 40 | ) 41 | @click.option( 42 | "-o", "--output-dir", 43 | help="Output directory.", 44 | type=click.Path(resolve_path=True), 45 | default="./", 46 | show_default=True 47 | ) 48 | @click.option( 49 | "-q", "--quantile-normalize", 50 | help="Quantile normalize signal intensities.", 51 | is_flag=True 52 | ) 53 | 54 | def main(**args): 55 | 56 | # Create output dir 57 | if not os.path.exists(args["output_dir"]): 58 | os.makedirs(args["output_dir"]) 59 | 60 | # Get intensity files 61 | intensity_files = [] 62 | for intensity_file in os.listdir(args["intensities_dir"]): 63 | intensity_files.append(os.path.join(args["intensities_dir"], 64 | intensity_file)) 65 | 66 | # Get ExplaiNN files 67 | kwargs = {"total": len(intensity_files), "bar_format": bar_format} 68 | pool = Pool(args["cpu_threads"]) 69 | p = partial(_to_ExplaiNN, no_linker=args["no_linker"], 70 | output_dir=args["output_dir"], 71 | quantile_normalize=args["quantile_normalize"]) 72 | for _ in tqdm(pool.imap(p, intensity_files), **kwargs): 73 | pass 74 | 75 | def _to_ExplaiNN(intensity_file, no_linker=False, output_dir="./", 76 | quantile_normalize=False): 77 | 78 | # Initialize 79 | prefix = None 80 | 81 | # Get prefix 82 | m = re.search("^(\S+@\S+@\S+)\.5\w+@\S+\.(\S+)\.(\S+)\.tsv$", 83 | os.path.basename(intensity_file)) 84 | prefix = m.group(1) + "@" + m.group(2) 85 | 86 | # Get DataFrame 87 | df = pd.read_table(intensity_file, header=0) 88 | if quantile_normalize: 89 | df.iloc[:, 7] = quantile_transform(df.iloc[:, 7].to_numpy().reshape(-1, 1), 90 | n_quantiles=10, random_state=0, copy=True) 91 | if not no_linker: 92 | df["pbm_sequence"] += df["linker_sequence"] 93 | df = df[["id_probe", "pbm_sequence", "mean_signal_intensity"]] 94 | 95 | # Get data splits 96 | train, validation, _ = get_data_splits(df, [80, 20, 0], 1714) 97 | 98 | # Save TSV files 99 | if train is not None: 100 | if prefix is None: 101 | tsv_file = os.path.join(output_dir, "train.tsv.gz") 102 | else: 103 | tsv_file = os.path.join(output_dir, f"{prefix}.train.tsv.gz") 104 | train.to_csv(tsv_file, sep="\t", header=False, index=False, 105 | compression="gzip") 106 | if validation is not None: 107 | if prefix is None: 108 | tsv_file = os.path.join(output_dir, "validation.tsv.gz") 109 | else: 110 | tsv_file = os.path.join(output_dir, f"{prefix}.validation.tsv.gz") 111 | validation.to_csv(tsv_file, sep="\t", header=False, index=False, 112 | compression="gzip") 113 | 114 | # Create set for PWM scoring 115 | validation_file = os.path.join(output_dir, f"{prefix}.validation.tsv.gz") 116 | test_file = os.path.join(output_dir, f"{prefix}.pwm-scoring.tsv.gz") 117 | df = pd.read_table(validation_file, header=None) 118 | ones = df.nlargest(int(len(df) * .05), 2) 119 | ones[2] = 1. 120 | zeros = df.nsmallest(int(len(df) * .05), 2) 121 | zeros[2] = 0. 122 | df = pd.concat((ones, zeros)).reset_index(drop=True) 123 | df.to_csv(test_file, sep="\t", header=False, index=False, 124 | compression="gzip") 125 | 126 | if __name__ == "__main__": 127 | main() -------------------------------------------------------------------------------- /scripts/parsers/GRECO-BIT/sms2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | from multiprocessing import Pool 6 | import os 7 | import re 8 | import subprocess as sp 9 | import sys 10 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 11 | os.pardir, 12 | os.pardir)) 13 | from tqdm import tqdm 14 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 15 | 16 | from parsers.fastq2explainn import _to_ExplaiNN 17 | 18 | CONTEXT_SETTINGS = { 19 | "help_option_names": ["-h", "--help"], 20 | } 21 | 22 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 23 | @click.argument( 24 | "reads_dir", type=click.Path(exists=True, resolve_path=True) 25 | ) 26 | @click.option( 27 | "-c", "--cpu-threads", 28 | help="Number of CPU threads to use.", 29 | type=int, 30 | default=1, 31 | show_default=True, 32 | ) 33 | @click.option( 34 | "-o", "--output-dir", 35 | help="Output directory.", 36 | type=click.Path(resolve_path=True), 37 | default="./", 38 | show_default=True 39 | ) 40 | 41 | def main(**args): 42 | 43 | # Create output dir 44 | if not os.path.exists(args["output_dir"]): 45 | os.makedirs(args["output_dir"]) 46 | 47 | # Get reads files 48 | reads_files = [] 49 | for reads_file in os.listdir(args["reads_dir"]): 50 | reads_files.append(os.path.join(args["reads_dir"], reads_file)) 51 | 52 | # Get ExplaiNN files 53 | kwargs = {"total": len(reads_files), "bar_format": bar_format} 54 | pool = Pool(args["cpu_threads"]) 55 | p = partial(_get_ExplaiNN_files, output_dir=args["output_dir"]) 56 | for _ in tqdm(pool.imap(p, reads_files), **kwargs): 57 | pass 58 | 59 | def _get_ExplaiNN_files(reads_file, output_dir="./"): 60 | """ 61 | Naming format: 62 | AKAP8L@silly-willy+clever-peter@Megaman.HOMER@motif-id123.pcm 63 | """ 64 | 65 | # Initialize 66 | prefix = None 67 | 68 | # Get prefix 69 | m = re.search("^(\S+@\S+@\S+).5\w+\.3\w+@\S+\.(\S+)\.\S+\.fastq", 70 | os.path.basename(reads_file)) 71 | prefix = m.group(1) + "@" + m.group(2) 72 | 73 | # To ExplaiNN 74 | _to_ExplaiNN(clip_left=None, clip_right=None, dummy_dir="/tmp/", 75 | fastq_1=[reads_file], fastq_2=[], output_dir=output_dir, 76 | prefix=prefix, random_seed=1714, splits=[80, 20, 0]) 77 | 78 | # Create set for PWM scoring 79 | validation_file = os.path.join(output_dir, f"{prefix}.validation.tsv.gz") 80 | test_file = os.path.join(output_dir, f"{prefix}.pwm-scoring.tsv.gz") 81 | os.symlink(validation_file, test_file) 82 | 83 | if __name__ == "__main__": 84 | main() -------------------------------------------------------------------------------- /scripts/parsers/de-novo/hts2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | from multiprocessing import Pool 6 | import os 7 | import pandas as pd 8 | import subprocess as sp 9 | from tqdm import tqdm 10 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 11 | 12 | # Globals 13 | scripts_dir = os.path.dirname(os.path.realpath(__file__)) 14 | 15 | CONTEXT_SETTINGS = { 16 | "help_option_names": ["-h", "--help"], 17 | } 18 | 19 | from utils import get_file_handle 20 | 21 | # Globals 22 | scripts_dir = os.path.dirname(os.path.realpath(__file__)) 23 | 24 | CONTEXT_SETTINGS = { 25 | "help_option_names": ["-h", "--help"], 26 | } 27 | 28 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 29 | @click.argument( 30 | "reads_dir", type=click.Path(exists=True, resolve_path=True) 31 | ) 32 | @click.argument( 33 | "tsv_file", type=click.Path(exists=True, resolve_path=True) 34 | ) 35 | @click.option( 36 | "-o", "--output-dir", 37 | help="Output directory.", 38 | type=click.Path(resolve_path=True), 39 | default="./", 40 | show_default=True 41 | ) 42 | @click.option( 43 | "-t", "--threads", 44 | help="Threads to use.", 45 | type=int, 46 | default=1, 47 | show_default=True 48 | ) 49 | 50 | def main(**args): 51 | 52 | # Create output dir 53 | if not os.path.exists(args["output_dir"]): 54 | os.makedirs(args["output_dir"]) 55 | fastq_dir = os.path.join(args["output_dir"], "FASTQ") 56 | if not os.path.exists(fastq_dir): 57 | os.makedirs(fastq_dir) 58 | 59 | # Get reads files 60 | tfs_reads_files = [] 61 | df = pd.read_csv(args["tsv_file"], sep="\t") 62 | df = df.groupby("TF").first().reset_index() 63 | for _, row in df.iterrows(): 64 | if row["SRA"] is not None: 65 | tfs_reads_files.append([row["TF"], []]) 66 | for sra in row["SRA"].split(";"): 67 | reads_file = os.path.join(args["reads_dir"], f"{sra}.fastq.gz") 68 | fastq_file = os.path.join(fastq_dir, f"{sra}.fastq.gz") 69 | if not os.path.exists(fastq_file): 70 | cmd = f"fastp -i {reads_file} -o {fastq_file} -A -G -w 8" 71 | _ = sp.run([cmd], shell=True, cwd=scripts_dir, 72 | stdout=sp.DEVNULL, stderr=sp.DEVNULL) 73 | tfs_reads_files[-1][-1].append(fastq_file) 74 | 75 | # Get FASTA sequences 76 | kwargs = {"total": len(tfs_reads_files), "bar_format": bar_format} 77 | pool = Pool(args["threads"]) 78 | p = partial(_to_ExplaiNN, output_dir=args["output_dir"]) 79 | for _ in tqdm(pool.imap(p, tfs_reads_files), **kwargs): 80 | pass 81 | 82 | def _to_ExplaiNN(tf_reads_files, output_dir="./"): 83 | 84 | # Initialize 85 | base_dir = os.path.split(os.path.realpath(__file__))[0] 86 | tf, reads_files = tf_reads_files 87 | 88 | # To ExplaiNN 89 | cmd = "%s/fastq2explainn.py -o %s -p %s %s" % \ 90 | (base_dir, output_dir, tf, " ".join(reads_files)) 91 | _ = sp.run([cmd], shell=True, cwd=scripts_dir, stdout=sp.DEVNULL, 92 | stderr=sp.DEVNULL) 93 | 94 | if __name__ == "__main__": 95 | main() -------------------------------------------------------------------------------- /scripts/parsers/de-novo/matrix2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | from multiprocessing import Pool 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import pickle 10 | from pybedtools import BedTool 11 | from pybedtools.helpers import cleanup 12 | import re 13 | import subprocess as sp 14 | from tqdm import tqdm 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 18 | 19 | from utils import get_file_handle 20 | 21 | # Globals 22 | scripts_dir = os.path.dirname(os.path.realpath(__file__)) 23 | 24 | CONTEXT_SETTINGS = { 25 | "help_option_names": ["-h", "--help"], 26 | } 27 | 28 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 29 | @click.argument( 30 | "matrix_dir", type=click.Path(exists=True, resolve_path=True) 31 | ) 32 | @click.argument( 33 | "genome_file", type=click.Path(exists=True, resolve_path=True) 34 | ) 35 | @click.argument( 36 | "regions_idx", type=click.Path(exists=True, resolve_path=True) 37 | ) 38 | @click.option( 39 | "-d", "--dummy-dir", 40 | help="Dummy directory.", 41 | type=click.Path(resolve_path=True), 42 | default="/tmp/", 43 | show_default=True 44 | ) 45 | @click.option( 46 | "-o", "--output-dir", 47 | help="Output directory.", 48 | type=click.Path(resolve_path=True), 49 | default="./", 50 | show_default=True 51 | ) 52 | @click.option( 53 | "-t", "--threads", 54 | help="Threads to use.", 55 | type=int, 56 | default=1, 57 | show_default=True 58 | ) 59 | 60 | def main(**args): 61 | 62 | # Create output dir 63 | if not os.path.exists(args["output_dir"]): 64 | os.makedirs(args["output_dir"]) 65 | 66 | # Get already processed TFs 67 | tfs = set() 68 | for tsv_file in os.listdir(args["output_dir"]): 69 | m = re.search("^(\S+).(train|validation|test).tsv.gz$", tsv_file) 70 | tfs.add(m.group(1)) 71 | 72 | # Get matrix files 73 | matrix_files = [] 74 | for matrix_file in os.listdir(args["matrix_dir"]): 75 | m = re.search("^matrix2d.(\S+).ReMap.sparse.npz$", matrix_file) 76 | if m.group(1) not in tfs: 77 | matrix_files.append(os.path.join(args["matrix_dir"], matrix_file)) 78 | 79 | # Get regions idx 80 | handle = get_file_handle(args["regions_idx"], mode="rb") 81 | regions_idx = pickle.load(handle) 82 | handle.close() 83 | idx_regions = {v: k for k, v in regions_idx.items()} 84 | 85 | # Get FASTA sequences 86 | kwargs = {"total": len(matrix_files), "bar_format": bar_format} 87 | pool = Pool(args["threads"]) 88 | p = partial(_to_ExplaiNN, genome_file=args["genome_file"], 89 | idx_regions=idx_regions, dummy_dir=args["dummy_dir"], 90 | output_dir=args["output_dir"]) 91 | for _ in tqdm(pool.imap(p, matrix_files), **kwargs): 92 | pass 93 | 94 | def _to_ExplaiNN(matrix_file, genome_file, idx_regions, 95 | dummy_dir="/tmp/", output_dir="./"): 96 | 97 | # Initialize 98 | prefix = re.search("^matrix2d.(\S+).ReMap.sparse.npz$", 99 | os.path.split(matrix_file)[1]).group(1) 100 | 101 | # Load matrix 2D as numpy array 102 | matrix2d = np.load(matrix_file)["arr_0"] 103 | 104 | # Get ones and zeros 105 | matrix1d = np.nanmax(matrix2d, axis=0) 106 | ones = np.where(matrix1d == 1.)[0] 107 | zeros = np.where(matrix1d == 0.)[0] 108 | 109 | # Get BedTool objects (i.e. positive/negative sequences) 110 | b = BedTool("\n".join(["\t".join(map(str, idx_regions[i])) \ 111 | for i in ones]), from_string=True).sort() 112 | b.sequence(fi=genome_file) 113 | positive_file = os.path.join(dummy_dir, "%s_pos.fa" % prefix) 114 | b.save_seqs(positive_file) 115 | b = BedTool("\n".join(["\t".join(map(str, idx_regions[i])) \ 116 | for i in zeros]), from_string=True).sort() 117 | b.sequence(fi=genome_file) 118 | negative_file = os.path.join(dummy_dir, "%s_neg.fa" % prefix) 119 | b.save_seqs(negative_file) 120 | 121 | # Subsample negative sequences by %GC 122 | json_file = os.path.join(dummy_dir, "%s.json" % prefix) 123 | cmd = "./utils/match-seqs-by-gc.py -f -o %s %s %s" % \ 124 | (json_file, negative_file, positive_file) 125 | _ = sp.run([cmd], shell=True, cwd=scripts_dir, stdout=sp.DEVNULL, 126 | stderr=sp.DEVNULL) 127 | 128 | # To ExplaiNN 129 | cmd = "./json2explainn.py -o %s -p %s --test %s" % \ 130 | (output_dir, prefix, json_file) 131 | _ = sp.run([cmd], shell=True, cwd=scripts_dir, stdout=sp.DEVNULL, 132 | stderr=sp.DEVNULL) 133 | 134 | # Delete tmp files 135 | cleanup() 136 | os.remove(positive_file) 137 | os.remove(negative_file) 138 | os.remove(json_file) 139 | 140 | if __name__ == "__main__": 141 | main() -------------------------------------------------------------------------------- /scripts/parsers/de-novo/pbm2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | from multiprocessing import Pool 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | from sklearn.preprocessing import quantile_transform 10 | import subprocess as sp 11 | from tqdm import tqdm 12 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 13 | 14 | # Globals 15 | scripts_dir = os.path.dirname(os.path.realpath(__file__)) 16 | 17 | CONTEXT_SETTINGS = { 18 | "help_option_names": ["-h", "--help"], 19 | } 20 | 21 | from utils import get_file_handle 22 | 23 | # Globals 24 | scripts_dir = os.path.dirname(os.path.realpath(__file__)) 25 | 26 | CONTEXT_SETTINGS = { 27 | "help_option_names": ["-h", "--help"], 28 | } 29 | 30 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 31 | @click.argument( 32 | "signal_intensities_dir", type=click.Path(exists=True, resolve_path=True) 33 | ) 34 | @click.argument( 35 | "tsv_file", type=click.Path(exists=True, resolve_path=True) 36 | ) 37 | @click.option( 38 | "-o", "--output-dir", 39 | help="Output directory.", 40 | type=click.Path(resolve_path=True), 41 | default="./", 42 | show_default=True 43 | ) 44 | @click.option( 45 | "-t", "--threads", 46 | help="Threads to use.", 47 | type=int, 48 | default=1, 49 | show_default=True 50 | ) 51 | 52 | def main(**args): 53 | 54 | # Create output dir 55 | if not os.path.exists(args["output_dir"]): 56 | os.makedirs(args["output_dir"]) 57 | 58 | # Get signal intensity files 59 | tfs_sig_int_files = [] 60 | df = pd.read_csv(args["tsv_file"], sep="\t") 61 | df = df.groupby("TF").first().reset_index() 62 | sig_int_files = {} 63 | for sig_int_file in os.listdir(args["signal_intensities_dir"]): 64 | tf = sig_int_file.split("_")[0] 65 | sig_int_files.setdefault(tf, []) 66 | sig_int_files[tf].append( 67 | os.path.join(args["signal_intensities_dir"], sig_int_file) 68 | ) 69 | for _, row in df.iterrows(): 70 | if row["UniPROBE"] is not None: 71 | tfs_sig_int_files.append([row["TF"], 72 | [sorted(sig_int_files[tf])]]) 73 | 74 | # Get FASTA sequences 75 | kwargs = {"total": len(tfs_sig_int_files), "bar_format": bar_format} 76 | pool = Pool(args["threads"]) 77 | p = partial(_to_ExplaiNN, output_dir=args["output_dir"]) 78 | for _ in tqdm(pool.imap(p, tfs_sig_int_files), **kwargs): 79 | pass 80 | 81 | def _to_ExplaiNN(tfs_sig_int_files, output_dir="./"): 82 | 83 | # Initialize 84 | data_splits = ["train", "validation"] 85 | tf, sig_int_files = tfs_sig_int_files 86 | rng = np.random.RandomState(0) 87 | 88 | # For each data split, signal intensity file... 89 | for data_split, sign_int_file in zip(data_splits, sig_int_files[0]): 90 | 91 | # Read signal intensities 92 | df = pd.read_csv(sign_int_file, header=None, sep="\t") 93 | 94 | # Quantile normalize intensity signals 95 | df.iloc[:, 0] = quantile_transform(df.iloc[:, 0].to_numpy().reshape(-1, 1), 96 | n_quantiles=10, random_state=0, copy=True) 97 | df["Index"] = df.index 98 | df = df.iloc[:, ::-1] 99 | 100 | # Save sequences 101 | tsv_file = os.path.join(output_dir, f"{tf}.{data_split}.tsv.gz") 102 | df = df.sample(frac=1) 103 | df.to_csv(tsv_file, sep="\t", header=False, index=False, 104 | compression="gzip") 105 | 106 | if __name__ == "__main__": 107 | main() -------------------------------------------------------------------------------- /scripts/parsers/de-novo/sms2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | from multiprocessing import Pool 6 | import os 7 | import pandas as pd 8 | import subprocess as sp 9 | from tqdm import tqdm 10 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 11 | 12 | # Globals 13 | scripts_dir = os.path.dirname(os.path.realpath(__file__)) 14 | 15 | CONTEXT_SETTINGS = { 16 | "help_option_names": ["-h", "--help"], 17 | } 18 | 19 | from utils import get_file_handle 20 | 21 | # Globals 22 | scripts_dir = os.path.dirname(os.path.realpath(__file__)) 23 | 24 | CONTEXT_SETTINGS = { 25 | "help_option_names": ["-h", "--help"], 26 | } 27 | 28 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 29 | @click.argument( 30 | "reads_dir", type=click.Path(exists=True, resolve_path=True) 31 | ) 32 | @click.argument( 33 | "tsv_file", type=click.Path(exists=True, resolve_path=True) 34 | ) 35 | @click.option( 36 | "-o", "--output-dir", 37 | help="Output directory.", 38 | type=click.Path(resolve_path=True), 39 | default="./", 40 | show_default=True 41 | ) 42 | @click.option( 43 | "-t", "--threads", 44 | help="Threads to use.", 45 | type=int, 46 | default=1, 47 | show_default=True 48 | ) 49 | 50 | def main(**args): 51 | 52 | # Create output dir 53 | if not os.path.exists(args["output_dir"]): 54 | os.makedirs(args["output_dir"]) 55 | fastq_dir = os.path.join(args["output_dir"], "FASTQ") 56 | if not os.path.exists(fastq_dir): 57 | os.makedirs(fastq_dir) 58 | 59 | # Get reads files 60 | tfs_reads_files = [] 61 | df = pd.read_csv(args["tsv_file"], sep="\t") 62 | df = df.groupby("TF").first().reset_index() 63 | for _, row in df.iterrows(): 64 | if row["SRA"] is not None: 65 | tfs_reads_files.append([row["TF"], []]) 66 | for sra in row["SRA"].split(";"): 67 | reads_file = os.path.join(args["reads_dir"], f"{sra}.fastq.gz") 68 | fastq_file = os.path.join(fastq_dir, f"{sra}.fastq.gz") 69 | if not os.path.exists(fastq_file): 70 | cmd = f"fastp -i {reads_file} -o {fastq_file} -A -G -w 8" 71 | _ = sp.run([cmd], shell=True, cwd=scripts_dir, 72 | stdout=sp.DEVNULL, stderr=sp.DEVNULL) 73 | tfs_reads_files[-1][-1].append(fastq_file) 74 | 75 | # Get FASTA sequences 76 | kwargs = {"total": len(tfs_reads_files), "bar_format": bar_format} 77 | pool = Pool(args["threads"]) 78 | p = partial(_to_ExplaiNN, output_dir=args["output_dir"]) 79 | for _ in tqdm(pool.imap(p, tfs_reads_files), **kwargs): 80 | pass 81 | 82 | def _to_ExplaiNN(tf_reads_files, output_dir="./"): 83 | 84 | # Initialize 85 | base_dir = os.path.split(os.path.realpath(__file__))[0] 86 | tf, reads_files = tf_reads_files 87 | 88 | # To ExplaiNN 89 | script = f"{base_dir}/fastq2explainn.py" 90 | cmd = "%s --clip-left 7 --clip-right 64 -o %s -p %s %s" % (script, 91 | output_dir, tf, " ".join(reads_files)) 92 | _ = sp.run([cmd], shell=True, cwd=scripts_dir, stdout=sp.DEVNULL, 93 | stderr=sp.DEVNULL) 94 | 95 | if __name__ == "__main__": 96 | main() -------------------------------------------------------------------------------- /scripts/parsers/fasta2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from Bio import SeqIO 4 | import click 5 | import copy 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import random 10 | import re 11 | from sklearn.preprocessing import OneHotEncoder 12 | import sys 13 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 14 | os.pardir)) 15 | import subprocess as sp 16 | 17 | from utils import click_validator, get_file_handle, get_data_splits 18 | 19 | CONTEXT_SETTINGS = { 20 | "help_option_names": ["-h", "--help"], 21 | } 22 | 23 | def validate_click_options(context): 24 | 25 | # Check that the data splits add to 100 26 | v = sum(context.params["splits"]) 27 | if v != 100: 28 | raise click.BadParameter(f"data splits do not add to 100: {v}.") 29 | 30 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS, 31 | cls=click_validator(validate_click_options)) 32 | @click.argument( 33 | "fasta_file", 34 | type=click.Path(exists=True, resolve_path=True), 35 | nargs=-1, 36 | ) 37 | @click.option( 38 | "-d", "--dummy-dir", 39 | help="Dummy directory.", 40 | type=click.Path(resolve_path=True), 41 | default="/tmp/", 42 | show_default=True 43 | ) 44 | @click.option( 45 | "-n", "--non-standard", 46 | type=click.Choice(["skip", "shuffle", "mask"]), 47 | help="Skip, shuffle, or mask (i.e. convert to Ns) non-standard (i.e. non A, C, G, T) DNA, including lowercase nucleotides.", 48 | show_default=True 49 | ) 50 | @click.option( 51 | "-o", "--output-dir", 52 | help="Output directory.", 53 | type=click.Path(resolve_path=True), 54 | default="./", 55 | show_default=True 56 | ) 57 | @click.option( 58 | "-p", "--prefix", 59 | help="Output prefix.", 60 | type=str 61 | ) 62 | @click.option( 63 | "-r", "--random-seed", 64 | help="Random seed.", 65 | type=int, 66 | default=1714, 67 | show_default=True 68 | ) 69 | @click.option( 70 | "-s", "--splits", 71 | help="Training, validation and test data splits.", 72 | nargs=3, 73 | type=click.IntRange(0, 100), 74 | default=[80, 10, 10], 75 | show_default=True 76 | ) 77 | 78 | def cli(**args): 79 | 80 | # Create output dir 81 | if not os.path.exists(args["output_dir"]): 82 | os.makedirs(args["output_dir"]) 83 | 84 | # Get TSV files for ExplaiNN 85 | _to_ExplaiNN(args["fasta_file"], args["dummy_dir"], args["non_standard"], 86 | args["output_dir"], args["prefix"], args["random_seed"], args["splits"]) 87 | 88 | def _to_ExplaiNN(fasta_files, dummy_dir="/tmp/", non_standard=None, 89 | output_dir="./", prefix=None, random_seed=1714, 90 | splits=[80, 10, 10]): 91 | 92 | # Initialize 93 | data = [] 94 | regexp = re.compile(r"[^ACGT]+") 95 | 96 | # Ys 97 | enc = OneHotEncoder() 98 | arr = np.array(list(range(len(fasta_files)))).reshape(-1, 1) 99 | enc.fit(arr) 100 | ys = enc.transform(arr).toarray().tolist() 101 | 102 | # Get DataFrame 103 | for i, fasta_file in enumerate(fasta_files): 104 | handle = get_file_handle(fasta_file, "rt") 105 | for record in SeqIO.parse(handle, "fasta"): 106 | s = str(record.seq) 107 | y = ys[i] 108 | # Skip non-standard/lowercase 109 | if non_standard == "skip": 110 | if re.search(regexp, s): 111 | continue 112 | # Shuffle/mask non-standard/lowercase 113 | elif non_standard is not None: 114 | # 1) extract blocks of non-standard/lowercase nucleotides; 115 | # 2) either shuffle the nucleotides or create string of Ns; and 116 | # 3) put the nucleotides back 117 | l = list(s) 118 | for m in re.finditer(regexp, s): 119 | if non_standard == "shuffle": 120 | sublist = l[m.start():m.end()] 121 | random.shuffle(sublist) 122 | l[m.start():m.end()] = copy.copy(sublist) 123 | else: 124 | l[m.start():m.end()] = "N" * (m.end() - m.start()) 125 | s = "".join(l) 126 | data.append([record.id, s] + y) 127 | handle.close() 128 | df = pd.DataFrame(data, columns=list(range(len(data[0])))) 129 | df = df.groupby(1).max().reset_index() 130 | df = df.reindex(sorted(df.columns), axis=1) 131 | 132 | # Generate negative sequences by dinucleotide shuffling 133 | if df.shape[1] == 3: # i.e. only one class 134 | data = [] 135 | cwd = os.path.dirname(os.path.realpath(__file__)) 136 | dummy_file = os.path.join(dummy_dir, "%s+%s+%s.fa" % 137 | (os.path.split(__file__)[1], str(os.getpid()), prefix)) 138 | with open(dummy_file, "wt") as handle: 139 | for z in zip(df.iloc[:, 0]. tolist(), df.iloc[:, 1]. tolist()): 140 | handle.write(f">{z[0]}\n{z[1]}\n") 141 | cmd = "biasaway k -f %s -k 2 -e 1 > %s.biasaway" % (dummy_file, 142 | dummy_file) 143 | _ = sp.run([cmd], shell=True, cwd=cwd, stderr=sp.DEVNULL) 144 | for s in SeqIO.parse("%s.biasaway" % dummy_file, "fasta"): 145 | header = "%s::shuf" % s.description.split(" ")[-1] 146 | data.append([header, str(s.seq), 0.]) 147 | df2 = pd.DataFrame(data, columns=list(range(len(data[0])))) 148 | df = pd.concat((df, df2)) 149 | os.remove(dummy_file) 150 | os.remove("%s.biasaway" % dummy_file) 151 | 152 | # Get data splits 153 | train, validation, test = get_data_splits(df, splits, random_seed) 154 | 155 | # Save TSV files 156 | if train is not None: 157 | if prefix is None: 158 | tsv_file = os.path.join(output_dir, "train.tsv.gz") 159 | else: 160 | tsv_file = os.path.join(output_dir, f"{prefix}.train.tsv.gz") 161 | train.to_csv(tsv_file, sep="\t", header=False, index=False, 162 | compression="gzip") 163 | if validation is not None: 164 | if prefix is None: 165 | tsv_file = os.path.join(output_dir, "validation.tsv.gz") 166 | else: 167 | tsv_file = os.path.join(output_dir, f"{prefix}.validation.tsv.gz") 168 | validation.to_csv(tsv_file, sep="\t", header=False, index=False, 169 | compression="gzip") 170 | if test is not None: 171 | if prefix is None: 172 | tsv_file = os.path.join(output_dir, "test.tsv.gz") 173 | else: 174 | tsv_file = os.path.join(output_dir, f"{prefix}.test.tsv.gz") 175 | test.to_csv(tsv_file, sep="\t", header=False, index=False, 176 | compression="gzip") 177 | 178 | if __name__ == "__main__": 179 | cli() -------------------------------------------------------------------------------- /scripts/parsers/fastq2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from Bio import SeqIO 4 | import click 5 | import copy 6 | from itertools import zip_longest 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | import random 11 | import re 12 | from sklearn.preprocessing import OneHotEncoder 13 | import sys 14 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 15 | os.pardir)) 16 | 17 | from utils import (click_validator, get_file_handle, get_data_splits, 18 | shuffle_string) 19 | 20 | CONTEXT_SETTINGS = { 21 | "help_option_names": ["-h", "--help"], 22 | } 23 | 24 | def validate_click_options(context): 25 | 26 | # Check that the data splits add to 100 27 | v = sum(context.params["splits"]) 28 | if v != 100: 29 | raise click.BadParameter(f"data splits do not add to 100: {v}.") 30 | 31 | # Check that the FASTQ files are paired 32 | f1 = len(context.params["fastq_1"]) 33 | f2 = len(context.params["fastq_2"]) 34 | if f2 > 0 and f1 != f2: 35 | raise click.BadParameter("the FASTQ files are not paired: " + \ 36 | f"{f1} Read 1 files vs. " + \ 37 | f"{f2} Read 2 files.") 38 | 39 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS, 40 | cls=click_validator(validate_click_options)) 41 | @click.option( 42 | "--clip-left", 43 | help="Trim the leftmost n bases.", 44 | type=int, 45 | default=None, 46 | show_default=True 47 | ) 48 | @click.option( 49 | "--clip-right", 50 | help="Trim the rightmost n bases.", 51 | type=int, 52 | default=None, 53 | show_default=True 54 | ) 55 | @click.option( 56 | "-f1", "--fastq-1", 57 | help="Read 1 FASTQ file.", 58 | type=click.Path(exists=True, resolve_path=True), 59 | multiple=True, 60 | required=True 61 | ) 62 | @click.option( 63 | "-f2", "--fastq-2", 64 | help="Read 2 FASTQ file", 65 | type=click.Path(exists=True, resolve_path=True), 66 | multiple=True 67 | ) 68 | @click.option( 69 | "-n", "--non-standard", 70 | type=click.Choice(["skip", "shuffle", "mask"]), 71 | help="Skip, shuffle, or mask (i.e. convert to Ns) non-standard (i.e. non A, C, G, T) DNA, including lowercase nucleotides.", 72 | show_default=True 73 | ) 74 | @click.option( 75 | "-o", "--output-dir", 76 | help="Output directory.", 77 | type=click.Path(resolve_path=True), 78 | default="./", 79 | show_default=True 80 | ) 81 | @click.option( 82 | "-p", "--prefix", 83 | help="Output prefix.", 84 | type=str 85 | ) 86 | @click.option( 87 | "-r", "--random-seed", 88 | help="Random seed.", 89 | type=int, 90 | default=1714, 91 | show_default=True 92 | ) 93 | @click.option( 94 | "-s", "--splits", 95 | help="Training, validation and test data splits.", 96 | nargs=3, 97 | type=click.IntRange(0, 100), 98 | default=[80, 10, 10], 99 | show_default=True 100 | ) 101 | 102 | def cli(**args): 103 | 104 | # Create output dir 105 | if not os.path.exists(args["output_dir"]): 106 | os.makedirs(args["output_dir"]) 107 | 108 | # Get TSV files for ExplaiNN 109 | _to_ExplaiNN(args["clip_left"], args["clip_right"], args["fastq_1"], 110 | args["fastq_2"], args["non_standard"], args["output_dir"], 111 | args["prefix"], args["random_seed"], args["splits"]) 112 | 113 | def _to_ExplaiNN(clip_left=None, clip_right=None, fastq_1=[], fastq_2=[], 114 | non_standard=None, output_dir="./", prefix=None, 115 | random_seed=1714, splits=[80, 10, 10]): 116 | 117 | # Initialize 118 | data = [] 119 | if clip_right is not None: 120 | clip_right = -clip_right 121 | regexp = re.compile(r"[^ACGT]+") 122 | 123 | # Ys 124 | enc = OneHotEncoder() 125 | arr = np.array(list(range(len(fastq_1)))).reshape(-1, 1) 126 | enc.fit(arr) 127 | ys = enc.transform(arr).toarray().tolist() 128 | 129 | # Get DataFrame 130 | for i, fastq_files in enumerate(zip_longest(fastq_1, fastq_2)): 131 | for fastq_file in fastq_files: 132 | if fastq_file is None: 133 | continue 134 | handle = get_file_handle(fastq_file, "rt") 135 | for record in SeqIO.parse(handle, "fastq"): 136 | s = str(record.seq)[clip_left:clip_right] 137 | y = ys[i] 138 | # Skip non-standard/lowercase 139 | if non_standard == "skip": 140 | if re.search(regexp, s): 141 | continue 142 | # Shuffle/mask non-standard/lowercase 143 | elif non_standard is not None: 144 | # 1) extract blocks of non-standard/lowercase nucleotides; 145 | # 2) either shuffle the nucleotides or create string of Ns; and 146 | # 3) put the nucleotides back 147 | l = list(s) 148 | for m in re.finditer(regexp, s): 149 | if non_standard == "shuffle": 150 | sublist = l[m.start():m.end()] 151 | random.shuffle(sublist) 152 | l[m.start():m.end()] = copy.copy(sublist) 153 | else: 154 | l[m.start():m.end()] = "N" * (m.end() - m.start()) 155 | s = "".join(l) 156 | data.append([record.id, s] + y) 157 | handle.close() 158 | df = pd.DataFrame(data, columns=list(range(len(data[0])))) 159 | df = df.groupby(1).max().reset_index() 160 | df = df.reindex(sorted(df.columns), axis=1) 161 | 162 | # Generate negative sequences by dinucleotide shuffling 163 | if df.shape[1] == 3: # i.e. only one class 164 | data = [] 165 | for z in zip(df.iloc[:, 0].tolist(), df.iloc[:, 1].tolist()): 166 | s = shuffle_string(z[1], random_seed=random_seed) 167 | data.append([f"{z[0]}_shuff", s, 0.]) 168 | df2 = pd.DataFrame(data, columns=list(range(len(data[0])))) 169 | df = pd.concat((df, df2)) 170 | 171 | # Get data splits 172 | train, validation, test = get_data_splits(df, splits, random_seed) 173 | 174 | # Save TSV files 175 | if train is not None: 176 | if prefix is None: 177 | tsv_file = os.path.join(output_dir, "train.tsv.gz") 178 | else: 179 | tsv_file = os.path.join(output_dir, f"{prefix}.train.tsv.gz") 180 | train.to_csv(tsv_file, sep="\t", header=False, index=False, 181 | compression="gzip") 182 | if validation is not None: 183 | if prefix is None: 184 | tsv_file = os.path.join(output_dir, "validation.tsv.gz") 185 | else: 186 | tsv_file = os.path.join(output_dir, f"{prefix}.validation.tsv.gz") 187 | validation.to_csv(tsv_file, sep="\t", header=False, index=False, 188 | compression="gzip") 189 | if test is not None: 190 | if prefix is None: 191 | tsv_file = os.path.join(output_dir, "test.tsv.gz") 192 | else: 193 | tsv_file = os.path.join(output_dir, f"{prefix}.test.tsv.gz") 194 | test.to_csv(tsv_file, sep="\t", header=False, index=False, 195 | compression="gzip") 196 | 197 | if __name__ == "__main__": 198 | cli() -------------------------------------------------------------------------------- /scripts/parsers/json2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | import copy 5 | import json 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import re 10 | from sklearn.preprocessing import OneHotEncoder 11 | import sys 12 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 13 | os.pardir)) 14 | 15 | from utils import click_validator, get_file_handle, get_data_splits 16 | 17 | CONTEXT_SETTINGS = { 18 | "help_option_names": ["-h", "--help"], 19 | } 20 | 21 | def validate_click_options(context): 22 | 23 | # Check that the data splits add to 100 24 | v = sum(context.params["splits"]) 25 | if v != 100: 26 | raise click.BadParameter(f"data splits do not add to 100: {v}.") 27 | 28 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS, 29 | cls=click_validator(validate_click_options)) 30 | @click.argument( 31 | "json_file", 32 | type=click.Path(exists=True, resolve_path=True) 33 | ) 34 | @click.option( 35 | "-n", "--non-standard", 36 | type=click.Choice(["skip", "shuffle", "mask"]), 37 | help="Skip, shuffle, or mask (i.e. convert to Ns) non-standard (i.e. non A, C, G, T) DNA, including lowercase nucleotides.", 38 | show_default=True 39 | ) 40 | @click.option( 41 | "-o", "--output-dir", 42 | help="Output directory.", 43 | type=click.Path(resolve_path=True), 44 | default="./", 45 | show_default=True 46 | ) 47 | @click.option( 48 | "-p", "--prefix", 49 | help="Output prefix.", 50 | type=str 51 | ) 52 | @click.option( 53 | "-r", "--random-seed", 54 | help="Random seed.", 55 | type=int, 56 | default=1714, 57 | show_default=True 58 | ) 59 | @click.option( 60 | "-s", "--splits", 61 | help="Training, validation and test data splits.", 62 | nargs=3, 63 | type=click.IntRange(0, 100), 64 | default=[80, 10, 10], 65 | show_default=True 66 | ) 67 | 68 | def main(**args): 69 | # Create output dir 70 | if not os.path.exists(args["output_dir"]): 71 | os.makedirs(args["output_dir"]) 72 | 73 | # Get TSV files for ExplaiNN 74 | _to_ExplaiNN(args["json_file"], args["non_standard"], args["output_dir"], 75 | args["prefix"], args["random_seed"], args["splits"]) 76 | 77 | def _to_ExplaiNN(json_file, non_standard=None, output_dir="./", prefix=None, 78 | random_seed=1714, splits=[80, 10, 10]): 79 | 80 | # Initialize 81 | regexp = re.compile(r"[^ACGT]+") 82 | 83 | # Load JSON 84 | handle = get_file_handle(json_file, "rt") 85 | sequences = json.load(handle) 86 | handle.close() 87 | sequences.pop(0) 88 | 89 | # Ys 90 | enc = OneHotEncoder() 91 | arr = np.array(list(range(len(sequences[0]) - 1))).reshape(-1, 1) 92 | enc.fit(arr) 93 | ys = enc.transform(arr).toarray().tolist() 94 | 95 | # Get DataFrame 96 | data = [] 97 | for i in range(len(sequences)): 98 | for j in range(1, len(sequences[i])): 99 | s = sequences[i][j][1] 100 | # Skip non-standard/lowercase 101 | if non_standard == "skip": 102 | if re.search(regexp, s): 103 | continue 104 | # Shuffle/mask non-standard/lowercase 105 | elif non_standard is not None: 106 | # 1) extract blocks of non-standard/lowercase nucleotides; 107 | # 2) either shuffle the nucleotides or create string of Ns; and 108 | # 3) put the nucleotides back 109 | l = list(s) 110 | for m in re.finditer(regexp, s): 111 | if non_standard == "shuffle": 112 | sublist = l[m.start():m.end()] 113 | random.shuffle(sublist) 114 | l[m.start():m.end()] = copy.copy(sublist) 115 | else: 116 | l[m.start():m.end()] = "N" * (m.end() - m.start()) 117 | s = "".join(l) 118 | data.append([sequences[i][j][0], s] + ys[j - 1]) 119 | df = pd.DataFrame(data, columns=list(range(len(data[0])))) 120 | df = df.groupby(1).max().reset_index() 121 | df = df.reindex(sorted(df.columns), axis=1) 122 | 123 | # Get data splits 124 | train, validation, test = get_data_splits(df, splits, random_seed) 125 | 126 | # Save TSV files 127 | if train is not None: 128 | if prefix is None: 129 | tsv_file = os.path.join(output_dir, "train.tsv.gz") 130 | else: 131 | tsv_file = os.path.join(output_dir, f"{prefix}.train.tsv.gz") 132 | train.to_csv(tsv_file, sep="\t", header=False, index=False, 133 | compression="gzip") 134 | if validation is not None: 135 | if prefix is None: 136 | tsv_file = os.path.join(output_dir, "validation.tsv.gz") 137 | else: 138 | tsv_file = os.path.join(output_dir, f"{prefix}.validation.tsv.gz") 139 | validation.to_csv(tsv_file, sep="\t", header=False, index=False, 140 | compression="gzip") 141 | if test is not None: 142 | if prefix is None: 143 | tsv_file = os.path.join(output_dir, "test.tsv.gz") 144 | else: 145 | tsv_file = os.path.join(output_dir, f"{prefix}.test.tsv.gz") 146 | test.to_csv(tsv_file, sep="\t", header=False, index=False, 147 | compression="gzip") 148 | 149 | if __name__ == "__main__": 150 | main() -------------------------------------------------------------------------------- /scripts/parsers/pbm2explainn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | import gzip 6 | from multiprocessing import Pool 7 | import os 8 | import pandas as pd 9 | import re 10 | from sklearn.preprocessing import quantile_transform 11 | import subprocess as sp 12 | import sys 13 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 14 | os.pardir)) 15 | 16 | from utils import click_validator, get_data_splits 17 | 18 | CONTEXT_SETTINGS = { 19 | "help_option_names": ["-h", "--help"], 20 | } 21 | 22 | def validate_click_options(context): 23 | 24 | # Check that the data splits add to 100 25 | v = sum(context.params["splits"]) 26 | if v != 100: 27 | raise click.BadParameter(f"data splits do not add to 100: {v}.") 28 | 29 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS, 30 | cls=click_validator(validate_click_options)) 31 | @click.argument( 32 | "intensity_file", 33 | type=click.Path(exists=True, resolve_path=True), 34 | ) 35 | @click.option( 36 | "-n", "--no-linker", 37 | help="Exclude the linker sequence.", 38 | is_flag=True, 39 | ) 40 | @click.option( 41 | "-o", "--output-dir", 42 | help="Output directory.", 43 | type=click.Path(resolve_path=True), 44 | default="./", 45 | show_default=True 46 | ) 47 | @click.option( 48 | "-p", "--prefix", 49 | help="Output prefix.", 50 | type=str 51 | ) 52 | @click.option( 53 | "-q", "--quantile-normalize", 54 | help="Quantile normalize signal intensities.", 55 | is_flag=True 56 | ) 57 | @click.option( 58 | "-r", "--random-seed", 59 | help="Random seed.", 60 | type=int, 61 | default=1714, 62 | show_default=True 63 | ) 64 | @click.option( 65 | "-s", "--splits", 66 | help="Training, validation and test data splits.", 67 | nargs=3, 68 | type=click.IntRange(0, 100), 69 | default=[80, 10, 10], 70 | show_default=True 71 | ) 72 | 73 | def main(**args): 74 | 75 | # Create output dir 76 | if not os.path.exists(args["output_dir"]): 77 | os.makedirs(args["output_dir"]) 78 | 79 | # Get TSV files for ExplaiNN 80 | _to_ExplaiNN(args["intensity_file"], args["no_linker"], args["output_dir"], 81 | args["prefix"], args["quantile_normalize"], args["random_seed"], 82 | args["splits"]) 83 | 84 | def _to_ExplaiNN(intensity_file, no_linker=False, output_dir="./", prefix=None, 85 | quantile_normalize=False, random_seed=1714, 86 | splits=[80, 10, 10]): 87 | 88 | # Initialize 89 | data = [] 90 | 91 | # Get DataFrame 92 | df = pd.read_table(intensity_file, header=0) 93 | if quantile_normalize: 94 | df.iloc[:, 7] = quantile_transform(df.iloc[:, 7].to_numpy().reshape(-1, 1), 95 | n_quantiles=10, random_state=0, copy=True) 96 | if not no_linker: 97 | df["pbm_sequence"] += df["linker_sequence"] 98 | df = df[["id_probe", "pbm_sequence", "mean_signal_intensity"]].dropna() 99 | 100 | # Get data splits 101 | train, validation, test = get_data_splits(df, splits, random_seed) 102 | 103 | # Save TSV files 104 | if train is not None: 105 | if prefix is None: 106 | tsv_file = os.path.join(output_dir, "train.tsv.gz") 107 | else: 108 | tsv_file = os.path.join(output_dir, f"{prefix}.train.tsv.gz") 109 | train.to_csv(tsv_file, sep="\t", header=False, index=False, 110 | compression="gzip") 111 | if validation is not None: 112 | if prefix is None: 113 | tsv_file = os.path.join(output_dir, "validation.tsv.gz") 114 | else: 115 | tsv_file = os.path.join(output_dir, f"{prefix}.validation.tsv.gz") 116 | validation.to_csv(tsv_file, sep="\t", header=False, index=False, 117 | compression="gzip") 118 | if test is not None: 119 | if prefix is None: 120 | tsv_file = os.path.join(output_dir, "test.tsv.gz") 121 | else: 122 | tsv_file = os.path.join(output_dir, f"{prefix}.test.tsv.gz") 123 | test.to_csv(tsv_file, sep="\t", header=False, index=False, 124 | compression="gzip") 125 | 126 | if __name__ == "__main__": 127 | main() -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup 5 | import json 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | from sklearn.metrics import average_precision_score, roc_auc_score 10 | import sys 11 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 12 | os.pardir)) 13 | import time 14 | import torch 15 | 16 | from explainn.interpretation.interpretation import get_explainn_predictions 17 | from explainn.models.networks import ExplaiNN 18 | from utils import (get_file_handle, get_seqs_labels_ids, get_data_loader, 19 | get_device) 20 | 21 | CONTEXT_SETTINGS = { 22 | "help_option_names": ["-h", "--help"], 23 | } 24 | 25 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 26 | @click.argument( 27 | "model_file", 28 | type=click.Path(exists=True, resolve_path=True), 29 | ) 30 | @click.argument( 31 | "training_parameters_file", 32 | type=click.Path(exists=True, resolve_path=True), 33 | ) 34 | @click.argument( 35 | "test_file", 36 | type=click.Path(exists=True, resolve_path=True), 37 | ) 38 | @click.option( 39 | "-c", "--cpu-threads", 40 | help="Number of CPU threads to use.", 41 | type=int, 42 | default=1, 43 | show_default=True, 44 | ) 45 | @click.option( 46 | "-d", "--debugging", 47 | help="Debugging mode.", 48 | is_flag=True, 49 | ) 50 | @click.option( 51 | "-o", "--output-dir", 52 | help="Output directory.", 53 | type=click.Path(resolve_path=True), 54 | default="./", 55 | show_default=True, 56 | ) 57 | @click.option( 58 | "-t", "--time", 59 | help="Return the program's running execution time in seconds.", 60 | is_flag=True, 61 | ) 62 | @optgroup.group("Test") 63 | @optgroup.option( 64 | "--batch-size", 65 | help="Batch size.", 66 | type=int, 67 | default=100, 68 | show_default=True, 69 | ) 70 | 71 | def cli(**args): 72 | 73 | # Start execution 74 | start_time = time.time() 75 | 76 | # Initialize 77 | if not os.path.exists(args["output_dir"]): 78 | os.makedirs(args["output_dir"]) 79 | 80 | # Save exec. parameters as JSON 81 | json_file = os.path.join(args["output_dir"], 82 | f"parameters-{os.path.basename(__file__)}.json") 83 | handle = get_file_handle(json_file, "wt") 84 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 85 | handle.close() 86 | 87 | ############## 88 | # Load Data # 89 | ############## 90 | 91 | # Load training parameters 92 | handle = get_file_handle(args["training_parameters_file"], "rt") 93 | train_args = json.load(handle) 94 | handle.close() 95 | if "training_parameters_file" in train_args: # i.e. for fine-tuned models 96 | handle = get_file_handle(train_args["training_parameters_file"], "rt") 97 | train_args = json.load(handle) 98 | handle.close() 99 | 100 | # Get test sequences and labels 101 | seqs, labels, _ = get_seqs_labels_ids(args["test_file"], 102 | args["debugging"], 103 | False, 104 | train_args["input_length"]) 105 | 106 | ############## 107 | # Test # 108 | ############## 109 | 110 | # Infer input type, and the number of classes 111 | num_classes = labels[0].shape[0] 112 | if np.unique(labels[:, 0]).size == 2: 113 | input_type = "binary" 114 | else: 115 | input_type = "non-binary" 116 | 117 | # Get device 118 | device = get_device() 119 | 120 | # Get model 121 | m = ExplaiNN(train_args["num_units"], train_args["input_length"], 122 | num_classes, train_args["filter_size"], train_args["num_fc"], 123 | train_args["pool_size"], train_args["pool_stride"], 124 | args["model_file"]) 125 | 126 | # Test 127 | _test(seqs, labels, m, device, input_type, train_args["rev_complement"], 128 | args["output_dir"], args["batch_size"]) 129 | 130 | # Finish execution 131 | seconds = format(time.time() - start_time, ".2f") 132 | if args["time"]: 133 | f = os.path.join(args["output_dir"], 134 | f"time-{os.path.basename(__file__)}.txt") 135 | handle = get_file_handle(f, "wt") 136 | handle.write(f"{seconds} seconds") 137 | handle.close() 138 | print(f"Execution time {seconds} seconds") 139 | 140 | def _test(seqs, labels, model, device, input_type, rev_complement, 141 | output_dir="./", batch_size=100): 142 | 143 | # Initialize 144 | predictions = [] 145 | model.to(device) 146 | model.eval() 147 | 148 | # Get training DataLoader 149 | data_loader = get_data_loader(seqs, labels, batch_size) 150 | 151 | # Get rev. complement 152 | if rev_complement: 153 | rev_seqs = np.array([s[::-1, ::-1] for s in seqs]) 154 | rev_data_loader = get_data_loader(rev_seqs, labels, batch_size) 155 | else: 156 | rev_seqs = None 157 | rev_data_loader = None 158 | 159 | for dl in [data_loader, rev_data_loader]: 160 | 161 | # Skip 162 | if dl is None: 163 | continue 164 | 165 | # Get predictions 166 | preds, labels = get_explainn_predictions(dl, model, device, 167 | isSigmoid=False) 168 | predictions.append(preds) 169 | 170 | # Avg. predictions from both strands 171 | if len(predictions) == 2: 172 | avg_predictions = np.empty(predictions[0].shape) 173 | for i in range(predictions[0].shape[1]): 174 | avg_predictions[:, i] = np.mean([predictions[0][:, i], 175 | predictions[1][:, i]], axis=0) 176 | else: 177 | avg_predictions = predictions[0] 178 | if input_type == "binary": 179 | for i in range(avg_predictions.shape[1]): 180 | avg_predictions[:, i] = \ 181 | torch.sigmoid(torch.from_numpy(avg_predictions[:, i])).numpy() 182 | 183 | # Get performance metrics 184 | metrics = __get_metrics(input_data=input_type) 185 | tsv_file = os.path.join(output_dir, "performance-metrics.tsv") 186 | if not os.path.exists(tsv_file): 187 | data = [] 188 | column_names = ["metric"] 189 | for m in metrics: 190 | data.append([m]) 191 | if labels.shape[1] > 1: 192 | data[-1].append(metrics[m](labels, avg_predictions)) 193 | column_names = column_names + ["global"] 194 | for i in range(labels.shape[1]): 195 | data[-1].append(metrics[m](labels[:, i], 196 | avg_predictions[:, i])) 197 | if labels.shape[1] > 1: 198 | column_names = ["metric", "global"] + list(range(labels.shape[1])) 199 | else: 200 | column_names = ["metric"] + list(range(labels.shape[1])) 201 | df = pd.DataFrame(data, columns=column_names) 202 | df.to_csv(tsv_file, sep="\t", index=False) 203 | 204 | def __get_metrics(input_data="binary"): 205 | 206 | if input_data == "binary": 207 | return(dict(aucROC=roc_auc_score, aucPR=average_precision_score)) 208 | 209 | return(dict(Pearson=pearson_corrcoef)) 210 | 211 | def pearson_corrcoef(y_true, y_score): 212 | 213 | if y_true.ndim == 1: 214 | return np.corrcoef(y_true, y_score)[0, 1] 215 | else: 216 | if y_true.shape[1] == 1: 217 | return np.corrcoef(y_true, y_score)[0, 1] 218 | else: 219 | corrcoefs = [] 220 | for i in range(len(y_score)): 221 | x = np.corrcoef(y_true[i, :], y_score[i, :])[0, 1] 222 | corrcoefs.append(x) 223 | return np.mean(corrcoefs) 224 | 225 | if __name__ == "__main__": 226 | cli() -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup 5 | import json 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import shutil 10 | import sys 11 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 12 | os.pardir)) 13 | import time 14 | import torch 15 | 16 | from explainn.train.train import train_explainn 17 | from explainn.utils.tools import pearson_loss 18 | from explainn.models.networks import ExplaiNN 19 | from utils import (get_file_handle, get_seqs_labels_ids, get_data_loader, 20 | get_device) 21 | 22 | CONTEXT_SETTINGS = { 23 | "help_option_names": ["-h", "--help"], 24 | } 25 | 26 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 27 | @click.argument( 28 | "training_file", 29 | type=click.Path(exists=True, resolve_path=True), 30 | ) 31 | @click.argument( 32 | "validation_file", 33 | type=click.Path(exists=True, resolve_path=True), 34 | ) 35 | @click.option( 36 | "-c", "--cpu-threads", 37 | help="Number of CPU threads to use.", 38 | type=int, 39 | default=1, 40 | show_default=True, 41 | ) 42 | @click.option( 43 | "-d", "--debugging", 44 | help="Debugging mode.", 45 | is_flag=True, 46 | ) 47 | @click.option( 48 | "-o", "--output-dir", 49 | help="Output directory.", 50 | type=click.Path(resolve_path=True), 51 | default="./", 52 | show_default=True, 53 | ) 54 | @click.option( 55 | "-t", "--time", 56 | help="Return the program's running execution time in seconds.", 57 | is_flag=True, 58 | ) 59 | @optgroup.group("ExplaiNN") 60 | @optgroup.option( 61 | "--filter-size", 62 | help="Size of each unit's filter.", 63 | type=int, 64 | default=19, 65 | show_default=True, 66 | ) 67 | @optgroup.option( 68 | "--input-length", 69 | help="Input length (for longer and shorter sequences, trim or add padding, i.e. Ns, up to the specified length).", 70 | type=int, 71 | required=True, 72 | ) 73 | @optgroup.option( 74 | "--num-fc", 75 | help="Number of fully connected layers in each unit.", 76 | type=click.IntRange(0, 8, clamp=True), 77 | default=2, 78 | show_default=True, 79 | ) 80 | @optgroup.option( 81 | "--num-units", 82 | help="Number of independent units.", 83 | type=int, 84 | default=100, 85 | show_default=True, 86 | ) 87 | @optgroup.option( 88 | "--pool-size", 89 | help="Size of each unit's maxpooling layer.", 90 | type=int, 91 | default=7, 92 | show_default=True, 93 | ) 94 | @optgroup.option( 95 | "--pool-stride", 96 | help="Stride of each unit's maxpooling layer.", 97 | type=int, 98 | default=7, 99 | show_default=True, 100 | ) 101 | @optgroup.group("Optimizer") 102 | @optgroup.option( 103 | "--criterion", 104 | help="Loss (objective) function to use. Select \"BCEWithLogits\" for binary or multi-class classification tasks (e.g. predict the binding of one or more TFs to a sequence), \"CrossEntropy\" for multi-class classification tasks wherein only one solution is possible (e.g. predict the species of origin of a sequence between human, mouse or zebrafish), \"MSE\" for regression tasks (e.g. predict probe intensity signals), \"Pearson\" also for regression tasks (e.g. modeling accessibility across 81 cell types), and \"PoissonNLL\" for modeling count data (e.g. total number of reads at ChIP-/ATAC-seq peaks).", 105 | type=click.Choice(["BCEWithLogits", "CrossEntropy", "MSE", "Pearson", "PoissonNLL"], case_sensitive=False), 106 | required=True 107 | ) 108 | @optgroup.option( 109 | "--lr", 110 | help="Learning rate.", 111 | type=float, 112 | default=0.003, 113 | show_default=True, 114 | ) 115 | @optgroup.option( 116 | "--optimizer", 117 | help="`torch.optim.Optimizer` with which to minimize the loss during training.", 118 | type=click.Choice(["Adam", "SGD"], case_sensitive=False), 119 | default="Adam", 120 | show_default=True, 121 | ) 122 | @optgroup.group("Training") 123 | @optgroup.option( 124 | "--batch-size", 125 | help="Batch size.", 126 | type=int, 127 | default=100, 128 | show_default=True, 129 | ) 130 | @optgroup.option( 131 | "--checkpoint", 132 | help="How often to save checkpoints (e.g. 1 means that the model will be saved after each epoch; by default, i.e. 0, only the best model will be saved).", 133 | type=int, 134 | default=0, 135 | show_default=True, 136 | ) 137 | @optgroup.option( 138 | "--num-epochs", 139 | help="Number of epochs to train the model.", 140 | type=int, 141 | default=100, 142 | show_default=True, 143 | ) 144 | @optgroup.option( 145 | "--patience", 146 | help="Number of epochs to wait before stopping training if the validation loss does not improve.", 147 | type=int, 148 | default=10, 149 | show_default=True, 150 | ) 151 | @optgroup.option( 152 | "--rev-complement", 153 | help="Reverse and complement training sequences.", 154 | is_flag=True, 155 | ) 156 | @optgroup.option( 157 | "--trim-weights", 158 | help="Constrain output weights to be non-negative (i.e. to ease interpretation).", 159 | is_flag=True, 160 | ) 161 | 162 | def cli(**args): 163 | 164 | # Start execution 165 | start_time = time.time() 166 | 167 | # Initialize 168 | if not os.path.exists(args["output_dir"]): 169 | os.makedirs(args["output_dir"]) 170 | 171 | # Save exec. parameters as JSON 172 | json_file = os.path.join(args["output_dir"], 173 | f"parameters-{os.path.basename(__file__)}.json") 174 | handle = get_file_handle(json_file, "wt") 175 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 176 | handle.close() 177 | 178 | ############## 179 | # Load Data # 180 | ############## 181 | 182 | # Get training/test sequences and labels 183 | train_seqs, train_labels, _ = get_seqs_labels_ids(args["training_file"], 184 | args["debugging"], 185 | args["rev_complement"], 186 | args["input_length"]) 187 | test_seqs, test_labels, _ = get_seqs_labels_ids(args["validation_file"], 188 | args["debugging"], 189 | args["rev_complement"], 190 | args["input_length"]) 191 | 192 | # Get training/test DataLoaders 193 | train_loader = get_data_loader(train_seqs, train_labels, 194 | args["batch_size"], shuffle=True) 195 | test_loader = get_data_loader(test_seqs, test_labels, 196 | args["batch_size"], shuffle=True) 197 | 198 | ############## 199 | # Train # 200 | ############## 201 | 202 | # Infer input length/type, and the number of classes 203 | # input_length = train_seqs[0].shape[1] 204 | num_classes = train_labels[0].shape[0] 205 | 206 | # Get device 207 | device = get_device() 208 | 209 | # Get criterion 210 | if args["criterion"].lower() == "bcewithlogits": 211 | criterion = torch.nn.BCEWithLogitsLoss() 212 | elif args["criterion"].lower() == "crossentropy": 213 | criterion = torch.nn.CrossEntropyLoss() 214 | elif args["criterion"].lower() == "mse": 215 | criterion = torch.nn.MSELoss() 216 | elif args["criterion"].lower() == "pearson": 217 | criterion = pearson_loss 218 | elif args["criterion"].lower() == "poissonnll": 219 | criterion = torch.nn.PoissonNLLLoss() 220 | 221 | # Get model and optimizer 222 | m = ExplaiNN(args["num_units"], args["input_length"], num_classes, 223 | args["filter_size"], args["num_fc"], args["pool_size"], 224 | args["pool_stride"]) 225 | 226 | # Get optimizer 227 | o = _get_optimizer(args["optimizer"], m.parameters(), args["lr"]) 228 | 229 | # Train 230 | _train(train_loader, test_loader, m, device, criterion, o, 231 | args["num_epochs"], args["output_dir"], None, True, False, 232 | args["checkpoint"], args["patience"]) 233 | 234 | # Finish execution 235 | seconds = format(time.time() - start_time, ".2f") 236 | if args["time"]: 237 | f = os.path.join(args["output_dir"], 238 | f"time-{os.path.basename(__file__)}.txt") 239 | handle = get_file_handle(f, "wt") 240 | handle.write(f"{seconds} seconds") 241 | handle.close() 242 | print(f"Execution time {seconds} seconds") 243 | 244 | def _get_optimizer(optimizer, parameters, lr=0.0005): 245 | 246 | if optimizer.lower() == "adam": 247 | return torch.optim.Adam(parameters, lr=lr) 248 | elif optimizer.lower() == "sgd": 249 | return torch.optim.SGD(parameters, lr=lr) 250 | 251 | def _train(train_loader, test_loader, model, device, criterion, optimizer, 252 | num_epochs=100, output_dir="./", name_ind=None, verbose=False, 253 | trim_weights=False, checkpoint=0, patience=0): 254 | 255 | # Initialize 256 | model.to(device) 257 | 258 | # Train 259 | _, train_error, test_error = train_explainn(train_loader, test_loader, 260 | model, device, criterion, 261 | optimizer, num_epochs, 262 | output_dir, name_ind, 263 | verbose, trim_weights, 264 | checkpoint, patience) 265 | 266 | # Save losses 267 | df = pd.DataFrame(list(zip(train_error, test_error)), 268 | columns=["Train loss", "Validation loss"]) 269 | df.index += 1 270 | df.index.rename("Epoch", inplace=True) 271 | df.to_csv(os.path.join(output_dir, "losses.tsv"), sep="\t") 272 | 273 | if __name__ == "__main__": 274 | cli() 275 | -------------------------------------------------------------------------------- /scripts/tsv2predictions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup 5 | import json 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import sys 10 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 11 | os.pardir)) 12 | import time 13 | import torch 14 | from tqdm import tqdm 15 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 16 | 17 | from explainn.models.networks import ExplaiNN 18 | from utils import (get_seqs_labels_ids, get_file_handle, get_data_loader, 19 | get_device) 20 | 21 | CONTEXT_SETTINGS = { 22 | "help_option_names": ["-h", "--help"], 23 | } 24 | 25 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 26 | @click.argument( 27 | "model_file", 28 | type=click.Path(exists=True, resolve_path=True) 29 | ) 30 | @click.argument( 31 | "training_parameters_file", 32 | type=click.Path(exists=True, resolve_path=True), 33 | ) 34 | @click.argument( 35 | "tsv_file", 36 | type=click.Path(exists=True, resolve_path=True), 37 | ) 38 | @click.option( 39 | "-c", "--cpu-threads", 40 | help="Number of CPU threads to use.", 41 | type=int, 42 | default=1, 43 | show_default=True, 44 | ) 45 | @click.option( 46 | "-d", "--debugging", 47 | help="Debugging mode.", 48 | is_flag=True, 49 | ) 50 | @click.option( 51 | "-o", "--output-dir", 52 | help="Output directory.", 53 | type=click.Path(resolve_path=True), 54 | default="./", 55 | show_default=True, 56 | ) 57 | @click.option( 58 | "-t", "--time", 59 | help="Return the program's running execution time in seconds.", 60 | is_flag=True, 61 | ) 62 | @optgroup.group("Predict") 63 | @optgroup.option( 64 | "--apply-sigmoid", 65 | help="Apply the logistic sigmoid function to outputs.", 66 | is_flag=True, 67 | ) 68 | @optgroup.option( 69 | "--batch-size", 70 | help="Batch size.", 71 | type=int, 72 | default=100, 73 | show_default=True, 74 | ) 75 | 76 | def cli(**args): 77 | 78 | # Start execution 79 | start_time = time.time() 80 | 81 | # Initialize 82 | if not os.path.exists(args["output_dir"]): 83 | os.makedirs(args["output_dir"]) 84 | 85 | # Save exec. parameters as JSON 86 | json_file = os.path.join(args["output_dir"], 87 | f"parameters-{os.path.basename(__file__)}.json") 88 | handle = get_file_handle(json_file, "wt") 89 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 90 | handle.close() 91 | 92 | ############## 93 | # Load Data # 94 | ############## 95 | 96 | # Load training parameters 97 | handle = get_file_handle(args["training_parameters_file"], "rt") 98 | train_args = json.load(handle) 99 | handle.close() 100 | if "training_parameters_file" in train_args: # i.e. for fine-tuned models 101 | handle = get_file_handle(train_args["training_parameters_file"], "rt") 102 | train_args = json.load(handle) 103 | handle.close() 104 | 105 | # Get test sequences and labels 106 | seqs, _, ids = get_seqs_labels_ids(args["tsv_file"], 107 | args["debugging"], 108 | False, 109 | train_args["input_length"]) 110 | 111 | ############## 112 | # Predict # 113 | ############## 114 | 115 | # Get device 116 | device = get_device() 117 | 118 | # Get model 119 | state_dict = torch.load(args["model_file"]) 120 | for k in reversed(state_dict.keys()): 121 | num_classes = state_dict[k].shape[0] 122 | break 123 | 124 | # Get model 125 | m = ExplaiNN(train_args["num_units"], train_args["input_length"], 126 | num_classes, train_args["filter_size"], train_args["num_fc"], 127 | train_args["pool_size"], train_args["pool_stride"], 128 | args["model_file"]) 129 | 130 | # Test 131 | _predict(seqs, ids, num_classes, m, device, args["output_dir"], 132 | args["apply_sigmoid"], args["batch_size"]) 133 | 134 | # Finish execution 135 | seconds = format(time.time() - start_time, ".2f") 136 | if args["time"]: 137 | f = os.path.join(args["output_dir"], 138 | f"time-{os.path.basename(__file__)}.txt") 139 | handle = get_file_handle(f, "wt") 140 | handle.write(f"{seconds} seconds") 141 | handle.close() 142 | print(f"Execution time {seconds} seconds") 143 | 144 | def _predict(seqs, ids, num_classes, model, device, output_dir="./", 145 | apply_sigmoid=False, batch_size=100): 146 | 147 | # Initialize 148 | idx = 0 149 | predictions = np.empty((len(seqs), num_classes, 4)) 150 | model.to(device) 151 | model.eval() 152 | 153 | # Get training DataLoader 154 | data_loader = get_data_loader( 155 | seqs, 156 | np.array([s[::-1, ::-1] for s in seqs]), 157 | batch_size 158 | ) 159 | 160 | with torch.no_grad(): 161 | 162 | for fwd, rev in tqdm(iter(data_loader), total=len(data_loader), 163 | bar_format=bar_format): 164 | 165 | # Get strand-specific predictions 166 | fwd = np.expand_dims(model(fwd.to(device)).cpu().numpy(), axis=2) 167 | rev = np.expand_dims(model(rev.to(device)).cpu().numpy(), axis=2) 168 | 169 | # Combine predictions from both strands 170 | fwd_rev = np.concatenate((fwd, rev), axis=2) 171 | mean_fwd_rev = np.expand_dims(np.mean(fwd_rev, axis=2), axis=2) 172 | max_fwd_rev = np.expand_dims(np.max(fwd_rev, axis=2), axis=2) 173 | 174 | # Concatenate predictions for this batch 175 | p = np.concatenate((fwd, rev, mean_fwd_rev, max_fwd_rev), axis=2) 176 | predictions[idx:idx+fwd.shape[0]] = p 177 | 178 | # Index increase 179 | idx += fwd.shape[0] 180 | 181 | # Apply sigmoid 182 | if apply_sigmoid: 183 | predictions = torch.sigmoid(torch.Tensor(predictions)).numpy() 184 | 185 | # Get predictions 186 | tsv_file = os.path.join(output_dir, "predictions.tsv.gz") 187 | if not os.path.exists(tsv_file): 188 | dfs = [] 189 | for i in range(num_classes): 190 | p = predictions[:, i, :] 191 | df = pd.DataFrame(p, columns=["Fwd", "Rev", "Mean", "Max"]) 192 | df["SeqId"] = ids 193 | df["Class"] = i 194 | dfs.append(df) 195 | df = pd.concat(dfs)[["SeqId", "Class", "Fwd", "Rev", "Mean", "Max"]] 196 | df.reset_index(drop=True, inplace=True) 197 | df.to_csv(tsv_file, sep="\t", index=False) 198 | 199 | if __name__ == "__main__": 200 | cli() -------------------------------------------------------------------------------- /scripts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import click 2 | import gzip 3 | from functools import partial 4 | import numpy as np 5 | import pandas as pd 6 | import random 7 | from sklearn.model_selection import train_test_split 8 | from torch import cuda, Tensor 9 | from torch.utils.data import DataLoader, TensorDataset 10 | 11 | from explainn.utils.tools import dna_one_hot 12 | 13 | ############################################################################### 14 | def click_validator(validator): 15 | 16 | class Validator(click.Command): 17 | 18 | def make_context(self, *args, **kwargs): 19 | context = super(Validator, self).make_context(*args, **kwargs) 20 | validator(context) 21 | 22 | return(context) 23 | 24 | return Validator 25 | 26 | 27 | ############################################################################### 28 | def get_file_handle(file_name, mode): 29 | 30 | if file_name.endswith(".gz"): 31 | handle = gzip.open(file_name, mode) 32 | else: 33 | handle = open(file_name, mode) 34 | 35 | return handle 36 | 37 | 38 | ############################################################################### 39 | def get_data_splits(data, splits=[80, 10, 10], random_seed=1714): 40 | 41 | # Initialize 42 | data_splits = [None, None, None] 43 | train = splits[0] / 100. 44 | validation = splits[1] / 100. 45 | test = splits[2] / 100. 46 | 47 | if train == 1.: 48 | data_splits[0] = data 49 | elif validation == 1.: 50 | data_splits[1] = data 51 | elif test == 1.: 52 | data_splits[2] = data 53 | else: 54 | # Initialize 55 | p = partial(train_test_split, random_state=random_seed) 56 | if train == 0.: 57 | test_size = test 58 | data_splits[1], data_splits[2] = p(data, test_size=test_size) 59 | elif validation == 0.: 60 | test_size = test 61 | data_splits[0], data_splits[2] = p(data, test_size=test_size) 62 | elif test == 0.: 63 | test_size = validation 64 | data_splits[0], data_splits[1] = p(data, test_size=test_size) 65 | else: 66 | test_size = validation + test 67 | data_splits[0], data = p(data, test_size=test_size) 68 | test_size = test / (validation + test) 69 | data_splits[1], data_splits[2] = p(data, test_size=test_size) 70 | 71 | return data_splits 72 | 73 | 74 | ############################################################################### 75 | def get_seqs_labels_ids(tsv_file, debugging=False, rev_complement=False, 76 | input_length="infer from data"): 77 | 78 | # Sequences / labels / ids 79 | df = pd.read_table(tsv_file, header=None, comment="#") 80 | ids = df.pop(0).values 81 | if input_length != "infer from data": 82 | seqs = [_resize_sequence(s, input_length) for s in df.pop(1).values] 83 | else: 84 | seqs = df.pop(1).values 85 | seqs = _dna_one_hot_many(seqs) 86 | labels = df.values 87 | 88 | # Reverse complement 89 | if rev_complement: 90 | seqs = np.append(seqs, np.array([s[::-1, ::-1] for s in seqs]), axis=0) 91 | labels = np.append(labels, labels, axis=0) 92 | ids = np.append(ids, ids, axis=0) 93 | 94 | # Return 1,000 sequences 95 | if debugging: 96 | return seqs[:1000], labels[:1000], ids[:1000] 97 | 98 | return seqs, labels, ids 99 | 100 | def _resize_sequence(s, l): 101 | 102 | if len(s) < l: 103 | return s.center(l, "N") 104 | elif len(s) > l: 105 | start = (len(s)//2) - (l//2) 106 | return s[start:start+l] 107 | else: 108 | return s 109 | 110 | def _dna_one_hot_many(seqs): 111 | """One hot encodes a list of sequences.""" 112 | return(np.array([dna_one_hot(str(seq)) for seq in seqs])) 113 | 114 | 115 | ############################################################################### 116 | def get_data_loader(seqs, labels, batch_size=100, shuffle=False): 117 | 118 | # TensorDatasets 119 | dataset = TensorDataset(Tensor(seqs), Tensor(labels)) 120 | 121 | # Avoid Error: Expected more than 1 value per channel when training 122 | batch_size = _avoid_expect_more_than_1_value_per_channel(len(dataset), 123 | batch_size) 124 | 125 | return DataLoader(dataset, batch_size, shuffle=shuffle) 126 | 127 | def _avoid_expect_more_than_1_value_per_channel(n, batch_size): 128 | 129 | if n % batch_size == 1: 130 | return _avoid_expect_more_than_1_value_per_channel(n, batch_size - 1) 131 | 132 | return batch_size 133 | 134 | 135 | ############################################################################### 136 | def get_device(): 137 | 138 | # Initialize 139 | device = "cpu" 140 | free_mems = {} 141 | 142 | # Assign the freest GPU to device 143 | if cuda.is_available(): 144 | for i in range(cuda.device_count()): 145 | free_mem = cuda.mem_get_info(f"cuda:{i}")[0] 146 | free_mems.setdefault(free_mem, []) 147 | free_mems[free_mem].append(i) 148 | max_free_mem = max(free_mems.keys()) 149 | random.shuffle(free_mems[max_free_mem]) 150 | device = f"cuda:{free_mems[max_free_mem][0]}" 151 | 152 | return device 153 | 154 | ############################################################################### 155 | def shuffle_string(s, k=2, random_seed=1714): 156 | 157 | # Shuffle 158 | l = [s[i-k:i] for i in range(k, len(s)+k, k)] 159 | random.Random(random_seed).shuffle(l) 160 | 161 | return "".join(l) -------------------------------------------------------------------------------- /scripts/utils/fonts/Arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wassermanlab/ExplaiNN/4ed332abc610499f4761307eb7ac283115ca7314/scripts/utils/fonts/Arial.ttf -------------------------------------------------------------------------------- /scripts/utils/jaspar2logo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from Bio import motifs 5 | import logomaker 6 | from matplotlib import font_manager as fm 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import os 10 | import pandas as pd 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | 14 | # Specify font 15 | f = os.path.join(os.path.dirname(os.path.realpath(__file__)), "fonts", 16 | "Arial.ttf") 17 | prop = fm.FontProperties(fname=f) 18 | 19 | CONTEXT_SETTINGS = { 20 | "help_option_names": ["-h", "--help"], 21 | } 22 | 23 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 24 | @click.argument( 25 | "jaspar_file", 26 | type=click.Path(exists=True, resolve_path=True), 27 | ) 28 | @click.argument( 29 | "logo_file", 30 | type=click.Path(exists=True, resolve_path=True), 31 | ) 32 | @click.option( 33 | "-r", "--rev-complement", 34 | help="Plot the reverse complement logo.", 35 | is_flag=True, 36 | ) 37 | 38 | def cli(**args): 39 | 40 | # Get figure 41 | fig = get_figure(args["motif_file"], rc=args["rev_complement"]) 42 | 43 | # Save 44 | fig.savefig(args["logo_file"], bbox_inches="tight", pad_inches=0) 45 | 46 | def get_figure(motif_file, rc=False): 47 | 48 | # From https://biopython.readthedocs.io/en/latest/chapter_motifs.html 49 | m = motifs.read(open(motif_file), "jaspar") 50 | pwm = list(m.counts.normalize(pseudocounts=.5).values()) 51 | 52 | return(_get_figure(pwm, rc)) 53 | 54 | def _get_figure(pwm, rc=False): 55 | 56 | # From https://www.bioconductor.org/packages/release/bioc/html/seqLogo.html 57 | if rc: 58 | arr = np.array(pwm) 59 | pwm = np.flip(arr).tolist() 60 | IC = 2 + np.add.reduce(pwm * np.log2(pwm)) 61 | df = pd.DataFrame({ 62 | "pos": [i + 1 for i in range(len(IC))], 63 | "A": pwm[0] * IC, 64 | "C": pwm[1] * IC, 65 | "G": pwm[2] * IC, 66 | "T": pwm[3] * IC 67 | }) 68 | df = df.set_index("pos") 69 | 70 | # From https://logomaker.readthedocs.io/en/latest/examples.html 71 | fig, ax = plt.subplots(1, 1, figsize=(len(df)/2.0, 2)) 72 | logo = logomaker.Logo(df, ax=ax, show_spines=False) 73 | logo.style_spines(spines=["left", "bottom"], visible=True) 74 | logo.ax.set_aspect(1.5) 75 | logo.ax.xaxis.set_ticks(list(df.index)) 76 | logo.ax.set_xticklabels(labels=list(df.index), fontproperties=prop) 77 | logo.ax.set_ylabel("Bits", fontproperties=prop) 78 | logo.ax.set_ylim(0, 2) 79 | logo.ax.yaxis.set_ticks([0, 1, 2]) 80 | logo.ax.set_yticklabels(labels=[0, 1, 2], fontproperties=prop) 81 | 82 | return(fig) 83 | 84 | if __name__ == "__main__": 85 | cli() -------------------------------------------------------------------------------- /scripts/utils/match-seqs-by-gc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from Bio import SeqIO 4 | from Bio.Seq import Seq 5 | from Bio.SeqUtils import gc_fraction 6 | import click 7 | import copy 8 | import importlib 9 | import json 10 | import random 11 | import re 12 | import sys 13 | 14 | lib = importlib.import_module("subsample-seqs-by-gc") 15 | 16 | CONTEXT_SETTINGS = { 17 | "help_option_names": ["-h", "--help"], 18 | } 19 | 20 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 21 | @click.argument( 22 | "fasta_file", 23 | type=click.Path(exists=True, resolve_path=True), 24 | nargs=-1, 25 | ) 26 | @click.option( 27 | "-d", "--dna", 28 | type=click.Choice(["uppercase", "lowercase"]), 29 | help="DNA to transform.", 30 | show_default=True, 31 | default="lowercase" 32 | ) 33 | @click.option( 34 | "-o", "--output-file", 35 | help="Output file (in JSON format). [default: STDOUT]", 36 | type=click.Path(writable=True, readable=False, resolve_path=True, 37 | allow_dash=True), 38 | ) 39 | @click.option( 40 | "-r", "--random-seed", 41 | help="Random seed.", 42 | type=int, 43 | default=1714, 44 | show_default=True 45 | ) 46 | @click.option( 47 | "-s", "--subsample", 48 | help="Number of sequences to subsample.", 49 | type=click.IntRange(min=0), 50 | default=0, 51 | show_default=True, 52 | ) 53 | @click.option( 54 | "-t", "--transform", 55 | type=click.Choice(["skip", "shuffle", "mask"]), 56 | help="Skip, shuffle, or mask (i.e. convert to Ns) DNA.", 57 | show_default=True 58 | ) 59 | 60 | def cli(**args): 61 | 62 | sampled_seqs = match_seqs_by_GC(args["fasta_file"], args["dna"], 63 | args["random_seed"], args["subsample"], 64 | args["transform"]) 65 | 66 | # Write 67 | if args["output_file"] is not None: 68 | handle = open(args["output_file"], "wt") 69 | else: 70 | handle = sys.stdout 71 | json.dump(sampled_seqs, handle, indent=4, sort_keys=True) 72 | handle.close() 73 | 74 | def match_seqs_by_GC(fasta_files, dna="lowercase", random_seed=1714, 75 | subsample=0, transform=None): 76 | 77 | # Group sequences based on their %GC content 78 | gc_groups = _get_GC_groups(fasta_files, dna, transform) 79 | 80 | # Match sequences based on their %GC content 81 | matched_seqs = _match_seqs_by_GC(gc_groups, random_seed) 82 | 83 | # Subsample sequences based on their %GC content 84 | if subsample: 85 | sampled_seqs = lib._subsample_seqs_by_GC(matched_seqs, 86 | random_seed, 87 | subsample) 88 | else: 89 | sampled_seqs = matched_seqs 90 | sampled_seqs.insert(0, ["labels"] + list(fasta_files)) 91 | 92 | return(sampled_seqs) 93 | 94 | def _get_GC_groups(fasta_files, dna="lowercase", transform=None): 95 | 96 | # Initialize 97 | gc_groups = {} 98 | if dna == "lowercase": 99 | regexp = re.compile(r"[^ACGT]+") 100 | else: 101 | regexp = re.compile(r"[^acgt]+") 102 | 103 | # For each FASTA file 104 | for i in range(len(fasta_files)): 105 | 106 | fasta_file = fasta_files[i] 107 | 108 | # For each SeqRecord... 109 | for record in SeqIO.parse(fasta_file, "fasta"): 110 | 111 | gc = round(gc_fraction(record.seq)*100) 112 | 113 | if transform: 114 | 115 | s = str(record.seq) 116 | 117 | # Skip 118 | if transform == "skip": 119 | if re.search(regexp, s): 120 | continue 121 | 122 | # Shuffle/Mask 123 | else: 124 | # 1) extract blocks of nucleotides matching regexp; 125 | # 2) either shuffle them or create string of Ns; 126 | # and 3) put the nucleotides back 127 | l = list(s) 128 | for m in re.finditer(regexp, s): 129 | if transform == "shuffle": 130 | sublist = l[m.start():m.end()] 131 | random.shuffle(sublist) 132 | l[m.start():m.end()] = copy.copy(sublist) 133 | else: 134 | l[m.start():m.end()] = "N" * (m.end() - m.start()) 135 | record.seq = Seq("".join(l)) 136 | 137 | # Group SeqRecords based on their %GC content 138 | gc_groups.setdefault(gc, [[] for i in range(len(fasta_files))]) 139 | gc_groups[gc][i].append(record) 140 | 141 | return(gc_groups) 142 | 143 | def _match_seqs_by_GC(gc_groups, random_seed=1714): 144 | 145 | # Initialize 146 | matched_seqs = [] 147 | 148 | # For each %GC content group... 149 | for i in sorted(gc_groups): 150 | 151 | # For each set of sequences... 152 | for j in range(len(gc_groups[i])): 153 | 154 | # Shuffle sequences 155 | random.Random(random_seed).shuffle(gc_groups[i][j]) 156 | 157 | # Get the smallest number of sequences in %GC content group 158 | min_len = min(len(gc_groups[i][j]) for j in range(len(gc_groups[i]))) 159 | 160 | # Sequence counter 161 | for k in range(min_len): 162 | 163 | matched_seqs.append([i]) 164 | 165 | # For each set of sequences... 166 | for j in range(len(gc_groups[i])): 167 | 168 | record = gc_groups[i][j][k] 169 | matched_seqs[-1].extend([[record.id, str(record.seq)]]) 170 | 171 | return(matched_seqs) 172 | 173 | if __name__ == "__main__": 174 | cli() -------------------------------------------------------------------------------- /scripts/utils/meme2logo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from functools import partial 5 | from multiprocessing import Pool 6 | import os 7 | from tqdm import tqdm 8 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 9 | 10 | from jaspar2logo import _get_figure 11 | from meme2scores import _get_PWMs 12 | from utils import get_file_handle 13 | 14 | CONTEXT_SETTINGS = { 15 | "help_option_names": ["-h", "--help"], 16 | } 17 | 18 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 19 | @click.argument( 20 | "meme_file", 21 | type=click.Path(exists=True, resolve_path=True), 22 | ) 23 | @click.option( 24 | "-c", "--cpu-threads", 25 | help="Number of CPU threads to use.", 26 | type=int, 27 | default=1, 28 | show_default=True, 29 | ) 30 | @click.option( 31 | "-f", "--oformat", 32 | help="Output format.", 33 | default="png", 34 | show_default=True, 35 | ) 36 | @click.option( 37 | "-o", "--output-dir", 38 | help="Output directory.", 39 | type=click.Path(resolve_path=True), 40 | default="./", 41 | show_default=True, 42 | ) 43 | 44 | def cli(**args): 45 | 46 | # Create output dir 47 | if not os.path.exists(args["output_dir"]): 48 | os.makedirs(args["output_dir"]) 49 | 50 | # Get PWMs 51 | pwms, names = _get_PWMs(args["meme_file"]) 52 | 53 | # Generate logos 54 | kwargs = {"bar_format": bar_format, "total": len(pwms)} 55 | pool = Pool(args["cpu_threads"]) 56 | p = partial(_generate_logo, oformat=args["oformat"], 57 | output_dir=args["output_dir"]) 58 | for _ in tqdm(pool.imap(p, zip(pwms, names)), **kwargs): 59 | pass 60 | 61 | def _generate_logo(pwm_name, oformat="png", output_dir="./"): 62 | 63 | # Initialize 64 | pwm, name = pwm_name 65 | 66 | for reverse_complement in [False, True]: 67 | if reverse_complement: 68 | logo_file = os.path.join(output_dir, f"{name}.rev.{oformat}") 69 | else: 70 | logo_file = os.path.join(output_dir, f"{name}.fwd.{oformat}") 71 | if not os.path.exists(logo_file): 72 | try: 73 | fig = _get_figure(pwm, reverse_complement) 74 | fig.savefig(logo_file, bbox_inches="tight", pad_inches=0) 75 | except: 76 | # i.e. no motif 77 | fh = get_file_handle(logo_file, "wt") 78 | fh.close() 79 | 80 | if __name__ == "__main__": 81 | cli() -------------------------------------------------------------------------------- /scripts/utils/meme2scores.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup 5 | import json 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | import re 10 | from sklearn.metrics import average_precision_score, roc_auc_score 11 | import sys 12 | sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])), 13 | os.pardir)) 14 | import time 15 | import torch 16 | import torch.nn as nn 17 | from tqdm import tqdm 18 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 19 | 20 | from utils import (get_file_handle, get_seqs_labels_ids, get_data_loader, 21 | get_device) 22 | 23 | class PWM(nn.Module): 24 | """PWM (Position Weight Matrix).""" 25 | 26 | def __init__(self, pwms, input_length, scoring="max"): 27 | """ 28 | initialize the model 29 | 30 | :param pwms: arr, numpy array with shape = (n, 4, PWM length) 31 | :param input_length: int, input sequence length 32 | :param scoring: string, return either the max. score or the sum 33 | occupancy score for each sequence, default max 34 | """ 35 | super(PWM, self).__init__() 36 | 37 | num_pwms, _, filter_size = pwms.shape 38 | 39 | self._options = { 40 | "num_pwms": num_pwms, 41 | "filter_size": filter_size, 42 | "input_length": input_length, 43 | "scoring": scoring 44 | } 45 | 46 | self.conv1d = nn.Conv1d(in_channels=4 * num_pwms, out_channels=1 * num_pwms, kernel_size=filter_size, 47 | groups=num_pwms) 48 | 49 | self.conv1d.bias.data = torch.Tensor([0.] * num_pwms) # no bias 50 | 51 | self.conv1d.weight.data = torch.Tensor(pwms) # set the conv. weights 52 | # to the PWM weights 53 | 54 | for p in self.conv1d.parameters(): 55 | p.requires_grad = False # freeze 56 | 57 | def forward(self, x): 58 | """Forward propagation of a batch.""" 59 | x_rev = _flip(_flip(x, 1), 2) 60 | o = self.conv1d(x.repeat(1, self._options["num_pwms"], 1)) 61 | o_rev = self.conv1d(x_rev.repeat(1, self._options["num_pwms"], 1)) 62 | o = torch.cat((o, o_rev), 2) 63 | if self._options["scoring"] == "max": 64 | return torch.max(o, 2)[0] 65 | else: 66 | return torch.sum(o, 2) 67 | 68 | def _flip(x, dim): 69 | """ 70 | Reverses the elements in a given dimension `dim` of the Tensor. 71 | source: https://github.com/pytorch/pytorch/issues/229 72 | """ 73 | xsize = x.size() 74 | dim = x.dim() + dim if dim < 0 else dim 75 | x = x.contiguous() 76 | x = x.view(-1, *xsize[dim:]) 77 | x = x.view( 78 | x.size(0), x.size(1), -1)[:, getattr( 79 | torch.arange(x.size(1)-1, -1, -1), 80 | ("cpu","cuda")[x.is_cuda])().long(), :] 81 | 82 | return x.view(xsize) 83 | 84 | CONTEXT_SETTINGS = { 85 | "help_option_names": ["-h", "--help"], 86 | } 87 | 88 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 89 | @click.argument( 90 | "meme_file", 91 | type=click.Path(exists=True, resolve_path=True), 92 | ) 93 | @click.argument( 94 | "tsv_file", 95 | type=click.Path(exists=True, resolve_path=True), 96 | ) 97 | @click.option( 98 | "-c", "--cpu-threads", 99 | help="Number of CPU threads to use.", 100 | type=int, 101 | default=1, 102 | show_default=True, 103 | ) 104 | @click.option( 105 | "-d", "--debugging", 106 | help="Debugging mode.", 107 | is_flag=True, 108 | ) 109 | @click.option( 110 | "-o", "--output-dir", 111 | help="Output directory.", 112 | type=click.Path(resolve_path=True), 113 | default="./", 114 | show_default=True, 115 | ) 116 | @click.option( 117 | "-t", "--time", 118 | help="Return the program's running execution time in seconds.", 119 | is_flag=True, 120 | ) 121 | @optgroup.group("PWM scoring") 122 | @optgroup.option( 123 | "--batch-size", 124 | help="Batch size.", 125 | type=int, 126 | default=100, 127 | show_default=True, 128 | ) 129 | @optgroup.option( 130 | "--input-length", 131 | help="Input length (for longer and shorter sequences, trim or add padding, i.e. Ns, up to the specified length).", 132 | type=int, 133 | required=True, 134 | ) 135 | @click.option( 136 | "-s", "--scoring", 137 | help="Scoring function.", 138 | type=click.Choice(["max", "sum"]), 139 | default="max", 140 | show_default=True, 141 | ) 142 | 143 | def main(**args): 144 | 145 | # Start execution 146 | start_time = time.time() 147 | 148 | # Initialize 149 | if not os.path.exists(args["output_dir"]): 150 | os.makedirs(args["output_dir"]) 151 | 152 | # Save exec. parameters as JSON 153 | json_file = os.path.join(args["output_dir"], 154 | f"parameters-{os.path.basename(__file__)}.json") 155 | handle = get_file_handle(json_file, "wt") 156 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 157 | handle.close() 158 | 159 | ############## 160 | # Load Data # 161 | ############## 162 | 163 | # Get training sequences and labels 164 | seqs, labels, _ = get_seqs_labels_ids(args["tsv_file"], 165 | args["debugging"], 166 | False, 167 | args["input_length"]) 168 | 169 | # Get DataLoader 170 | data_loader = get_data_loader(seqs, labels, args["batch_size"]) 171 | 172 | # Load PWMs and names 173 | pwms, names = _get_PWMs(args["meme_file"], resize_pwms=True, 174 | return_log=True) 175 | 176 | ############## 177 | # Score PWMs # 178 | ############## 179 | 180 | # Initialize 181 | idx = 0 182 | scores = np.zeros((len(seqs), len(pwms))) 183 | 184 | # Infer input length/type, and the number of classes 185 | input_length = seqs[0].shape[1] 186 | 187 | # Get device 188 | device = get_device() 189 | 190 | # Get model 191 | m = PWM(pwms, input_length, args["scoring"]).to(device) 192 | 193 | with torch.no_grad(): 194 | for x, _ in tqdm(iter(data_loader), total=len(data_loader), 195 | bar_format=bar_format): 196 | 197 | x = x.to(device) # prepare inputs 198 | 199 | s = m(x) # get scores 200 | scores[idx:idx+x.shape[0], :] = s.cpu().numpy() 201 | 202 | idx += x.shape[0] # increase index 203 | 204 | ############### 205 | # AUC metrics # 206 | ############### 207 | 208 | # Initialize 209 | aucs = [] 210 | metrics = dict(aucROC=roc_auc_score, aucPR=average_precision_score) 211 | 212 | # Compute AUCs 213 | for i in range(len(names)): 214 | s = scores[:, i] 215 | aucs.append([names[i]]) 216 | for m in metrics: 217 | aucs[-1].append(metrics[m](labels, s)) 218 | 219 | ############### 220 | # Output AUCs # 221 | ############### 222 | 223 | tsv_file = os.path.join(args["output_dir"], "scores.tsv") 224 | if not os.path.exists(tsv_file): 225 | df = pd.DataFrame(aucs, columns=["PWM"]+[m for m in metrics]) 226 | df.to_csv(tsv_file, sep="\t", index=False) 227 | 228 | # Finish execution 229 | seconds = format(time.time() - start_time, ".2f") 230 | if args["time"]: 231 | f = os.path.join(args["output_dir"], 232 | f"time-{os.path.basename(__file__)}.txt") 233 | handle = get_file_handle(f, "wt") 234 | handle.write(f"{seconds} seconds") 235 | handle.close() 236 | print(f'Execution time {seconds} seconds') 237 | 238 | def _get_PWMs(meme_file, resize_pwms=False, return_log=False): 239 | 240 | # Initialize 241 | dicts = [] 242 | names = [] 243 | pwms = [] 244 | alphabet = "ACGT" 245 | parse = False 246 | 247 | # Get PWM 248 | handle = get_file_handle(meme_file, "rt") 249 | for line in handle: 250 | line = line.strip("\n") 251 | if line.startswith("MOTIF"): 252 | parse = True 253 | dicts.append({}) 254 | for l in alphabet: 255 | dicts[-1].setdefault(l, []) 256 | names.append(line.split(" ")[1]) 257 | elif not parse: 258 | continue 259 | elif line.startswith("letter-probability matrix:"): 260 | continue 261 | else: 262 | m = re.search("^[\t\s]*(\S+)[\t\s]+(\S+)[\t\s]+(\S+)[\t\s]+(\S+)", line) 263 | if m: 264 | for l in range(len(alphabet)): 265 | # Add pseudocounts 266 | v = max([1e-4, float(m.group(l+1))]) 267 | dicts[-1][alphabet[l]].append(v) 268 | 269 | # Get max. PWM size 270 | max_size = 0 271 | for d in dicts: 272 | for l in alphabet: 273 | max_size = max([len(d[l]), max_size]) 274 | break 275 | 276 | # For each matrix... 277 | for d in dicts: 278 | pwm = [] 279 | for l in alphabet: 280 | pwm.append(d[l]) 281 | if resize_pwms: 282 | pwm = __resize_PWM(list(zip(*pwm)), max_size) 283 | pwm = list(zip(*pwm)) 284 | if return_log: 285 | pwm = np.log(pwm) 286 | pwms.append(np.array(pwm)) 287 | 288 | return(pwms, names) 289 | 290 | def __resize_PWM(pwm, size): 291 | 292 | # Initialize 293 | lpop = 0 294 | rpop = 0 295 | 296 | pwm = [[.25,.25,.25,.25]]*size+pwm+[[.25,.25,.25,.25]]*size 297 | 298 | while len(pwm) > size: 299 | if max(pwm[0]) < max(pwm[-1]): 300 | pwm.pop(0) 301 | lpop += 1 302 | elif max(pwm[-1]) < max(pwm[0]): 303 | pwm.pop(-1) 304 | rpop += 1 305 | else: 306 | if lpop > rpop: 307 | pwm.pop(-1) 308 | rpop += 1 309 | else: 310 | pwm.pop(0) 311 | lpop += 1 312 | 313 | return(pwm) 314 | 315 | if __name__ == "__main__": 316 | main() -------------------------------------------------------------------------------- /scripts/utils/pwm2scores.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | import numpy as np 5 | import os 6 | import pandas as pd 7 | import re 8 | import sys 9 | import torch 10 | from tqdm import tqdm 11 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 12 | 13 | from explainn.models.networks import PWM 14 | 15 | # Local imports 16 | sys.path.insert(0, os.path.join(os.path.dirname(sys.argv[0]), os.pardir)) 17 | from utils import (get_file_handle, get_seqs_labels_ids, get_data_loader, 18 | get_device) 19 | 20 | CONTEXT_SETTINGS = { 21 | "help_option_names": ["-h", "--help"], 22 | } 23 | 24 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 25 | @click.argument( 26 | "meme_file", 27 | type=click.Path(exists=True, resolve_path=True), 28 | ) 29 | @click.argument( 30 | "tsv_file", 31 | type=click.Path(exists=True, resolve_path=True), 32 | ) 33 | @click.option( 34 | "-b", "--batch-size", 35 | help="Batch size.", 36 | type=int, 37 | default=100, 38 | show_default=True, 39 | ) 40 | @click.option( 41 | "-d", "--debugging", 42 | help="Debugging mode.", 43 | is_flag=True, 44 | ) 45 | @click.option( 46 | "-o", "--output-dir", 47 | help="Output directory.", 48 | type=click.Path(resolve_path=True), 49 | default="./", 50 | show_default=True, 51 | ) 52 | @click.option( 53 | "-p", "--prefix", 54 | help="Output prefix.", 55 | ) 56 | @click.option( 57 | "-s", "--scoring", 58 | help="Scoring function.", 59 | type=click.Choice(["max", "sum"]), 60 | default="max", 61 | show_default=True, 62 | ) 63 | 64 | def main(**args): 65 | 66 | # Create output dir 67 | if not os.path.exists(args["output_dir"]): 68 | os.makedirs(args["output_dir"]) 69 | 70 | ############## 71 | # Load Data # 72 | ############## 73 | 74 | # Initialize 75 | device = "cuda" if torch.cuda.is_available() else "cpu" 76 | 77 | # Get data 78 | seqs, y_true, _ = _get_seqs_labels_ids(args["tsv_file"], args["debugging"]) 79 | 80 | # Get DataLoader 81 | data_loader = _get_data_loader(seqs, y_true, args["batch_size"]) 82 | 83 | # Load model 84 | pwms, names = _get_PWMs(args["meme_file"], resize_pwms=True, 85 | return_log=True) 86 | pwm_model = PWM(pwms, seqs.shape[2], args["scoring"]).to(device) 87 | 88 | ############## 89 | # Score PWMs # 90 | ############## 91 | 92 | # Initialize 93 | idx = 0 94 | scores = np.zeros((len(data_loader.dataset), pwm_model._options["groups"])) 95 | 96 | with torch.no_grad(): 97 | for x, _ in tqdm(iter(data_loader), total=len(data_loader), 98 | bar_format=bar_format): 99 | 100 | # Prepare inputs 101 | x = x.to(device) 102 | 103 | # Get scores 104 | s = pwm_model(x) 105 | scores[idx:idx+x.shape[0], :] = s.cpu().numpy() 106 | 107 | # Index increase 108 | idx += x.shape[0] 109 | 110 | ############### 111 | # AUC metrics # 112 | ############### 113 | 114 | # Initialize 115 | aucs = [] 116 | metrics = get_metrics() 117 | 118 | # Compute AUCs 119 | for i in range(len(names)): 120 | y_score = scores[:, i] 121 | aucs.append([names[i]]) 122 | for m in metrics: 123 | aucs[-1].append(metrics[m](y_true, y_score)) 124 | 125 | ############### 126 | # Output AUCs # 127 | ############### 128 | 129 | # Create DataFrame 130 | df = pd.DataFrame(aucs, columns=["PWM"]+[m for m in metrics]) 131 | 132 | # Save AUCs 133 | if args["prefix"] is None: 134 | tsv_file = os.path.join(args["output_dir"], "%s.tsv" % args["scoring"]) 135 | else: 136 | tsv_file = os.path.join(args["output_dir"], 137 | "%s.%s.tsv" % (args["prefix"], args["scoring"])) 138 | df.to_csv(tsv_file, sep="\t", index=False) 139 | 140 | def _get_PWMs(meme_file, resize_pwms=False, return_log=False): 141 | 142 | # Initialize 143 | dicts = [] 144 | names = [] 145 | pwms = [] 146 | alphabet = "ACGT" 147 | parse = False 148 | 149 | # Get PWM 150 | handle = get_file_handle(meme_file, "rt") 151 | for line in handle: 152 | line = line.strip("\n") 153 | if line.startswith("MOTIF"): 154 | parse = True 155 | dicts.append({}) 156 | for l in alphabet: 157 | dicts[-1].setdefault(l, []) 158 | names.append(line.split(" ")[1]) 159 | elif not parse: 160 | continue 161 | elif line.startswith("letter-probability matrix:"): 162 | continue 163 | else: 164 | m = re.search("^\s*(\S+)\s+(\S+)\s+(\S+)\s+(\S+)$", line) 165 | if m: 166 | for l in range(len(alphabet)): 167 | # Add pseudocounts 168 | v = max([1e-4, float(m.group(l+1))]) 169 | dicts[-1][alphabet[l]].append(v) 170 | 171 | # Get max. PWM size 172 | max_size = 0 173 | for d in dicts: 174 | for l in alphabet: 175 | max_size = max([len(d[l]), max_size]) 176 | break 177 | 178 | # For each matrix... 179 | for d in dicts: 180 | pwm = [] 181 | for l in alphabet: 182 | pwm.append(d[l]) 183 | if resize_pwms: 184 | pwm = __resize_PWM(list(zip(*pwm)), max_size) 185 | pwm = list(zip(*pwm)) 186 | pwms.append(pwm) 187 | 188 | if return_log: 189 | return(np.log(pwms), names) 190 | else: 191 | return(pwms, names) 192 | 193 | def __resize_PWM(pwm, size): 194 | 195 | # Initialize 196 | lpop = 0 197 | rpop = 0 198 | 199 | pwm = [[.25,.25,.25,.25]]*size+pwm+[[.25,.25,.25,.25]]*size 200 | 201 | while len(pwm) > size: 202 | if max(pwm[0]) < max(pwm[-1]): 203 | pwm.pop(0) 204 | lpop += 1 205 | elif max(pwm[-1]) < max(pwm[0]): 206 | pwm.pop(-1) 207 | rpop += 1 208 | else: 209 | if lpop > rpop: 210 | pwm.pop(-1) 211 | rpop += 1 212 | else: 213 | pwm.pop(0) 214 | lpop += 1 215 | 216 | return(pwm) 217 | 218 | if __name__ == "__main__": 219 | main() -------------------------------------------------------------------------------- /scripts/utils/resize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | import pandas as pd 5 | import sys 6 | 7 | CONTEXT_SETTINGS = { 8 | "help_option_names": ["-h", "--help"], 9 | } 10 | 11 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 12 | @click.argument( 13 | "bed_file", 14 | type=click.Path(exists=True, resolve_path=True), 15 | ) 16 | @click.argument( 17 | "chrom_sizes", 18 | type=click.Path(exists=True, resolve_path=True), 19 | ) 20 | @click.argument( 21 | "size", 22 | type=int, 23 | ) 24 | @click.option( 25 | "-o", "--output-file", 26 | help="Output file (in BED format). [default: STDOUT]", 27 | type=click.Path(writable=True, readable=False, resolve_path=True, 28 | allow_dash=True), 29 | ) 30 | 31 | def cli(**args): 32 | 33 | # Get BED file as DataFrame 34 | bed = pd.read_table(args["bed_file"], names=["chrom", "start", "end"], 35 | header=None) 36 | 37 | # Get chrom sizes as dict 38 | sizes = dict.fromkeys(bed["chrom"].to_list(), -1) 39 | df = pd.read_table(args["chrom_sizes"], names=["chrom", "size"], 40 | header=None) 41 | sizes.update(dict(zip(df["chrom"].to_list(), df["size"].to_list()))) 42 | 43 | # Resize intervals 44 | s = args["size"] / 2. 45 | bed["center"] = list(map(int, bed["start"] + (bed["end"] - bed["start"]) / 2.)) 46 | bed["start"] = list(map(int, bed["center"] - s)) 47 | bed["end"] = list(map(int, bed["center"] + s)) 48 | 49 | # Filter intervals 50 | bed = bed[(bed["start"] >= 0) & \ 51 | (bed["end"] <= bed["chrom"].map(lambda x: sizes[x]))] 52 | 53 | # Write 54 | if args["output_file"] is not None: 55 | handle = open(args["output_file"], "wt") 56 | else: 57 | handle = sys.stdout 58 | bed.to_csv(handle, columns=["chrom", "start", "end"], header=False, 59 | index=False, sep="\t") 60 | handle.close() 61 | 62 | if __name__ == "__main__": 63 | cli() -------------------------------------------------------------------------------- /scripts/utils/subsample-seqs-by-gc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from Bio import SeqIO 4 | from Bio.Seq import Seq 5 | from Bio.SeqRecord import SeqRecord 6 | import click 7 | import importlib 8 | import math 9 | import random 10 | import sys 11 | 12 | lib = importlib.import_module("match-seqs-by-gc") 13 | 14 | CONTEXT_SETTINGS = { 15 | "help_option_names": ["-h", "--help"], 16 | } 17 | 18 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 19 | @click.argument( 20 | "fasta_file", 21 | type=click.Path(exists=True, resolve_path=True), 22 | ) 23 | @click.option( 24 | "-d", "--dna", 25 | type=click.Choice(["uppercase", "lowercase"]), 26 | help="DNA to transform.", 27 | show_default=True, 28 | default="lowercase" 29 | ) 30 | @click.option( 31 | "-o", "--output-file", 32 | help="Output file (in FASTA format). [default: STDOUT]", 33 | type=click.Path(writable=True, readable=False, resolve_path=True, 34 | allow_dash=True), 35 | ) 36 | @click.option( 37 | "-r", "--random-seed", 38 | help="Random seed.", 39 | type=int, 40 | default=1714, 41 | show_default=True 42 | ) 43 | @click.option( 44 | "-s", "--subsample", 45 | help="Number of sequences to subsample.", 46 | type=click.IntRange(min=1000), 47 | default=1000, 48 | show_default=True, 49 | ) 50 | @click.option( 51 | "-t", "--transform", 52 | type=click.Choice(["skip", "shuffle", "mask"]), 53 | help="Skip, shuffle, or mask (i.e. convert to Ns) DNA.", 54 | show_default=True 55 | ) 56 | 57 | def cli(**args): 58 | 59 | # Group sequences based on their %GC content 60 | gc_groups = lib._get_GC_groups([args["fasta_file"]], args["dna"], 61 | args["transform"]) 62 | 63 | # Match sequences based on their %GC content 64 | matched_seqs = lib._match_seqs_by_GC(gc_groups, args["random_seed"]) 65 | 66 | # Subsample sequences based on their %GC content 67 | if args["subsample"]: 68 | sampled_seqs = _subsample_seqs_by_GC(matched_seqs, args["random_seed"], 69 | abs(args["subsample"])) 70 | else: 71 | sampled_seqs = matched_seqs 72 | sampled_seqs = [SeqRecord(Seq(arr[1][1]), id=arr[1][0], description="") \ 73 | for arr in sampled_seqs] 74 | 75 | # Write 76 | if args["output_file"] is not None: 77 | handle = open(args["output_file"], "wt") 78 | else: 79 | handle = sys.stdout 80 | SeqIO.write(sampled_seqs, handle, "fasta") 81 | handle.close() 82 | 83 | def _subsample_seqs_by_GC(matched_seqs, random_seed=1714, subsample=1000): 84 | 85 | # Initialize 86 | gc_regroups = {} 87 | sampled_seqs = [] 88 | 89 | # Regroup sequences based on their %GC content 90 | for arr in matched_seqs: 91 | gc_regroups.setdefault(arr[0], []) 92 | gc_regroups[arr[0]].append(arr) 93 | 94 | # Get normalization factor 95 | norm_factor = subsample / sum([len(v) for v in gc_regroups.values()]) 96 | 97 | # Subsample sequences based on their %GC content 98 | for i in sorted(gc_regroups): 99 | random.Random(random_seed).shuffle(gc_regroups[i]) 100 | arr = gc_regroups[i][:math.ceil(len(gc_regroups[i])*norm_factor)] 101 | sampled_seqs.extend(arr) 102 | 103 | # Randomize 104 | random.Random(random_seed).shuffle(sampled_seqs) 105 | 106 | return(sampled_seqs[:subsample]) 107 | 108 | if __name__ == "__main__": 109 | cli() -------------------------------------------------------------------------------- /scripts/utils/tomtom.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import click 4 | from click_option_group import optgroup 5 | from functools import partial 6 | import json 7 | from multiprocessing import Pool 8 | import numpy as np 9 | import os 10 | import pandas as pd 11 | from pandas.errors import EmptyDataError 12 | import sys 13 | sys.path.insert(0, os.path.join(os.path.dirname(sys.argv[0]), os.pardir)) 14 | import subprocess as sp 15 | import time 16 | from tqdm import tqdm 17 | bar_format = "{percentage:3.0f}%|{bar:20}{r_bar}" 18 | import warnings 19 | warnings.filterwarnings("ignore") 20 | 21 | from utils import get_file_handle 22 | 23 | CONTEXT_SETTINGS = { 24 | "help_option_names": ["-h", "--help"], 25 | } 26 | 27 | @click.command(no_args_is_help=True, context_settings=CONTEXT_SETTINGS) 28 | @click.argument( 29 | "query_file", 30 | type=click.Path(exists=True, resolve_path=True), 31 | ) 32 | @click.argument( 33 | "target_file", 34 | type=click.Path(exists=True, resolve_path=True), 35 | ) 36 | @click.option( 37 | "-c", "--cpu-threads", 38 | help="Number of CPU threads to use.", 39 | type=int, 40 | default=1, 41 | show_default=True, 42 | ) 43 | @click.option( 44 | "-o", "--output-dir", 45 | help="Output directory.", 46 | type=click.Path(resolve_path=True), 47 | default="./", 48 | show_default=True, 49 | ) 50 | @click.option( 51 | "-t", "--time", 52 | help="Return the program's running execution time in seconds.", 53 | is_flag=True, 54 | ) 55 | @optgroup.group("Tomtom") 56 | @optgroup.option( 57 | "--dist", 58 | help="Distance metric for scoring alignments.", 59 | type=str, 60 | default="pearson", 61 | show_default=True, 62 | ) 63 | @optgroup.option( 64 | "--evalue", 65 | help="Use E-value threshold.", 66 | is_flag=True, 67 | ) 68 | @optgroup.option( 69 | "--min-overlap", 70 | help="Minimum overlap between query and target.", 71 | type=int, 72 | default=5, 73 | show_default=True, 74 | ) 75 | @optgroup.option( 76 | "--motif-pseudo", 77 | help="Apply pseudocounts to the query and target motifs.", 78 | type=float, 79 | default=0.1, 80 | show_default=True, 81 | ) 82 | @optgroup.option( 83 | "--thresh", 84 | help="Significance threshold (i.e., do not show worse motifs).", 85 | type=float, 86 | default=0.05, 87 | show_default=True, 88 | ) 89 | 90 | def cli(**args): 91 | 92 | # Start execution 93 | start_time = time.time() 94 | 95 | # Initialize 96 | if not os.path.exists(args["output_dir"]): 97 | os.makedirs(args["output_dir"]) 98 | 99 | # Save exec. parameters as JSON 100 | json_file = os.path.join(args["output_dir"], 101 | f"parameters-{os.path.basename(__file__)}.json") 102 | handle = get_file_handle(json_file, "wt") 103 | handle.write(json.dumps(args, indent=4, sort_keys=True)) 104 | handle.close() 105 | 106 | # Create output dirs 107 | motifs_dir = os.path.join(args["output_dir"], "motifs") 108 | if not os.path.isdir(motifs_dir): 109 | os.makedirs(motifs_dir) 110 | tomtom_dir = os.path.join(args["output_dir"], "tomtom") 111 | if not os.path.isdir(tomtom_dir): 112 | os.makedirs(tomtom_dir) 113 | 114 | # Get motifs 115 | motifs = [] 116 | _get_motifs(args["query_file"], motifs_dir, args["cpu_threads"]) 117 | for meme_file in os.listdir(motifs_dir): 118 | motifs.append(os.path.join(motifs_dir, meme_file)) 119 | 120 | # Compute Tomtom similarities 121 | kwargs = {"bar_format": bar_format, "total": len(motifs)} 122 | pool = Pool(args["cpu_threads"]) 123 | p = partial(_compute_Tomtom_similarities, target_file=args["target_file"], 124 | tomtom_dir=tomtom_dir, dist=args["dist"], evalue=args["evalue"], 125 | minover=args["min_overlap"], mpseudo=args["motif_pseudo"], 126 | thresh=args["thresh"]) 127 | for _ in tqdm(pool.imap(p, motifs), **kwargs): 128 | pass 129 | 130 | # Save Tomtom file 131 | tsv_file = os.path.join(args["output_dir"], "tomtom.tsv.gz") 132 | if not os.path.exists(tsv_file): 133 | df = _load_Tomtom_files(tomtom_dir) 134 | df.to_csv(tsv_file, sep="\t", index=False) 135 | 136 | # Finish execution 137 | seconds = format(time.time() - start_time, ".2f") 138 | if args["time"]: 139 | f = os.path.join(args["output_dir"], 140 | f"time-{os.path.basename(__file__)}.txt") 141 | handle = get_file_handle(f, "wt") 142 | handle.write(f"{seconds} seconds") 143 | handle.close() 144 | print(f"Execution time {seconds} seconds") 145 | 146 | def _get_motifs(meme_file, motifs_dir, cpu_threads=1): 147 | 148 | # Initialize 149 | motifs = [] 150 | parse = False 151 | 152 | # Get motifs 153 | handle = get_file_handle(meme_file, "rt") 154 | for line in handle: 155 | line = line.strip("\n") 156 | if line.startswith("MOTIF"): 157 | motifs.append([]) 158 | parse = True 159 | if parse: 160 | motifs[-1].append(line) 161 | handle.close() 162 | 163 | # Create motif files 164 | zfill = len(str(len(motifs))) 165 | kwargs = {"bar_format": bar_format, "total": len(motifs)} 166 | pool = Pool(cpu_threads) 167 | p = partial(__write_motif, motifs_dir=motifs_dir, zfill=zfill) 168 | for _ in tqdm(pool.imap(p, enumerate(motifs)), **kwargs): 169 | pass 170 | 171 | def __write_motif(i_motif, motifs_dir, zfill=0): 172 | 173 | # Initialize 174 | i, motif = i_motif 175 | prefix = str(i).zfill(zfill) 176 | 177 | motif_file = os.path.join(motifs_dir, f"{prefix}.meme") 178 | if not os.path.exists(motif_file): 179 | handle = get_file_handle(motif_file, "wt") 180 | handle.write("MEME version 4\n\n") 181 | handle.write("ALPHABET= ACGT\n\n") 182 | handle.write("strands: + -\n\n") 183 | handle.write( 184 | "Background letter frequencies (from uniform background):\n" 185 | ) 186 | handle.write("A 0.25000 C 0.25000 G 0.25000 T 0.25000\n\n") 187 | for line in motif: 188 | handle.write(f"{line}\n") 189 | handle.close() 190 | 191 | def _compute_Tomtom_similarities(query_file, target_file, tomtom_dir, 192 | dist="pearson", evalue=False, minover=5, 193 | mpseudo=0.1, thresh=0.05): 194 | 195 | # Initialize 196 | prefix = os.path.splitext(os.path.basename(query_file))[0] 197 | tomtom_file = os.path.join(tomtom_dir, f"{prefix}.tsv.gz") 198 | 199 | if not os.path.exists(tomtom_file): 200 | 201 | # Compute motif similarities 202 | if evalue: 203 | cmd = ["tomtom", "-dist", dist, "-motif-pseudo", str(mpseudo), 204 | "-min-overlap", str(minover), "-thresh", str(thresh), 205 | "-text", "-evalue", query_file, target_file] 206 | else: 207 | cmd = ["tomtom", "-dist", dist, "-motif-pseudo", str(mpseudo), 208 | "-min-overlap", str(minover), "-thresh", str(thresh), 209 | "-text", query_file, target_file] 210 | proc = sp.run(cmd, stdout=sp.PIPE, stderr=sp.DEVNULL) 211 | 212 | # Save Tomtom results 213 | handle = get_file_handle(tomtom_file, "wb") 214 | for line in proc.stdout.decode().split("\n"): 215 | handle.write(f"{line}\n".encode()) 216 | handle.close() 217 | 218 | def _load_Tomtom_files(tomtom_dir, col_names=None): 219 | 220 | # Initialize 221 | dfs = [] 222 | 223 | for tsv_file in os.listdir(tomtom_dir): 224 | if tsv_file.endswith(".tsv.gz"): 225 | dfs.append(pd.read_table(os.path.join(tomtom_dir, tsv_file), 226 | header=0, usecols=col_names, 227 | comment="#")) 228 | 229 | return(pd.concat(dfs)) 230 | 231 | if __name__ == "__main__": 232 | cli() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | ExplaiNN: interpretable and transparent neural networks for genomics 3 | """ 4 | 5 | import setuptools 6 | 7 | with open("README.md", "r") as f: 8 | long_description = f.read() 9 | 10 | setuptools.setup( 11 | name="explainn", 12 | version="23.5.1", 13 | author="Gherman Novakovsky, Oriol Fornes", 14 | author_email="g.e.novakovsky@gmail.com, oriol.fornes@gmail.com", 15 | description="ExplaiNN: interpretable and transparent neural networks for genomics", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/wassermanlab/ExplaiNN", 19 | packages=setuptools.find_packages(), 20 | include_package_data=True, 21 | classifiers=[ 22 | "Intended Audience :: Developers", 23 | "Programming Language :: Python :: 3", 24 | "License :: OSI Approved :: MIT License", 25 | "Operating System :: OS Independent", 26 | ], 27 | install_requires=["numpy", "h5py", "tqdm", "pandas"] 28 | ) --------------------------------------------------------------------------------