├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── code ├── 10x10valid.py ├── aitac.py ├── extract_motifs.py ├── finetune_filter_subset.py ├── plot_utils.py ├── second_layer_motifs │ ├── aitac_v2.py │ ├── extract_layer2_motifs.py │ └── plot_utils.py ├── summary_sheet.py ├── test_human.py └── train_test_aitac.py ├── data_processing ├── human_data_preprocess.py ├── permute_human_data.py ├── permute_sequence.py ├── preprocess_data.py └── preprocess_utils.py ├── figures ├── AI-TAC.png ├── AITAC.pdf └── logo.png ├── models └── aitac.ckpt └── sample_data ├── README ├── cell_type_array.npy ├── cell_type_names.npy ├── filter_set99_index.txt ├── one_hot_seqs.npy └── peak_names.npy /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | #directories 132 | data/ 133 | human_data/ 134 | outputs/ 135 | 136 | #bash scripts 137 | *.sh 138 | *.out 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 smaslova 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Overview of AI-TAC](figures/logo.png) 2 | # Convolutional neural network to predict immune cell chromatin state 3 | AI-TAC is a deep convoluional network for predicting mouse immune cell ATAC-seq signal from peak sequences using data from the Immunological Genome Project: http://www.immgen.org/. 4 | 5 | ![Overview of AI-TAC](figures/AI-TAC.png) 6 | 7 | 8 | ## Requirements 9 | The code was written in python v3.6.3 and pytorch v1.4.0, and run on NVIDIA P100 Pascal GPUs. 10 | 11 | ## Tutorial 12 | 13 | The processed data files can be downloaded using the following links: 14 | - [bed file containing peak locations](https://www.dropbox.com/s/r8drj2wxc07bt4j/ImmGenATAC1219.peak_matched.txt?dl=0) 15 | - [csv file containing measured ATAC-seq peak heights](https://www.dropbox.com/s/7mmd4v760eux755/mouse_peak_heights.csv?dl=0) 16 | 17 | The required input is a bed file with ATAC-seq peak locations, the reference genome and a file with normalized peak heights. The code for processing raw data is in data_processing/; for example, to convert the ImmGen mouse data set to one-hot encoded sequences and save in the data directory, run: 18 | 19 | ```python 20 | python process_data.py "../data/ImmGenATAC1219.peak_matched.txt" "../data/mouse_peak_heights.csv" "../mm10/" "../data/" 21 | ``` 22 | 23 | The model can then be trained by running: 24 | ```python 25 | python train_test_aitac.py model_name '../data/one_hot_seqs.npy' '../data/cell_type_array.npy' '../data/peak_names.npy' 26 | ``` 27 | 28 | To extract first layer motifs run: 29 | ```python 30 | python -u extract_motifs.py model_name '../data/one_hot_seqs.npy' '../data/cell_type_array.npy' '../data/peak_names.npy' 31 | ``` 32 | ## Reference 33 | [Learning immune cell differentiation. Alexandra Maslova, Ricardo N. Ramirez, Ke Ma, Hugo Schmutz, Chendi Wang, Curtis Fox, Bernard Ng, Christophe Benoist, Sara Mostafavi, the Immunological Genome Project, bioRxiv 2019.12.21.885814; doi: https://doi.org/10.1101/2019.12.21.885814](https://www.biorxiv.org/content/10.1101/2019.12.21.885814v1) 34 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /code/10x10valid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import random 3 | import pandas as pd 4 | 5 | from sklearn.model_selection import KFold 6 | import torch.utils.data 7 | import matplotlib 8 | import os 9 | import sys 10 | matplotlib.use('Agg') 11 | 12 | import aitac 13 | import plot_utils 14 | 15 | #create output directory 16 | output_file_path = "../outputs/valid10x10/" 17 | directory = os.path.dirname(output_file_path) 18 | if not os.path.exists(directory): 19 | print("Creating directory %s" % output_file_path) 20 | os.makedirs(directory) 21 | else: 22 | print("Directory %s exists" % output_file_path) 23 | 24 | # Device configuration 25 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 26 | 27 | # Hyper parameters 28 | num_epochs = 10 29 | num_classes = 81 30 | batch_size = 100 31 | learning_rate = 0.001 32 | num_filters = 300 33 | run_num = sys.argv[1] 34 | 35 | # Load all data 36 | x = np.load(sys.argv[2]) 37 | x = x.astype(np.float32) 38 | y = np.load(sys.argv[3]) 39 | y = y.astype(np.float32) 40 | peak_names = np.load(sys.argv[4]) 41 | 42 | 43 | def cross_validate(x, y, peak_names, output_file_path): 44 | kf = KFold(n_splits=10, shuffle=True) 45 | 46 | pred_all = [] 47 | corr_all = [] 48 | peak_order = [] 49 | for train_index, test_index in kf.split(x): 50 | train_data, eval_data = x[train_index, :, :], x[test_index, :, :] 51 | train_labels, eval_labels = y[train_index, :], y[test_index, :] 52 | train_names, eval_name = peak_names[train_index], peak_names[test_index] 53 | 54 | # Data loader 55 | train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data), torch.from_numpy(train_labels)) 56 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False) 57 | 58 | eval_dataset = torch.utils.data.TensorDataset(torch.from_numpy(eval_data), torch.from_numpy(eval_labels)) 59 | eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=batch_size, shuffle=False) 60 | 61 | 62 | # create model 63 | model = aitac.ConvNet(num_classes, num_filters).to(device) 64 | 65 | # Loss and optimizer 66 | criterion = aitac.pearson_loss 67 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 68 | 69 | # train model 70 | model, best_loss = aitac.train_model(train_loader, eval_loader, model, device, criterion, optimizer, num_epochs, output_file_path) 71 | 72 | # Predict on test set 73 | predictions, max_activations, max_act_index = aitac.test_model(eval_loader, model, device) 74 | 75 | # plot the correlations histogram 76 | correlations = plot_utils.plot_cors(eval_labels, predictions, output_file_path) 77 | 78 | pred_all.append(predictions) 79 | corr_all.append(correlations) 80 | peak_order.append(eval_name) 81 | 82 | pred_all = np.vstack(pred_all) 83 | corr_all = np.hstack(corr_all) 84 | peak_order = np.hstack(peak_order) 85 | 86 | return pred_all, corr_all, peak_order 87 | 88 | 89 | predictions, correlations, peak_order = cross_validate(x, y, peak_names, output_file_path) 90 | 91 | np.save(output_file_path + "predictions_trial" + run_num + ".npy", predictions) 92 | np.save(output_file_path + "correlations_trial" + run_num + ".npy", correlations) 93 | np.save(output_file_path + "peak_order" + run_num + ".npy", peak_order) 94 | -------------------------------------------------------------------------------- /code/aitac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data 5 | import random 6 | import numpy as np 7 | import copy 8 | 9 | # Convolutional neural network 10 | class ConvNet(nn.Module): 11 | def __init__(self, num_classes, num_filters): 12 | super(ConvNet, self).__init__() 13 | 14 | # for layer one, separate convolution and relu step from maxpool and batch normalization 15 | # to extract convolutional filters 16 | self.layer1_conv = nn.Sequential( 17 | nn.Conv2d(in_channels=1, 18 | out_channels=num_filters, 19 | kernel_size=(4, 19), 20 | stride=1, 21 | padding=0), # padding is done in forward method along 1 dimension only 22 | nn.ReLU()) 23 | 24 | self.layer1_process = nn.Sequential( 25 | nn.MaxPool2d(kernel_size=(1,3), stride=(1,3), padding=(0,1)), 26 | nn.BatchNorm2d(num_filters)) 27 | 28 | self.layer2 = nn.Sequential( 29 | nn.Conv2d(in_channels=num_filters, 30 | out_channels=200, 31 | kernel_size=(1, 11), 32 | stride=1, 33 | padding=0), # padding is done in forward method along 1 dimension only 34 | nn.ReLU(), 35 | nn.MaxPool2d(kernel_size=(1,4), stride=(1,4), padding=(0,1)), 36 | nn.BatchNorm2d(200)) 37 | 38 | self.layer3 = nn.Sequential( 39 | nn.Conv2d(in_channels=200, 40 | out_channels=200, 41 | kernel_size=(1, 7), 42 | stride=1, 43 | padding=0), # padding is done in forward method along 1 dimension only 44 | nn.ReLU(), 45 | nn.MaxPool2d(kernel_size=(1, 4), stride=(1,4), padding=(0,1)), 46 | nn.BatchNorm2d(200)) 47 | 48 | self.layer4 = nn.Sequential( 49 | nn.Linear(in_features=1000, 50 | out_features=1000), 51 | nn.ReLU(), 52 | nn.Dropout(p=0.03)) 53 | 54 | self.layer5 = nn.Sequential( 55 | nn.Linear(in_features=1000, 56 | out_features=1000), 57 | nn.ReLU(), 58 | nn.Dropout(p=0.03)) 59 | 60 | self.layer6 = nn.Sequential( 61 | nn.Linear(in_features=1000, 62 | out_features=num_classes))#, 63 | #nn.Sigmoid()) 64 | 65 | 66 | def forward(self, input): 67 | # run all layers on input data 68 | # add dummy dimension to input (for num channels=1) 69 | input = torch.unsqueeze(input, 1) 70 | 71 | # Run convolutional layers 72 | input = F.pad(input, (9, 9), mode='constant', value=0) # padding - last dimension goes first 73 | out = self.layer1_conv(input) 74 | activations = torch.squeeze(out) 75 | out = self.layer1_process(out) 76 | 77 | out = F.pad(out, (5, 5), mode='constant', value=0) 78 | out = self.layer2(out) 79 | 80 | out = F.pad(out, (3, 3), mode='constant', value=0) 81 | out = self.layer3(out) 82 | 83 | # Flatten output of convolutional layers 84 | out = out.view(out.size()[0], -1) 85 | 86 | # run fully connected layers 87 | out = self.layer4(out) 88 | out = self.layer5(out) 89 | predictions = self.layer6(out) 90 | 91 | activations, act_index = torch.max(activations, dim=2) 92 | 93 | return predictions, activations, act_index 94 | 95 | # define model for extracting motifs from first convolutional layer 96 | # and determining importance of each filter on prediction 97 | class motifCNN(nn.Module): 98 | def __init__(self, original_model): 99 | super(motifCNN, self).__init__() 100 | self.layer1_conv = nn.Sequential(*list(original_model.children())[0]) 101 | self.layer1_process = nn.Sequential(*list(original_model.children())[1]) 102 | self.layer2 = nn.Sequential(*list(original_model.children())[2]) 103 | self.layer3 = nn.Sequential(*list(original_model.children())[3]) 104 | 105 | self.layer4 = nn.Sequential(*list(original_model.children())[4]) 106 | self.layer5 = nn.Sequential(*list(original_model.children())[5]) 107 | self.layer6 = nn.Sequential(*list(original_model.children())[6]) 108 | 109 | 110 | def forward(self, input, num_filters): 111 | # add dummy dimension to input (for num channels=1) 112 | input = torch.unsqueeze(input, 1) 113 | 114 | # Run convolutional layers 115 | input = F.pad(input, (9, 9), mode='constant', value=0) # padding - last dimension goes first 116 | out= self.layer1_conv(input) 117 | layer1_activations = torch.squeeze(out) 118 | 119 | #do maxpooling and batch normalization for layer 1 120 | layer1_out = self.layer1_process(out) 121 | layer1_out = F.pad(layer1_out, (5, 5), mode='constant', value=0) 122 | 123 | #calculate average activation by filter for the whole batch 124 | filter_means_batch = layer1_activations.mean(0).mean(1) 125 | 126 | # run all other layers with 1 filter left out at a time 127 | batch_size = layer1_out.shape[0] 128 | predictions = torch.zeros(batch_size, num_filters, 81) 129 | 130 | 131 | for i in range(num_filters): 132 | #modify filter i of first layer output 133 | filter_input = layer1_out.clone() 134 | 135 | filter_input[:,i,:,:] = filter_input.new_full((batch_size, 1, 94), fill_value=filter_means_batch[i]) 136 | 137 | out = self.layer2(filter_input) 138 | out = F.pad(out, (3, 3), mode='constant', value=0) 139 | out = self.layer3(out) 140 | 141 | # Flatten output of convolutional layers 142 | out = out.view(out.size()[0], -1) 143 | 144 | # run fully connected layers 145 | out = self.layer4(out) 146 | out = self.layer5(out) 147 | out = self.layer6(out) 148 | 149 | predictions[:,i,:] = out 150 | 151 | activations, act_index = torch.max(layer1_activations, dim=2) 152 | 153 | return predictions, layer1_activations, act_index 154 | 155 | 156 | #define the model loss 157 | def pearson_loss(x,y): 158 | mx = torch.mean(x, dim=1, keepdim=True) 159 | my = torch.mean(y, dim=1, keepdim=True) 160 | xm, ym = x - mx, y - my 161 | 162 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 163 | loss = torch.sum(1-cos(xm,ym)) 164 | return loss 165 | 166 | 167 | def train_model(train_loader, test_loader, model, device, criterion, optimizer, num_epochs, output_directory): 168 | total_step = len(train_loader) 169 | model.train() 170 | 171 | #open files to log error 172 | train_error = open(output_directory + "training_error.txt", "a") 173 | test_error = open(output_directory + "test_error.txt", "a") 174 | 175 | best_model_wts = copy.deepcopy(model.state_dict()) 176 | best_loss_valid = float('inf') 177 | best_epoch = 1 178 | 179 | for epoch in range(num_epochs): 180 | running_loss = 0.0 181 | for i, (seqs, labels) in enumerate(train_loader): 182 | seqs = seqs.to(device) 183 | labels = labels.to(device) 184 | 185 | # Forward pass 186 | outputs, act, idx = model(seqs) 187 | loss = criterion(outputs, labels) # change input to 188 | running_loss += loss.item() 189 | 190 | # Backward and optimize 191 | optimizer.zero_grad() 192 | loss.backward() 193 | optimizer.step() 194 | 195 | if (i+1) % 100 == 0: 196 | print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 197 | .format(epoch+1, num_epochs, i+1, total_step, loss.item())) 198 | 199 | #save training loss to file 200 | epoch_loss = running_loss / len(train_loader.dataset) 201 | print("%s, %s" % (epoch, epoch_loss), file=train_error) 202 | 203 | #calculate test loss for epoch 204 | test_loss = 0.0 205 | with torch.no_grad(): 206 | model.eval() 207 | for i, (seqs, labels) in enumerate(test_loader): 208 | x = seqs.to(device) 209 | y = labels.to(device) 210 | outputs, act, idx = model(x) 211 | loss = criterion(outputs, y) 212 | test_loss += loss.item() 213 | 214 | test_loss = test_loss / len(test_loader.dataset) 215 | 216 | #save outputs for epoch 217 | print("%s, %s" % (epoch, test_loss), file=test_error) 218 | 219 | if test_loss < best_loss_valid: 220 | best_loss_valid = test_loss 221 | best_epoch = epoch 222 | best_model_wts = copy.deepcopy(model.state_dict()) 223 | print ('Saving the best model weights at Epoch [{}], Best Valid Loss: {:.4f}' 224 | .format(epoch+1, best_loss_valid)) 225 | 226 | 227 | train_error.close() 228 | test_error.close() 229 | 230 | model.load_state_dict(best_model_wts) 231 | return model, best_loss_valid 232 | 233 | 234 | def test_model(test_loader, model, device): 235 | num_filters=model.layer1_conv[0].out_channels 236 | predictions = torch.zeros(0, 81) 237 | max_activations = torch.zeros(0, num_filters) 238 | act_index = torch.zeros(0, num_filters) 239 | 240 | with torch.no_grad(): 241 | model.eval() 242 | for seqs, labels in test_loader: 243 | seqs = seqs.to(device) 244 | pred, act, idx = model(seqs) 245 | predictions = torch.cat((predictions, pred.type(torch.FloatTensor)), 0) 246 | max_activations = torch.cat((max_activations, act.type(torch.FloatTensor)), 0) 247 | act_index = torch.cat((act_index, idx.type(torch.FloatTensor)), 0) 248 | 249 | predictions = predictions.numpy() 250 | max_activations = max_activations.numpy() 251 | act_index = act_index.numpy() 252 | return predictions, max_activations, act_index 253 | 254 | 255 | 256 | def get_motifs(data_loader, model, device): 257 | num_filters=model.layer1_conv[0].out_channels 258 | activations = torch.zeros(0, num_filters, 251) 259 | predictions = torch.zeros(0, num_filters, 81) 260 | with torch.no_grad(): 261 | model.eval() 262 | for seqs, labels in data_loader: 263 | seqs = seqs.to(device) 264 | pred, act, idx = model(seqs, num_filters) 265 | 266 | activations = torch.cat((activations, act.type(torch.FloatTensor)), 0) 267 | predictions = torch.cat((predictions, pred.type(torch.FloatTensor)), 0) 268 | 269 | predictions = predictions.numpy() 270 | activations = activations.numpy() 271 | return activations, predictions 272 | -------------------------------------------------------------------------------- /code/extract_motifs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | import matplotlib 5 | import os 6 | import sys 7 | matplotlib.use('Agg') 8 | 9 | import aitac 10 | import plot_utils 11 | 12 | import time 13 | from sklearn.model_selection import train_test_split 14 | 15 | # Device configuration 16 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 17 | 18 | # Hyper parameters 19 | num_classes = 81 20 | batch_size = 100 21 | num_filters = 300 22 | 23 | #create output figure directory 24 | model_name = sys.argv[1] 25 | output_file_path = "../outputs/" + model_name + "/motifs/" 26 | directory = os.path.dirname(output_file_path) 27 | if not os.path.exists(directory): 28 | os.makedirs(directory) 29 | 30 | 31 | # Load all data 32 | x = np.load(sys.argv[2]) 33 | x = x.astype(np.float32) 34 | y = np.load(sys.argv[3]) 35 | peak_names = np.load(sys.argv[4]) 36 | 37 | # Data loader 38 | dataset = torch.utils.data.TensorDataset(torch.from_numpy(x), torch.from_numpy(y)) 39 | data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False) 40 | 41 | # load trained model 42 | model = aitac.ConvNet(num_classes, num_filters).to(device) 43 | checkpoint = torch.load('../models/' + model_name + '.ckpt') 44 | model.load_state_dict(checkpoint) 45 | 46 | #copy trained model weights to motif extraction model 47 | motif_model = aitac.motifCNN(model).to(device) 48 | motif_model.load_state_dict(model.state_dict()) 49 | 50 | # run predictions with full model on all data 51 | pred_full_model, max_activations, activation_idx = aitac.test_model(data_loader, model, device) 52 | correlations = plot_utils.plot_cors(y, pred_full_model, output_file_path) 53 | 54 | 55 | # find well predicted OCRs 56 | idx = np.argwhere(np.asarray(correlations)>0.75).squeeze() 57 | 58 | #get data subset for well predicted OCRs to run further test 59 | x2 = x[idx, :, :] 60 | y2 = y[idx, :] 61 | 62 | dataset = torch.utils.data.TensorDataset(torch.from_numpy(x2), torch.from_numpy(y2)) 63 | data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False) 64 | 65 | # non-modified results for well-predicted OCRs only 66 | pred_full_model2 = pred_full_model[idx,:] 67 | correlations2 = plot_utils.plot_cors(y2, pred_full_model2, output_file_path) 68 | 69 | 70 | # get first layer activations and predictions with leave-one-filter-out 71 | start = time.time() 72 | activations, predictions = aitac.get_motifs(data_loader, motif_model, device) 73 | print(time.time()- start) 74 | 75 | filt_corr, filt_infl, ave_filt_infl = plot_utils.plot_filt_corr(predictions, y2, correlations2, output_file_path) 76 | 77 | infl, infl_by_OCR = plot_utils.plot_filt_infl(pred_full_model2, predictions, output_file_path) 78 | 79 | pwm, act_ind, nseqs, activated_OCRs, n_activated_OCRs, OCR_matrix = plot_utils.get_memes(activations, x2, y2, output_file_path) 80 | 81 | 82 | #save predictions 83 | np.save(output_file_path + "filter_predictions.npy", predictions) 84 | np.save(output_file_path + "predictions.npy", pred_full_model) 85 | 86 | #save correlations 87 | np.save(output_file_path + "correlations.npy", correlations) 88 | np.save(output_file_path + "correaltions_per_filter.npy", filt_corr) 89 | 90 | #overall influence: 91 | np.save(output_file_path + "influence.npy", ave_filt_infl) 92 | np.save(output_file_path + "influence_by_OCR.npy", filt_infl) 93 | 94 | #influence by cell type: 95 | np.save(output_file_path + "filter_cellwise_influence.npy", infl) 96 | np.save(output_file_path + "cellwise_influence_by_OCR.npy", infl_by_OCR) 97 | 98 | #other metrics 99 | np.savetxt(output_file_path + "nseqs_per_filters.txt", nseqs) 100 | np.save(output_file_path + "mean_OCR_activation.npy", activated_OCRs) 101 | np.save(output_file_path + "n_activated_OCRs.npy", n_activated_OCRs) 102 | np.save(output_file_path + "OCR_matrix.npy", OCR_matrix) 103 | -------------------------------------------------------------------------------- /code/finetune_filter_subset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import random 3 | from sklearn.model_selection import train_test_split 4 | import torch.utils.data 5 | import matplotlib 6 | import os 7 | import sys 8 | matplotlib.use('Agg') 9 | 10 | import aitac 11 | import plot_utils 12 | 13 | 14 | # Device configuration 15 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 16 | 17 | # Hyper parameters 18 | num_epochs = 10 19 | num_classes = 81 20 | batch_size = 100 21 | learning_rate = 0.0001 22 | num_filters = 99 23 | 24 | #create output figure directory 25 | model_name = sys.argv[1] 26 | output_file_path = "../outputs/" + model_name + "/training/" 27 | directory = os.path.dirname(output_file_path) 28 | if not os.path.exists(directory): 29 | print("Creating directory %s" % output_file_path) 30 | os.makedirs(directory) 31 | else: 32 | print("Directory %s exists" % output_file_path) 33 | 34 | 35 | # Load all data 36 | x = np.load(sys.argv[2]) 37 | x = x.astype(np.float32) 38 | y = np.load(sys.argv[3]) 39 | y = y.astype(np.float32) 40 | peak_names = np.load(sys.argv[4]) 41 | 42 | #file containing original model weights 43 | original_model = sys.argv[5] 44 | 45 | #load names of test set from original model 46 | test_peaks = np.load("../outputs/" + original_model + "/training/test_OCR_names.npy") 47 | idx = np.in1d(peak_names, test_peaks) 48 | 49 | # split the data into training and test sets 50 | eval_data, eval_labels, eval_names = x[idx, :, :], y[idx, :], peak_names[idx] 51 | train_data, train_labels, train_names = x[~idx, :, :], y[~idx, :], peak_names[~idx] 52 | 53 | # split the data into training and validation sets 54 | train_data, valid_data, train_labels, valid_labels, train_names, valid_names = train_test_split(train_data, train_labels, train_names, test_size=0.1, random_state=40) 55 | 56 | # Data loader 57 | train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data), torch.from_numpy(train_labels)) 58 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False) 59 | 60 | valid_dataset = torch.utils.data.TensorDataset(torch.from_numpy(valid_data), torch.from_numpy(valid_labels)) 61 | valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False) 62 | 63 | eval_dataset = torch.utils.data.TensorDataset(torch.from_numpy(eval_data), torch.from_numpy(eval_labels)) 64 | eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=batch_size, shuffle=False) 65 | 66 | 67 | # create model 68 | model = aitac.ConvNet(num_classes, num_filters).to(device) 69 | 70 | # weights from model with 300 filters 71 | checkpoint = torch.load("../models/" + original_model + ".ckpt") 72 | 73 | #indices of filters in original model 74 | filters = np.loadtxt('../data/filter_set99_index.txt') 75 | 76 | #load new weights into model 77 | checkpoint2 = model.state_dict() 78 | 79 | #copy original model weights into new model 80 | for i, (layer_name, layer_weights) in enumerate(checkpoint.items()): 81 | 82 | # for all first layer weights take subset 83 | if i<6: 84 | subset_weights = layer_weights[filters, ...] 85 | checkpoint2[layer_name] = subset_weights 86 | elif i==6: 87 | subset_weights = layer_weights[:,filters, ...] 88 | checkpoint2[layer_name] = subset_weights 89 | 90 | #for remainder of layers take all weights 91 | else: 92 | checkpoint2[layer_name] = layer_weights 93 | 94 | model.load_state_dict(checkpoint2) 95 | 96 | #freeze first layer 97 | def freeze_layer(layer): 98 | for param in layer.parameters(): 99 | param.requires_grad = False 100 | 101 | freeze_layer(model.layer1_conv) 102 | 103 | 104 | # Loss and optimizer 105 | criterion = aitac.pearson_loss 106 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate) 107 | 108 | # train model 109 | model, best_loss_valid = aitac.train_model(train_loader, valid_loader, model, device, criterion, optimizer, num_epochs, output_file_path) 110 | 111 | # save the model checkpoint 112 | torch.save(model.state_dict(), '../models/model' + model_name + '.ckpt') 113 | 114 | #save the whole model 115 | torch.save(model, '../models/model' + model_name + '.pth') 116 | 117 | # Predict on test set 118 | predictions, max_activations, max_act_index = aitac.test_model(eval_loader, model, device) 119 | 120 | 121 | #-------------------------------------------# 122 | # Create Plots # 123 | #-------------------------------------------# 124 | 125 | # plot the correlations histogram 126 | # returns correlation measurement for every prediction-label pair 127 | print("Creating plots...") 128 | 129 | #plot_utils.plot_training_loss(training_loss, output_file_path) 130 | 131 | correlations = plot_utils.plot_cors(eval_labels, predictions, output_file_path) 132 | 133 | plot_utils.plot_corr_variance(eval_labels, correlations, output_file_path) 134 | 135 | quantile_indx = plot_utils.plot_piechart(correlations, eval_labels, output_file_path) 136 | 137 | 138 | #-------------------------------------------# 139 | # Save Files # 140 | #-------------------------------------------# 141 | 142 | #save predictions 143 | np.save(output_file_path + "predictions.npy", predictions) 144 | 145 | #save correlations 146 | np.save(output_file_path + "correlations.npy", correlations) 147 | 148 | #save max first layer activations 149 | np.save(output_file_path + "max_activations.npy", max_activations) 150 | np.save(output_file_path + "max_activation_index.npy", max_act_index) 151 | 152 | #save test data set 153 | np.save(output_file_path + "test_OCR_names.npy", eval_names) 154 | -------------------------------------------------------------------------------- /code/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import random 4 | import matplotlib 5 | from matplotlib.colors import LogNorm 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | 11 | def plot_training_loss(training_loss, output_file_path): 12 | #plot line plot of loss function per epoch 13 | plt.figure(figsize=(16, 9)) 14 | plt.plot(np.arange(len(training_loss)), training_loss) 15 | plt.xlabel('Epoch', fontsize=18) 16 | plt.ylabel('Loss', fontsize=18) 17 | plt.savefig(output_file_path + "training_loss.svg") 18 | plt.close() 19 | 20 | 21 | def plot_cors(obs, pred, output_file_path): 22 | correlations = [] 23 | vars = [] 24 | for i in range(len(pred)): 25 | var = np.var(obs[i, :]) 26 | vars.append(var) 27 | x = np.corrcoef(pred[i, :], obs[i, :])[0, 1] 28 | correlations.append(x) 29 | 30 | weighted_cor = np.dot(correlations, vars) / np.sum(vars) 31 | print('weighted_cor is {}'.format(weighted_cor)) 32 | 33 | nan_cors = [value for value in correlations if math.isnan(value)] 34 | print("number of NaN values: %d" % len(nan_cors)) 35 | correlations = [value for value in correlations if not math.isnan(value)] 36 | 37 | 38 | plt.clf() 39 | plt.hist(correlations, bins=30) 40 | plt.axvline(np.mean(correlations), color='r', linestyle='dashed', linewidth=2) 41 | plt.axvline(0, color='k', linestyle='solid', linewidth=2) 42 | try: 43 | plt.title("histogram of correlation. Avg cor = {%f}" % np.mean(correlations)) 44 | except Exception as e: 45 | print("could not set the title for graph") 46 | print(e) 47 | plt.ylabel("Frequency") 48 | plt.xlabel("correlation") 49 | plt.savefig(output_file_path + "basset_cor_hist.svg") 50 | plt.close() 51 | 52 | return correlations 53 | 54 | 55 | def plot_piechart(correlations, eval_labels, output_file_path): 56 | ind_collection = [] 57 | Q0_idx = [] 58 | Q1_idx = [] 59 | Q2_idx = [] 60 | Q3_idx = [] 61 | Q4_idx = [] 62 | Q5_idx = [] 63 | ind_collection.append(Q0_idx) 64 | ind_collection.append(Q1_idx) 65 | ind_collection.append(Q2_idx) 66 | ind_collection.append(Q3_idx) 67 | ind_collection.append(Q4_idx) 68 | ind_collection.append(Q5_idx) 69 | 70 | for i, x in enumerate(correlations): 71 | if x > 0.75: 72 | Q1_idx.append(i) 73 | if x > 0.9: 74 | Q0_idx.append(i) 75 | elif x > 0.5 and x <= 0.75: 76 | Q2_idx.append(i) 77 | elif x > 0.25 and x <= 0.5: 78 | Q3_idx.append(i) 79 | elif x > 0 and x <= 0.25: 80 | Q4_idx.append(i) 81 | elif x < 0: 82 | Q5_idx.append(i) 83 | 84 | # pie chart of correlations distribution 85 | pie_labels = "cor>0.75", "0.50.75", "0.5 activation_threshold[i]) 416 | 417 | #save ground truth peak heights of OCRs activated by each filter 418 | if indices[0].shape[0]>0: 419 | act_OCRs_tmp.append(y[j, :]) 420 | OCR_matrix[i, j] = 1 421 | 422 | for start in indices[0]: 423 | activation_indices.append(start) 424 | end = start+19 425 | act_seqs_list.append(sequences[j,:,start:end]) 426 | 427 | #convert act_seqs from list to array 428 | if act_seqs_list: 429 | act_seqs = np.stack(act_seqs_list) 430 | pwm_tmp = np.sum(act_seqs, axis=0) 431 | pfm_tmp=pwm_tmp 432 | total = np.sum(pwm_tmp, axis=0) 433 | pwm_tmp = np.nan_to_num(pwm_tmp/total) 434 | 435 | #permute pwm from A, T, G, C order to A, C, G, T order 436 | order = [0, 3, 2, 1] 437 | pwm[i,:,:] = pwm_tmp[order, :] 438 | pfm[i,:,:] = pfm_tmp[order, :] 439 | 440 | #store total number of sequences that activated that filter 441 | total_seq[i] = len(act_seqs_list) 442 | 443 | #save mean OCR activation 444 | act_OCRs_tmp = np.stack(act_OCRs_tmp) 445 | activated_OCRs[i, :] = np.mean(act_OCRs_tmp, axis=0) 446 | 447 | #save the number of activated OCRs 448 | n_activated_OCRs[i] = act_OCRs_tmp.shape[0] 449 | 450 | 451 | activated_OCRs = np.stack(activated_OCRs) 452 | 453 | #write motifs to meme format 454 | #PWM file: 455 | meme_file = open(output_file_path + "filter_motifs_pwm.meme", 'w') 456 | meme_file.write("MEME version 4 \n") 457 | 458 | #PFM file: 459 | meme_file_pfm = open(output_file_path + "filter_motifs_pfm.meme", 'w') 460 | meme_file_pfm.write("MEME version 4 \n") 461 | 462 | for i in range(0, 300): 463 | if np.sum(pwm[i,:,:]) >0: 464 | meme_file.write("\n") 465 | meme_file.write("MOTIF filter%s \n" % i) 466 | meme_file.write("letter-probability matrix: alength= 4 w= %d \n" % np.count_nonzero(np.sum(pwm[i,:,:], axis=0))) 467 | 468 | meme_file_pfm.write("\n") 469 | meme_file_pfm.write("MOTIF filter%s \n" % i) 470 | meme_file_pfm.write("letter-probability matrix: alength= 4 w= %d \n" % np.count_nonzero(np.sum(pwm[i,:,:], axis=0))) 471 | 472 | for j in range(0, 19): 473 | if np.sum(pwm[i,:,j]) > 0: 474 | meme_file.write(str(pwm[i,0,j]) + "\t" + str(pwm[i,1,j]) + "\t" + str(pwm[i,2,j]) + "\t" + str(pwm[i,3,j]) + "\n") 475 | meme_file_pfm.write(str(pfm[i,0,j]) + "\t" + str(pfm[i,1,j]) + "\t" + str(pfm[i,2,j]) + "\t" + str(pfm[i,3,j]) + "\n") 476 | 477 | meme_file.close() 478 | meme_file_pfm.close() 479 | 480 | #plot indices of first position in sequence that activates the filters 481 | activation_indices_array = np.stack(activation_indices) 482 | 483 | plt.clf() 484 | plt.hist(activation_indices_array.flatten(), bins=260) 485 | plt.title("histogram of position indices.") 486 | plt.ylabel("Frequency") 487 | plt.xlabel("Position") 488 | plt.savefig(output_file_path + "position_hist.svg") 489 | plt.close() 490 | 491 | #plot total sequences that activated each filter 492 | total_seq_array = np.stack(total_seq) 493 | 494 | plt.clf() 495 | plt.bar(np.arange(300), total_seq_array) 496 | plt.title("Number of sequences activating each filter") 497 | plt.ylabel("N sequences") 498 | plt.xlabel("Filter") 499 | plt.savefig(output_file_path + "nseqs_bar_graph.svg") 500 | plt.close() 501 | 502 | return pwm, activation_indices_array, total_seq_array, activated_OCRs, n_activated_OCRs, OCR_matrix 503 | 504 | 505 | #get cell-wise correlations 506 | def plot_cell_cors(obs, pred, cell_labels, output_file_path): 507 | correlations = [] 508 | for i in range(pred.shape[1]): 509 | x = np.corrcoef(pred[:, i], obs[:, i])[0, 1] 510 | correlations.append(x) 511 | 512 | #plot cell-wise correlation histogram 513 | plt.clf() 514 | plt.hist(correlations, bins=30) 515 | plt.axvline(np.mean(correlations), color='k', linestyle='dashed', linewidth=2) 516 | try: 517 | plt.title("histogram of correlation. Avg cor = {1:.2f}".format(np.mean(correlations))) 518 | except Exception as e: 519 | print("could not set the title for graph") 520 | print(e) 521 | plt.ylabel("Frequency") 522 | plt.xlabel("Correlation") 523 | plt.savefig(output_file_path + "basset_cell_wise_cor_hist.svg") 524 | plt.close() 525 | 526 | #plot cell-wise correlations by cell type 527 | plt.clf() 528 | plt.bar(np.arange(len(cell_labels.shape)), correlations) 529 | plt.title("Correlations by Cell Type") 530 | plt.ylabel("Correlation") 531 | plt.xlabel("Cell Type") 532 | plt.xticks(np.arange(len(cell_labels)), cell_labels, rotation='vertical', fontsize=3.5) 533 | plt.savefig(output_file_path + "cellwise_cor_bargraph.svg") 534 | plt.close() 535 | 536 | return correlations 537 | 538 | #convert mouse model predictions to human cell predictions 539 | def mouse2human(mouse_predictions, mouse_cell_types, map, method='average'): 540 | 541 | human_cells = np.unique(map[:,1]) 542 | 543 | human_predictions = np.zeros((mouse_predictions.shape[0], human_cells.shape[0])) 544 | 545 | for i, celltype in enumerate(human_cells): 546 | matches = map[np.where(map[:,1] == celltype)][:,0] 547 | idx = np.in1d(mouse_cell_types[:,1], matches).nonzero() 548 | human_predictions[:, i] = np.mean(mouse_predictions[:, idx], axis=2).squeeze() 549 | 550 | return human_predictions, human_cells -------------------------------------------------------------------------------- /code/second_layer_motifs/aitac_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data 5 | import random 6 | import numpy as np 7 | import copy 8 | 9 | # Convolutional neural network 10 | class ConvNet(nn.Module): 11 | def __init__(self, num_classes): 12 | super(ConvNet, self).__init__() 13 | 14 | # for layer one, separate convolution and relu step from maxpool and batch normalization 15 | # to extract convolutional filters 16 | self.layer1_conv = nn.Sequential( 17 | nn.Conv2d(in_channels=1, 18 | out_channels=300, 19 | kernel_size=(4, 19), 20 | stride=1, 21 | padding=0), # padding is done in forward method along 1 dimension only 22 | nn.ReLU()) 23 | 24 | self.layer1_process = nn.Sequential( 25 | nn.MaxPool2d(kernel_size=(1,3), stride=(1,3), padding=(0,1)), 26 | nn.BatchNorm2d(300)) 27 | 28 | self.layer2 = nn.Sequential( 29 | nn.Conv2d(in_channels=300, 30 | out_channels=200, 31 | kernel_size=(1, 11), 32 | stride=1, 33 | padding=0), # padding is done in forward method along 1 dimension only 34 | nn.ReLU()) 35 | 36 | self.layer2_process = nn.Sequential( 37 | nn.MaxPool2d(kernel_size=(1,4), stride=(1,4), padding=(0,1)), 38 | nn.BatchNorm2d(200)) 39 | 40 | self.layer3 = nn.Sequential( 41 | nn.Conv2d(in_channels=200, 42 | out_channels=200, 43 | kernel_size=(1, 7), 44 | stride=1, 45 | padding=0), # padding is done in forward method along 1 dimension only 46 | nn.ReLU(), 47 | nn.MaxPool2d(kernel_size=(1, 4), stride=(1,4), padding=(0,1)), 48 | nn.BatchNorm2d(200)) 49 | 50 | self.layer4 = nn.Sequential( 51 | nn.Linear(in_features=1000, 52 | out_features=1000), 53 | nn.ReLU(), 54 | nn.Dropout(p=0.03)) 55 | 56 | self.layer5 = nn.Sequential( 57 | nn.Linear(in_features=1000, 58 | out_features=1000), 59 | nn.ReLU(), 60 | nn.Dropout(p=0.03)) 61 | 62 | self.layer6 = nn.Sequential( 63 | nn.Linear(in_features=1000, 64 | out_features=num_classes))#, 65 | #nn.Sigmoid()) 66 | 67 | 68 | def forward(self, input): 69 | # run all layers on input data 70 | # add dummy dimension to input (for num channels=1) 71 | input = torch.unsqueeze(input, 1) 72 | 73 | # Run convolutional layers 74 | input = F.pad(input, (9, 9), mode='constant', value=0) # padding - last dimension goes first 75 | out = self.layer1_conv(input) 76 | activations = torch.squeeze(out) 77 | out = self.layer1_process(out) 78 | 79 | out = F.pad(out, (5, 5), mode='constant', value=0) 80 | out = self.layer2(out) 81 | activations2 = torch.squeeze(out) 82 | out = self.layer2_process(out) 83 | 84 | out = F.pad(out, (3, 3), mode='constant', value=0) 85 | out = self.layer3(out) 86 | 87 | # Flatten output of convolutional layers 88 | out = out.view(out.size()[0], -1) 89 | 90 | # run fully connected layers 91 | out = self.layer4(out) 92 | out = self.layer5(out) 93 | predictions = self.layer6(out) 94 | 95 | activations, act_index = torch.max(activations, dim=2) 96 | 97 | return predictions, activations, activations2, act_index 98 | 99 | # define model for extracting motifs from first convolutional layer 100 | # and determining importance of each filter on prediction 101 | class motifCNN(nn.Module): 102 | def __init__(self, original_model): 103 | super(motifCNN, self).__init__() 104 | self.layer1_conv = nn.Sequential(*list(original_model.children())[0]) 105 | self.layer1_process = nn.Sequential(*list(original_model.children())[1]) 106 | self.layer2 = nn.Sequential(*list(original_model.children())[2]) 107 | self.layer2_process = nn.Sequential(*list(original_model.children())[3]) 108 | self.layer3 = nn.Sequential(*list(original_model.children())[4]) 109 | 110 | self.layer4 = nn.Sequential(*list(original_model.children())[5]) 111 | self.layer5 = nn.Sequential(*list(original_model.children())[6]) 112 | self.layer6 = nn.Sequential(*list(original_model.children())[7]) 113 | 114 | 115 | def forward(self, input): 116 | # add dummy dimension to input (for num channels=1) 117 | input = torch.unsqueeze(input, 1) 118 | 119 | # Run convolutional layers 120 | input = F.pad(input, (9, 9), mode='constant', value=0) # padding - last dimension goes first 121 | out= self.layer1_conv(input) 122 | layer1_activations = torch.squeeze(out) 123 | 124 | #do maxpooling and batch normalization for layer 1 125 | layer1_out = self.layer1_process(out) 126 | layer1_out = F.pad(layer1_out, (5, 5), mode='constant', value=0) 127 | 128 | #calculate average activation by filter for the whole batch 129 | filter_means_batch = layer1_activations.mean(0).mean(1) 130 | 131 | # run all other layers with 1 filter left out at a time 132 | batch_size = layer1_out.shape[0] 133 | predictions = torch.zeros(batch_size, 300, 81) 134 | 135 | #filter_matches = np.load("../outputs/motifs2/run2_motif_matches.npy") 136 | 137 | for i in range(300): 138 | #modify filter i of first layer output 139 | filter_input = layer1_out.clone() 140 | 141 | filter_input[:,i,:,:] = filter_input.new_full((batch_size, 1, 94), fill_value=filter_means_batch[i]) 142 | 143 | out = self.layer2(filter_input) 144 | out = F.pad(out, (3, 3), mode='constant', value=0) 145 | out = self.layer3(out) 146 | 147 | # Flatten output of convolutional layers 148 | out = out.view(out.size()[0], -1) 149 | # run fully connected layers 150 | out = self.layer4(out) 151 | out = self.layer5(out) 152 | out = self.layer6(out) 153 | 154 | predictions[:,i,:] = out 155 | 156 | activations, act_index = torch.max(layer1_activations, dim=2) 157 | 158 | return predictions, layer1_activations, act_index 159 | 160 | 161 | #define the model loss 162 | def pearson_loss(x,y): 163 | mx = torch.mean(x, dim=1, keepdim=True) 164 | my = torch.mean(y, dim=1, keepdim=True) 165 | xm, ym = x - mx, y - my 166 | 167 | cos = nn.CosineSimilarity(dim=1, eps=1e-6) 168 | loss = torch.sum(1-cos(xm,ym)) 169 | return loss 170 | 171 | 172 | def train_model(train_loader, test_loader, model, device, criterion, alpha, optimizer, num_epochs, output_directory): 173 | total_step = len(train_loader) 174 | model.train() 175 | 176 | #open files to log error 177 | train_error = open(output_directory + "training_error.txt", "a") 178 | test_error = open(output_directory + "test_error.txt", "a") 179 | 180 | best_model_wts = copy.deepcopy(model.state_dict()) 181 | best_loss_valid = float('inf') 182 | best_epoch = 1 183 | 184 | for epoch in range(num_epochs): 185 | running_loss = 0.0 186 | for i, (seqs, labels) in enumerate(train_loader): 187 | seqs = seqs.to(device) 188 | labels = labels.to(device) 189 | 190 | # Forward pass 191 | outputs, act, act2, idx = model(seqs) 192 | loss = criterion(outputs, labels, alpha) # change input to 193 | running_loss += loss.item() 194 | 195 | # Backward and optimize 196 | optimizer.zero_grad() 197 | loss.backward() 198 | optimizer.step() 199 | 200 | #if (i+1) % 100 == 0: 201 | # print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 202 | # .format(epoch+1, num_epochs, i+1, total_step, loss.item())) 203 | 204 | #save training loss to file 205 | epoch_loss = running_loss / len(train_loader.dataset) 206 | print("%s, %s" % (epoch, epoch_loss), file=train_error) 207 | 208 | #calculate test loss for epoch 209 | test_loss = 0.0 210 | with torch.no_grad(): 211 | model.eval() 212 | for i, (seqs, labels) in enumerate(test_loader): 213 | x = seqs.to(device) 214 | y = labels.to(device) 215 | outputs, act, idx = model(x) 216 | loss = criterion(outputs, y, alpha) 217 | test_loss += loss.item() 218 | 219 | test_loss = test_loss / len(test_loader.dataset) 220 | 221 | #save outputs for epoch 222 | print("%s, %s" % (epoch, test_loss), file=test_error) 223 | 224 | if test_loss < best_loss_valid: 225 | best_loss_valid = test_loss 226 | best_epoch = epoch 227 | best_model_wts = copy.deepcopy(model.state_dict()) 228 | print ('Saving the best model weights at Epoch [{}], Best Valid Loss: {:.4f}' 229 | .format(epoch+1, best_loss_valid)) 230 | 231 | 232 | train_error.close() 233 | test_error.close() 234 | 235 | #model.load_state_dict(best_model_wts) 236 | return model, best_loss_valid 237 | 238 | 239 | def test_model(test_loader, model, device): 240 | predictions = torch.zeros(0, 81) 241 | max_activations = torch.zeros(0, 300) 242 | activations2 = torch.zeros(0,200,84) 243 | act_index = torch.zeros(0, 300) 244 | 245 | with torch.no_grad(): 246 | model.eval() 247 | for seqs, labels in test_loader: 248 | seqs = seqs.to(device) 249 | pred, act, act2, idx = model(seqs) 250 | predictions = torch.cat((predictions, pred.type(torch.FloatTensor)), 0) 251 | max_activations = torch.cat((max_activations, act.type(torch.FloatTensor)), 0) 252 | activations2 = torch.cat((activations2, act2.type(torch.FloatTensor)), 0) 253 | act_index = torch.cat((act_index, idx.type(torch.FloatTensor)), 0) 254 | 255 | predictions = predictions.numpy() 256 | max_activations = max_activations.numpy() 257 | activations2 = activations2.numpy() 258 | act_index = act_index.numpy() 259 | return predictions, max_activations, activations2, act_index 260 | 261 | 262 | 263 | def get_motifs(data_loader, model, device): 264 | activations = torch.zeros(0, 300, 251) 265 | predictions = torch.zeros(0, 300, 81) 266 | with torch.no_grad(): 267 | model.eval() 268 | for seqs, labels in data_loader: 269 | seqs = seqs.to(device) 270 | pred, act = model(seqs) 271 | 272 | activations = torch.cat((activations, act.type(torch.FloatTensor)), 0) 273 | predictions = torch.cat((predictions, pred.type(torch.FloatTensor)), 0) 274 | 275 | predictions = predictions.numpy() 276 | activations = activations.numpy() 277 | return activations, predictions 278 | -------------------------------------------------------------------------------- /code/second_layer_motifs/extract_layer2_motifs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import random 3 | from sklearn.model_selection import train_test_split 4 | import torch.utils.data 5 | import matplotlib 6 | import os 7 | import sys 8 | matplotlib.use('Agg') 9 | 10 | import aitac_v2 11 | import plot_utils 12 | 13 | 14 | # Device configuration 15 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 16 | 17 | # Hyper parameters 18 | num_epochs = 10 19 | num_classes = 81 20 | batch_size = 100 21 | learning_rate = 0.001 22 | 23 | #pre-trained model to use 24 | model_name = sys.argv[1] 25 | 26 | #create output figure directory 27 | output_file_path = "../../outputs/" + model_name + "/layer2_motifs/" 28 | directory = os.path.dirname(output_file_path) 29 | if not os.path.exists(directory): 30 | print("Creating directory %s" % output_file_path) 31 | os.makedirs(directory) 32 | else: 33 | print("Directory %s exists" % output_file_path) 34 | 35 | 36 | # Load all data 37 | x = np.load(sys.argv[2]) 38 | x = x.astype(np.float32) 39 | y = np.load(sys.argv[3]) 40 | y = y.astype(np.float32) 41 | peak_names = np.load(sys.argv[4]) 42 | 43 | #load names of test set from original model 44 | test_peaks = np.load("../../outputs/" + model_name + "/training/test_OCR_names.npy") 45 | idx = np.in1d(peak_names, test_peaks) 46 | 47 | # split the data into training and test sets 48 | eval_data, eval_labels, eval_names = x[idx, :, :], y[idx, :], peak_names[idx] 49 | train_data, train_labels, train_names = x[~idx, :, :], y[~idx, :], peak_names[~idx] 50 | 51 | # Data loader 52 | eval_dataset = torch.utils.data.TensorDataset(torch.from_numpy(eval_data), torch.from_numpy(eval_labels)) 53 | eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=batch_size, shuffle=False) 54 | 55 | #load trained model weights 56 | checkpoint = torch.load("../../models/" + model_name + ".ckpt") 57 | 58 | # initialize model 59 | model = aitac_v2.ConvNet(num_classes).to(device) 60 | checkpoint2 = model.state_dict() 61 | 62 | #copy original model weights into new model 63 | for i, (layer_name, layer_weights) in enumerate(checkpoint.items()): 64 | new_name = list(checkpoint2.keys())[i] 65 | checkpoint2[new_name] = layer_weights 66 | 67 | #load weights into new model 68 | model.load_state_dict(checkpoint2) 69 | 70 | #get layer 2 motifs 71 | predictions, max_act_layer2, activations_layer2, act_index_layer2 = aitac_v2.test_model(eval_loader, model, device) 72 | 73 | correlations = plot_utils.plot_cors(eval_labels, predictions, output_file_path) 74 | 75 | #get PWMs of second layer motifs 76 | plot_utils.get_memes2(activations_layer2, eval_data, eval_labels, output_file_path) 77 | 78 | #save files 79 | np.save(output_file_path + "second_layer_maximum_activations.npy", max_act_layer2) 80 | np.save(output_file_path + "second_layer_maxact_index.npy", act_index_layer2) 81 | 82 | -------------------------------------------------------------------------------- /code/second_layer_motifs/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import random 4 | import matplotlib 5 | from matplotlib.colors import LogNorm 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | 11 | def plot_training_loss(training_loss, output_file_path): 12 | #plot line plot of loss function per epoch 13 | plt.figure(figsize=(16, 9)) 14 | plt.plot(np.arange(len(training_loss)), training_loss) 15 | plt.xlabel('Epoch', fontsize=18) 16 | plt.ylabel('Loss', fontsize=18) 17 | plt.savefig(output_file_path + "training_loss.svg") 18 | plt.close() 19 | 20 | 21 | def plot_cors(obs, pred, output_file_path): 22 | correlations = [] 23 | vars = [] 24 | for i in range(len(pred)): 25 | var = np.var(obs[i, :]) 26 | vars.append(var) 27 | x = np.corrcoef(pred[i, :], obs[i, :])[0, 1] 28 | correlations.append(x) 29 | 30 | weighted_cor = np.dot(correlations, vars) / np.sum(vars) 31 | print('weighted_cor is {}'.format(weighted_cor)) 32 | 33 | nan_cors = [value for value in correlations if math.isnan(value)] 34 | print("number of NaN values: %d" % len(nan_cors)) 35 | correlations = [value for value in correlations if not math.isnan(value)] 36 | 37 | 38 | plt.clf() 39 | plt.hist(correlations, bins=30) 40 | plt.axvline(np.mean(correlations), color='r', linestyle='dashed', linewidth=2) 41 | plt.axvline(0, color='k', linestyle='solid', linewidth=2) 42 | try: 43 | plt.title("histogram of correlation. Avg cor = {%f}" % np.mean(correlations)) 44 | except Exception as e: 45 | print("could not set the title for graph") 46 | print(e) 47 | plt.ylabel("Frequency") 48 | plt.xlabel("correlation") 49 | plt.savefig(output_file_path + "basset_cor_hist.svg") 50 | plt.close() 51 | 52 | return correlations 53 | 54 | 55 | def plot_piechart(correlations, eval_labels, output_file_path): 56 | ind_collection = [] 57 | Q0_idx = [] 58 | Q1_idx = [] 59 | Q2_idx = [] 60 | Q3_idx = [] 61 | Q4_idx = [] 62 | Q5_idx = [] 63 | ind_collection.append(Q0_idx) 64 | ind_collection.append(Q1_idx) 65 | ind_collection.append(Q2_idx) 66 | ind_collection.append(Q3_idx) 67 | ind_collection.append(Q4_idx) 68 | ind_collection.append(Q5_idx) 69 | 70 | for i, x in enumerate(correlations): 71 | if x > 0.75: 72 | Q1_idx.append(i) 73 | if x > 0.9: 74 | Q0_idx.append(i) 75 | elif x > 0.5 and x <= 0.75: 76 | Q2_idx.append(i) 77 | elif x > 0.25 and x <= 0.5: 78 | Q3_idx.append(i) 79 | elif x > 0 and x <= 0.25: 80 | Q4_idx.append(i) 81 | elif x < 0: 82 | Q5_idx.append(i) 83 | 84 | # pie chart of correlations distribution 85 | pie_labels = "cor>0.75", "0.50.75", "0.5 activation_threshold[i]) 417 | 418 | #save ground truth peak heights of OCRs activated by each filter 419 | if indices[0].shape[0]>0: 420 | act_OCRs_tmp.append(y[j, :]) 421 | OCR_matrix[i, j] = 1 422 | 423 | for start in indices[0]: 424 | activation_indices.append(start) 425 | end = start+19 426 | act_seqs_list.append(sequences[j,:,start:end])#*activations[j, i, start]) 427 | 428 | #convert act_seqs from list to array 429 | if act_seqs_list: 430 | act_seqs = np.stack(act_seqs_list) 431 | pwm_tmp = np.sum(act_seqs, axis=0) 432 | pfm_tmp=pwm_tmp 433 | total = np.sum(pwm_tmp, axis=0) 434 | pwm_tmp = np.nan_to_num(pwm_tmp/total) 435 | 436 | #permute pwm from A, T, G, C order to A, C, G, T order 437 | order = [0, 3, 2, 1] 438 | pwm[i,:,:] = pwm_tmp[order, :] 439 | pfm[i,:,:] = pfm_tmp[order, :] 440 | 441 | #store total number of sequences that activated that filter 442 | total_seq[i] = len(act_seqs_list) 443 | 444 | #save mean OCR activation 445 | act_OCRs_tmp = np.stack(act_OCRs_tmp) 446 | activated_OCRs[i, :] = np.mean(act_OCRs_tmp, axis=0) 447 | 448 | #save the number of activated OCRs 449 | n_activated_OCRs[i] = act_OCRs_tmp.shape[0] 450 | 451 | 452 | activated_OCRs = np.stack(activated_OCRs) 453 | 454 | #write motifs to meme format 455 | #PWM file: 456 | meme_file = open(output_file_path + "filter_motifs_pwm.meme", 'w') 457 | meme_file.write("MEME version 4 \n") 458 | 459 | #PFM file: 460 | meme_file_pfm = open(output_file_path + "filter_motifs_pfm.meme", 'w') 461 | meme_file_pfm.write("MEME version 4 \n") 462 | 463 | for i in range(0, 300): 464 | if np.sum(pwm[i,:,:]) >0: 465 | meme_file.write("\n") 466 | meme_file.write("MOTIF filter%s \n" % i) 467 | meme_file.write("letter-probability matrix: alength= 4 w= %d \n" % np.count_nonzero(np.sum(pwm[i,:,:], axis=0))) 468 | 469 | meme_file_pfm.write("\n") 470 | meme_file_pfm.write("MOTIF filter%s \n" % i) 471 | meme_file_pfm.write("letter-probability matrix: alength= 4 w= %d \n" % np.count_nonzero(np.sum(pwm[i,:,:], axis=0))) 472 | 473 | for j in range(0, 19): 474 | if np.sum(pwm[i,:,j]) > 0: 475 | meme_file.write(str(pwm[i,0,j]) + "\t" + str(pwm[i,1,j]) + "\t" + str(pwm[i,2,j]) + "\t" + str(pwm[i,3,j]) + "\n") 476 | meme_file_pfm.write(str(pfm[i,0,j]) + "\t" + str(pfm[i,1,j]) + "\t" + str(pfm[i,2,j]) + "\t" + str(pfm[i,3,j]) + "\n") 477 | 478 | meme_file.close() 479 | meme_file_pfm.close() 480 | 481 | #plot indices of first position in sequence that activates the filters 482 | activation_indices_array = np.stack(activation_indices) 483 | 484 | plt.clf() 485 | plt.hist(activation_indices_array.flatten(), bins=260) 486 | plt.title("histogram of position indices.") 487 | plt.ylabel("Frequency") 488 | plt.xlabel("Position") 489 | plt.savefig(output_file_path + "position_hist.svg") 490 | plt.close() 491 | 492 | #plot total sequences that activated each filter 493 | total_seq_array = np.stack(total_seq) 494 | 495 | plt.clf() 496 | plt.bar(np.arange(300), total_seq_array) 497 | plt.title("Number of sequences activating each filter") 498 | plt.ylabel("N sequences") 499 | plt.xlabel("Filter") 500 | plt.savefig(output_file_path + "nseqs_bar_graph.svg") 501 | plt.close() 502 | 503 | return pwm, activation_indices_array, total_seq_array, activated_OCRs, n_activated_OCRs, OCR_matrix 504 | 505 | #get layer 2 motifs 506 | def get_memes2(activations, sequences, y, output_file_path): 507 | #find the threshold value for activation 508 | activation_threshold = 0.5*np.amax(activations, axis=(0,2)) 509 | 510 | #pad sequences: 511 | npad = ((0, 0), (0, 0), (26, 26)) 512 | sequences = np.pad(sequences, pad_width=npad, mode='constant', constant_values=0) 513 | 514 | pwm = np.zeros((200, 4, 51)) 515 | pfm = np.zeros((200, 4, 51)) 516 | nsamples = activations.shape[0] 517 | 518 | for i in range(200): 519 | #create list to store 19 bp sequences that activated filter 520 | act_seqs_list = [] 521 | act_OCRs_tmp = [] 522 | for j in range(nsamples): 523 | # find all indices where filter is activated 524 | indices = np.where(activations[j,i,:] > activation_threshold[i]) 525 | 526 | for start in indices[0]: 527 | start = start*3 528 | end = start+51 529 | act_seqs_list.append(sequences[j,:,start:end]) #*activations[j, i, start] 530 | 531 | #convert act_seqs from list to array 532 | if act_seqs_list: 533 | act_seqs = np.stack(act_seqs_list) 534 | pwm_tmp = np.sum(act_seqs, axis=0) 535 | pfm_tmp=pwm_tmp 536 | total = np.sum(pwm_tmp, axis=0) 537 | pwm_tmp = np.nan_to_num(pwm_tmp/total) 538 | 539 | #permute pwm from A, T, G, C order to A, C, G, T order 540 | order = [0, 3, 2, 1] 541 | pwm[i,:,:] = pwm_tmp[order, :] 542 | pfm[i,:,:] = pfm_tmp[order, :] 543 | 544 | #store total number of sequences that activated that filter 545 | #total_seq[i] = len(act_seqs_list) 546 | 547 | #save mean OCR activation 548 | #act_OCRs_tmp = np.stack(act_OCRs_tmp) 549 | #activated_OCRs[i, :] = np.mean(act_OCRs_tmp, axis=0) 550 | 551 | #save the number of activated OCRs 552 | #n_activated_OCRs[i] = act_OCRs_tmp.shape[0] 553 | 554 | 555 | #activated_OCRs = np.stack(activated_OCRs) 556 | 557 | #write motifs to meme format 558 | #PWM file: 559 | meme_file = open(output_file_path + "layer2_filter_motifs_pwm.meme", 'w') 560 | meme_file.write("MEME version 4 \n") 561 | 562 | #PFM file: 563 | meme_file_pfm = open(output_file_path + "layer2_filter_motifs_pfm.meme", 'w') 564 | meme_file_pfm.write("MEME version 4 \n") 565 | 566 | for i in range(0, 200): 567 | if np.sum(pwm[i,:,:]) >0: 568 | meme_file.write("\n") 569 | meme_file.write("MOTIF filter%s \n" % i) 570 | meme_file.write("letter-probability matrix: alength= 4 w= %d \n" % np.count_nonzero(np.sum(pwm[i,:,:], axis=0))) 571 | 572 | meme_file_pfm.write("\n") 573 | meme_file_pfm.write("MOTIF filter%s \n" % i) 574 | meme_file_pfm.write("letter-probability matrix: alength= 4 w= %d \n" % np.count_nonzero(np.sum(pwm[i,:,:], axis=0))) 575 | 576 | for j in range(0, 51): 577 | if np.sum(pwm[i,:,j]) > 0: 578 | meme_file.write(str(pwm[i,0,j]) + "\t" + str(pwm[i,1,j]) + "\t" + str(pwm[i,2,j]) + "\t" + str(pwm[i,3,j]) + "\n") 579 | meme_file_pfm.write(str(pfm[i,0,j]) + "\t" + str(pfm[i,1,j]) + "\t" + str(pfm[i,2,j]) + "\t" + str(pfm[i,3,j]) + "\n") 580 | 581 | meme_file.close() 582 | meme_file_pfm.close() 583 | 584 | 585 | 586 | #get cell-wise correlations 587 | def plot_cell_cors(obs, pred, cell_labels, output_file_path): 588 | correlations = [] 589 | for i in range(pred.shape[1]): 590 | x = np.corrcoef(pred[:, i], obs[:, i])[0, 1] 591 | correlations.append(x) 592 | 593 | #plot cell-wise correlation histogram 594 | plt.clf() 595 | plt.hist(correlations, bins=30) 596 | plt.axvline(np.mean(correlations), color='k', linestyle='dashed', linewidth=2) 597 | try: 598 | plt.title("histogram of correlation. Avg cor = {1:.2f}".format(np.mean(correlations))) 599 | except Exception as e: 600 | print("could not set the title for graph") 601 | print(e) 602 | plt.ylabel("Frequency") 603 | plt.xlabel("Correlation") 604 | plt.savefig(output_file_path + "basset_cell_wise_cor_hist.svg") 605 | plt.close() 606 | 607 | #plot cell-wise correlations by cell type 608 | plt.clf() 609 | plt.bar(np.arange(len(cell_labels.shape)), correlations) 610 | plt.title("Correlations by Cell Type") 611 | plt.ylabel("Correlation") 612 | plt.xlabel("Cell Type") 613 | plt.xticks(np.arange(len(cell_labels)), cell_labels, rotation='vertical', fontsize=3.5) 614 | plt.savefig(output_file_path + "cellwise_cor_bargraph.svg") 615 | plt.close() 616 | 617 | return correlations 618 | 619 | #convert mouse model predictions to human cell predictions 620 | def mouse2human(mouse_predictions, mouse_cell_types, map, method='average'): 621 | 622 | human_cells = np.unique(map[:,1]) 623 | 624 | human_predictions = np.zeros((mouse_predictions.shape[0], human_cells.shape[0])) 625 | 626 | for i, celltype in enumerate(human_cells): 627 | matches = map[np.where(map[:,1] == celltype)][:,0] 628 | idx = np.in1d(mouse_cell_types[:,1], matches).nonzero() 629 | human_predictions[:, i] = np.mean(mouse_predictions[:, idx], axis=2).squeeze() 630 | 631 | return human_predictions, human_cells -------------------------------------------------------------------------------- /code/summary_sheet.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import seaborn as sns 5 | import matplotlib.pyplot as plt 6 | import matplotlib.image as mpimg 7 | 8 | #create output figure directory 9 | model_name = sys.argv[1] 10 | file_path = "../outputs/" + model_name + "/motifs/" 11 | directory = os.path.dirname(output_file_path) 12 | if not os.path.exists(directory): 13 | print("Creating directory %s" % output_file_path) 14 | os.makedirs(directory) 15 | else: 16 | print("Directory %s exists" % output_file_path) 17 | 18 | #files 19 | meme_file = file_path + "filter_motifs_pwm.meme" 20 | infl_file = file_path + "influence.npy" 21 | cell_infl_file = file_path + "filter_cellwise_influence.npy" 22 | nseqs_file = file_path + "nseqs_per_filters.txt" 23 | 24 | cell_type_file = "../data/cell_type_names.npy" 25 | tomtom_file = file_path + "tomtom/tomtom.tsv" 26 | 27 | #load cell type names 28 | col_names = np.load(cell_type_file) 29 | #filter names 30 | rows = ['filter' + str(i) for i in range(300)] 31 | 32 | #tomtom results 33 | #load tomtom results 34 | tomtom = pd.read_csv(tomtom_file, sep='\t') 35 | tomtom = tomtom[:-3] 36 | 37 | #read in motif TF information 38 | #files downloaded from the CisBP database 39 | TF_info = pd.read_csv('../data/Mus_musculus_2018_09_27_3_44_pm/TF_Information_all_motifs.txt', sep='\t') 40 | 41 | #leave only rows with q-value <0.05 42 | tomtom = tomtom.loc[tomtom['q-value']<0.05] 43 | 44 | #merge results with TF information 45 | tomtom = tomtom.merge(TF_info[['Motif_ID', 'TF_Name', 'Family_Name', 'TF_Status', 'Motif_Type']], how='left', left_on='Target_ID', right_on='Motif_ID') 46 | tomtom = tomtom.drop(columns=['Motif_ID']) 47 | 48 | # get top 5 transcription factor matches for each motif 49 | temp = tomtom[['Target_ID', 'TF_Name', 'Family_Name', 'TF_Status', 'Motif_Type']].drop_duplicates() 50 | temp = temp.groupby(['Target_ID']).agg(lambda x : ', '.join(x.astype(str))) 51 | 52 | # combine matrices 53 | tomtom = tomtom.drop(['TF_Name', 'Family_Name', 'TF_Status', 'Motif_Type'], axis=1) 54 | tomtom = tomtom.merge(temp, how='left', left_on=['Target_ID'], right_on=['Target_ID']).drop_duplicates() 55 | 56 | tomtom['log_qval'] = -np.log(tomtom['q-value']) 57 | tomtom['log_qval'].fillna(0, inplace=True) 58 | 59 | #load influence score 60 | infl = np.load(infl_file) 61 | 62 | rows = ['filter' + str(i) for i in range(infl.shape[0])] 63 | infl = pd.DataFrame(infl, columns=["Influence"]) 64 | infl.index = rows 65 | tomtom = tomtom.merge(infl, how='outer', left_on='Query_ID',right_index=True) 66 | 67 | #load cell-wise influence for individual filters 68 | cell_infl = np.load(cell_infl_file) 69 | cell_infl = pd.DataFrame(cell_infl, index=rows, columns = col_names[:,1]) 70 | tomtom = tomtom.merge(cell_infl, how='outer', left_on='Query_ID', right_index=True) 71 | 72 | # number of sequences per filter: 73 | nseqs = pd.read_csv(nseqs_file, header = None, names = ['num_seqs']) 74 | nseqs.index = rows 75 | 76 | #add to data frame 77 | tomtom = tomtom.merge(nseqs, how='outer', left_on='Query_ID', right_index=True) 78 | 79 | 80 | #read in meme file 81 | with open(meme_file) as fp: 82 | line = fp.readline() 83 | motifs=[] 84 | motif_names=[] 85 | while line: 86 | #determine length of next motif 87 | if line.split(" ")[0]=='MOTIF': 88 | #add motif number to separate array 89 | motif_names.append(line.split(" ")[1]) 90 | 91 | #get length of motif 92 | line2=fp.readline().split(" ") 93 | motif_length = int(float(line2[5])) 94 | 95 | #read in motif 96 | current_motif=np.zeros((19, 4)) 97 | for i in range(motif_length): 98 | current_motif[i,:] = fp.readline().split("\t") 99 | 100 | motifs.append(current_motif) 101 | 102 | line = fp.readline() 103 | 104 | motifs = np.stack(motifs) 105 | motif_names = np.stack(motif_names) 106 | 107 | fp.close() 108 | 109 | #set background frequencies of nucleotides 110 | bckgrnd = [0.25, 0.25, 0.25, 0.25] 111 | 112 | #compute information content of each motif 113 | info_content = [] 114 | position_ic = [] 115 | for i in range(motifs.shape[0]): 116 | length = motifs[i,:,:].shape[0] 117 | position_wise_ic = np.subtract(np.sum(np.multiply(motifs[i,:,:],np.log2(motifs[i,:,:] + 0.00000000001)), axis=1),np.sum(np.multiply(bckgrnd,np.log2(bckgrnd)))) 118 | 119 | position_ic.append(position_wise_ic) 120 | ic = np.sum(position_wise_ic, axis=0) 121 | info_content.append(ic) 122 | 123 | info_content = np.stack(info_content) 124 | position_ic = np.stack(position_ic) 125 | 126 | #length of motif with high info content 127 | n_info = np.sum(position_ic>0.2, axis=1) 128 | 129 | #"length of motif", i.e. difference between first and last informative base 130 | ic_idx = pd.DataFrame(np.argwhere(position_ic>0.2), columns=['row', 'idx']).groupby('row')['idx'].apply(list) 131 | motif_length = [] 132 | for row in ic_idx: 133 | motif_length.append(np.max(row)-np.min(row) + 1) 134 | 135 | motif_length = np.stack(motif_length) 136 | 137 | #create pandas data frame: 138 | info_content_df = pd.DataFrame(data=[motif_names, info_content, n_info, pd.to_numeric(motif_length)]).T 139 | info_content_df.columns=['Filter', 'IC', 'Num_Informative_Bases', 'Motif_Length'] 140 | 141 | tomtom = tomtom.merge(info_content_df, how='left', left_on='Query_ID', right_on='Filter') 142 | tomtom = tomtom.drop(columns=['Filter']) 143 | tomtom['IC'] = pd.to_numeric(tomtom['IC']) 144 | tomtom['Num_Informative_Bases'] = pd.to_numeric(tomtom['Num_Informative_Bases']) 145 | 146 | #compute consensus sequence 147 | for filter in tomtom[pd.isnull(tomtom['Query_consensus'])]['Query_ID']: 148 | sequence = '' 149 | indx = np.argwhere(motif_names == filter) 150 | pwm = motifs[indx,:,:].squeeze() 151 | for j in range(pwm.shape[0]): 152 | letter = np.argmax(pwm[j, :]) 153 | if(letter==0): 154 | sequence = sequence + 'A' 155 | if(letter==1): 156 | sequence = sequence + 'C' 157 | if(letter==2): 158 | sequence = sequence + 'G' 159 | if(letter==3): 160 | sequence = sequence + 'T' 161 | 162 | tomtom.loc[(tomtom['Query_ID']==filter), 'Query_consensus'] = sequence 163 | 164 | 165 | tomtom.to_csv("../outputs/motif_summary.csv") 166 | -------------------------------------------------------------------------------- /code/test_human.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.utils.data 5 | import matplotlib 6 | import os 7 | import sys 8 | matplotlib.use('Agg') 9 | 10 | import aitac 11 | import plot_utils 12 | 13 | import time 14 | from sklearn.model_selection import train_test_split 15 | 16 | # Device configuration 17 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 18 | 19 | # Hyper parameters 20 | num_classes = 81 21 | batch_size = 100 22 | num_filters = 300 23 | 24 | #create output figure directory 25 | output_file_path = "../outputs/" + sys.argv[1] + "/" 26 | directory = os.path.dirname(output_file_path) 27 | if not os.path.exists(directory): 28 | os.makedirs(directory) 29 | 30 | # Load all data 31 | x= np.load(sys.argv[2]) 32 | x = x.astype(np.float32) 33 | y= np.load(sys.argv[3]) 34 | y= y[:,[ 0, 1, 2, 9, 14, 16, 11, 17]] 35 | peak_names= np.load(sys.argv[4]) 36 | 37 | model_name = sys.argv[5] 38 | 39 | # Data loader 40 | dataset = torch.utils.data.TensorDataset(torch.from_numpy(x), torch.from_numpy(y)) 41 | data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False) 42 | 43 | # load trained model 44 | model = aitac.ConvNet(num_classes, num_filters).to(device) 45 | checkpoint = torch.load('../models/' + model_name + '.ckpt') 46 | model.load_state_dict(checkpoint) 47 | 48 | 49 | # run predictions with full model on all data 50 | mouse_predictions, max_activations, act_index = aitac.test_model(data_loader, model, device) 51 | 52 | # convert predictions from mouse cell types to human cell types 53 | map = np.genfromtxt("../human_data/mouse_human_celltypes.txt",dtype='str') 54 | mouse_cell_types = np.genfromtxt("../data/cell_type_names.txt", dtype='str') 55 | predictions, cell_names = plot_utils.mouse2human(mouse_predictions, mouse_cell_types, map) 56 | print(cell_names) 57 | 58 | #-------------------------------------------# 59 | # Create Plots # 60 | #-------------------------------------------# 61 | 62 | # plot the correlations histogram 63 | # returns correlation measurement for every prediction-label pair 64 | print("Creating plots...") 65 | 66 | correlations = plot_utils.plot_cors(y, predictions, output_file_path) 67 | 68 | plot_utils.plot_corr_variance(y, correlations, output_file_path) 69 | 70 | quantile_indx = plot_utils.plot_piechart(correlations, y, output_file_path) 71 | 72 | 73 | #-------------------------------------------# 74 | # Save Files # 75 | #-------------------------------------------# 76 | 77 | #save predictions 78 | np.save(output_file_path + "predictions.npy", predictions) 79 | 80 | #save correlations 81 | np.save(output_file_path + "correlations.npy", correlations) 82 | -------------------------------------------------------------------------------- /code/train_test_aitac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import random 3 | from sklearn.model_selection import train_test_split 4 | import torch.utils.data 5 | import matplotlib 6 | import os 7 | import sys 8 | import pathlib 9 | matplotlib.use('Agg') 10 | 11 | import aitac 12 | import plot_utils 13 | 14 | 15 | # Device configuration 16 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 17 | 18 | # Hyper parameters 19 | num_epochs = 10 20 | num_classes = 81 21 | batch_size = 100 22 | learning_rate = 0.001 23 | num_filters = 300 24 | 25 | #create output figure directory 26 | model_name = sys.argv[1] 27 | output_file_path = "../outputs/" + model_name + "/training/" 28 | directory = os.path.dirname(output_file_path) 29 | if not os.path.exists(directory): 30 | print("Creating directory %s" % output_file_path) 31 | pathlib.Path(output_file_path).mkdir(parents=True, exist_ok=True) 32 | else: 33 | print("Directory %s exists" % output_file_path) 34 | 35 | 36 | 37 | # Load all data 38 | x = np.load(sys.argv[2]) 39 | x = x.astype(np.float32) 40 | y = np.load(sys.argv[3]) 41 | y = y.astype(np.float32) 42 | peak_names = np.load(sys.argv[4]) 43 | 44 | 45 | # split the data into training and test sets 46 | train_data, eval_data, train_labels, eval_labels, train_names, eval_names = train_test_split(x, y, peak_names, test_size=0.1, random_state=40) 47 | 48 | # Data loader 49 | train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data), torch.from_numpy(train_labels)) 50 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False) 51 | 52 | eval_dataset = torch.utils.data.TensorDataset(torch.from_numpy(eval_data), torch.from_numpy(eval_labels)) 53 | eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=batch_size, shuffle=False) 54 | 55 | 56 | # create model 57 | model = aitac.ConvNet(num_classes, num_filters).to(device) 58 | 59 | # Loss and optimizer 60 | criterion = aitac.pearson_loss 61 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 62 | 63 | # train model 64 | model, best_loss = aitac.train_model(train_loader, eval_loader, model, device, criterion, optimizer, num_epochs, output_file_path) 65 | 66 | # save the model checkpoint 67 | models_file_path = "../models/" 68 | models_directory = os.path.dirname(models_file_path) 69 | if not os.path.exists(models_directory): 70 | print("Creating directory %s" % models_file_path) 71 | os.makedirs(models_directory) 72 | else: 73 | print("Directory %s exists" % models_file_path) 74 | 75 | torch.save(model.state_dict(), '../models/' + model_name + '.ckpt') 76 | 77 | #save the whole model 78 | torch.save(model, '../models/' + model_name + '.pth') 79 | 80 | # Predict on test set 81 | predictions, max_activations, max_act_index = aitac.test_model(eval_loader, model, device) 82 | 83 | 84 | #-------------------------------------------# 85 | # Create Plots # 86 | #-------------------------------------------# 87 | 88 | # plot the correlations histogram 89 | # returns correlation measurement for every prediction-label pair 90 | print("Creating plots...") 91 | 92 | 93 | correlations = plot_utils.plot_cors(eval_labels, predictions, output_file_path) 94 | 95 | plot_utils.plot_corr_variance(eval_labels, correlations, output_file_path) 96 | 97 | quantile_indx = plot_utils.plot_piechart(correlations, eval_labels, output_file_path) 98 | 99 | plot_utils.plot_random_predictions(eval_labels, predictions, correlations, quantile_indx, eval_names, output_file_path) 100 | 101 | 102 | #-------------------------------------------# 103 | # Save Files # 104 | #-------------------------------------------# 105 | 106 | #save predictions 107 | np.save(output_file_path + "predictions.npy", predictions) 108 | 109 | #save correlations 110 | np.save(output_file_path + "correlations.npy", correlations) 111 | 112 | #save max first layer activations 113 | np.save(output_file_path + "max_activations.npy", max_activations) 114 | np.save(output_file_path + "max_activation_index.npy", max_act_index) 115 | 116 | #save test data set 117 | np.save(output_file_path + "test_OCR_names.npy", eval_names) -------------------------------------------------------------------------------- /data_processing/human_data_preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | import _pickle as pickle 5 | 6 | import preprocess_utils 7 | 8 | #parameters 9 | num_chr = 23 10 | 11 | #input files: 12 | data_file = sys.argv[1] #path to bed file with human data 13 | intensity_file = sys.argv[2] #path to file with normalized ATAC-Seq data 14 | reference_genome = sys.argv[3] #path to reference genome fasta files "../human_data/hg19/" 15 | output_directory = "../human_data/" 16 | 17 | 18 | # read bed file with peak positions, and keep only entries with valid activity vectors 19 | positions = preprocess_utils.read_bed(data_file) 20 | 21 | # read reference genome fasta file into dictionary 22 | if not os.path.exists('../human_data/hg19_chr_chr_dict.pickle'): 23 | chr_dict = preprocess_utils.read_fasta(reference_genome, num_chr) 24 | pickle.dump(chr_dict, open("../human_data/hg19_chr_dict.pickle", "wb")) 25 | 26 | chr_dict = pickle.load(open("../human_data/hg19_chr_dict.pickle", "rb")) 27 | 28 | 29 | one_hot_seqs, peak_seqs, invalid_ids, peak_names = preprocess_utils.get_sequences(positions, chr_dict, num_chr) 30 | 31 | # remove invalid ids from intensities file so sequence/intensity files match 32 | cell_type_array, peak_names2 = preprocess_utils.format_intensities(intensity_file, invalid_ids) 33 | 34 | cell_type_array = cell_type_array.astype(np.float32) 35 | 36 | 37 | #take one_hot_sequences of only peaks that have associated intensity values in cell_type_array 38 | peak_ids = np.intersect1d(peak_names, peak_names2) 39 | 40 | idx = np.isin(peak_names, peak_ids) 41 | peak_names = peak_names[idx] 42 | one_hot_seqs = one_hot_seqs[idx, :, :] 43 | peak_seqs = peak_seqs[idx] 44 | 45 | 46 | idx2 = np.isin(peak_names2, peak_ids) 47 | peak_names2 = peak_names2[idx2] 48 | cell_type_array = cell_type_array[idx2, :] 49 | 50 | if np.sum(peak_names != peak_names2) > 0: 51 | print("Order of peaks not matching for sequences/intensities!") 52 | 53 | 54 | # write to file 55 | np.save(output_directory + 'one_hot_seqs.npy', one_hot_seqs) 56 | np.save(output_directory + 'peak_names.npy', peak_names) 57 | np.save(output_directory + 'peak_seqs.npy', peak_seqs) 58 | 59 | with open(output_directory + 'invalid_ids.txt', 'w') as f: 60 | f.write(json.dumps(invalid_ids)) 61 | 62 | np.save(output_directory + 'cell_type_array.npy', cell_type_array) -------------------------------------------------------------------------------- /data_processing/permute_human_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | directory = "../human_data/" 4 | 5 | x = np.load(directory + 'one_hot_seqs.npy') 6 | y = np.load(directory + 'cell_type_array.npy') 7 | peak_names = np.load(directory + 'peak_names.npy') 8 | 9 | #shuffle order of ocrs in peak intensity array 10 | idx = np.random.permutation(y.shape[0]) 11 | y_shuffled_ocrs = y[idx, :] 12 | peak_names_shuffled = peak_names[idx] 13 | np.save(directory + 'shuffled_ocrs/peak_intensities.npy', y_shuffled_ocrs) 14 | np.save(directory + 'shuffled_ocrs/one_hot_seqs.npy', x) 15 | np.save(directory + 'shuffled_ocrs/peak_names.npy', peak_names_shuffled) 16 | 17 | #shuffle sequences for each sample 18 | x_shuffled = np.zeros(x.shape) 19 | for sample in range(x.shape[0]): 20 | idx = np.random.permutation(np.arange(x.shape[2])) 21 | x_shuffled[sample, :, :] = x[np.ix_([sample], np.arange(4), idx)] 22 | np.save(directory + "shuffled_seqs/peak_intensities.npy", y) 23 | np.save(directory + "shuffled_seqs/one_hot_seqs.npy", x_shuffled) 24 | np.save(directory + "shuffled_seqs/peak_names.npy", peak_names) 25 | 26 | 27 | #shuffle order of cells in peak intensity array 28 | y_shuffled_cells = np.zeros(y.shape) 29 | for peak in range(y.shape[0]): 30 | idx = np.random.permutation(np.arange(y.shape[1])) 31 | y_shuffled_cells[peak, :] = y[peak, idx] 32 | 33 | np.save(directory + "shuffled_cells/peak_intensities.npy", y_shuffled_cells) 34 | np.save(directory + "shuffled_cells/one_hot_seqs.npy", x) 35 | np.save(directory + "shuffled_cells/peak_names.npy", peak_names) 36 | -------------------------------------------------------------------------------- /data_processing/permute_sequence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | x = np.load('../data/processed/one_hot_seqs.npy') 4 | y = np.load('../data/processed/cell_type_array.npy') 5 | peak_names = np.load('../data/processed/peak_names.npy') 6 | 7 | #shuffle order of ocrs in peak intensity array 8 | idx = np.random.permutation(y.shape[0]) 9 | y_shuffled_ocrs = y[idx, :] 10 | peak_names_shuffled = peak_names[idx] 11 | np.save("../data/shuffled_ocrs/peak_intensities.npy", y_shuffled_ocrs) 12 | np.save("../data/shuffled_ocrs/one_hot_seqs.npy", x) 13 | np.save("../data/shuffled_ocrs/peak_names.npy", peak_names_shuffled) 14 | 15 | #shuffle sequences for each sample 16 | x_shuffled = np.zeros(x.shape) 17 | for sample in range(x.shape[0]): 18 | idx = np.random.permutation(np.arange(x.shape[2])) 19 | x_shuffled[sample, :, :] = x[np.ix_([sample], np.arange(4), idx)] 20 | np.save("../data/shuffled_seqs/peak_intensities.npy", y) 21 | np.save("../data/shuffled_seqs/one_hot_seqs.npy", x_shuffled) 22 | np.save("../data/shuffled_seqs/peak_names.npy", peak_names) 23 | 24 | 25 | #shuffle order of cells in peak intensity array 26 | y_shuffled_cells = np.zeros(y.shape) 27 | for peak in range(y.shape[0]): 28 | idx = np.random.permutation(np.arange(y.shape[1])) 29 | y_shuffled_cells[peak, :] = y[peak, idx] 30 | 31 | np.save("../data/shuffled_cells/peak_intensities.npy", y_shuffled_cells) 32 | np.save("../data/shuffled_cells/one_hot_seqs.npy", x) 33 | np.save("../data/shuffled_cells/peak_names.npy", peak_names) 34 | -------------------------------------------------------------------------------- /data_processing/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | import _pickle as pickle 5 | 6 | import preprocess_utils 7 | 8 | #parameters 9 | num_chr = 20 10 | 11 | #input files: 12 | data_file = sys.argv[1] #bed file with peak locations 13 | intensity_file = sys.argv[2] #text file with normalized peak heights 14 | reference_genome = sys.argv[3] #directory with reference genome fasta files 15 | output_directory = sys.argv[4] #output directory 16 | 17 | directory = os.path.dirname(output_directory) 18 | if not os.path.exists(directory): 19 | os.makedirs(directory) 20 | 21 | 22 | 23 | # read bed file with peak positions, and keep only entries with valid activity vectors 24 | positions = preprocess_utils.read_bed(data_file) 25 | 26 | # read reference genome fasta file into dictionary 27 | if not os.path.exists('../data/chr_dict.pickle'): 28 | chr_dict = preprocess_utils.read_fasta(reference_genome, num_chr) 29 | pickle.dump(chr_dict, open('../data/chr_dict.pickle', "wb")) 30 | 31 | chr_dict = pickle.load(open('../data/chr_dict.pickle', "rb")) 32 | 33 | 34 | one_hot_seqs, peak_seqs, invalid_ids, peak_names = preprocess_utils.get_sequences(positions, chr_dict, num_chr) 35 | 36 | # remove invalid ids from intensities file so sequence/intensity files match 37 | cell_type_array, peak_names2 = preprocess_utils.format_intensities(intensity_file, invalid_ids) 38 | 39 | cell_type_array = cell_type_array.astype(np.float32) 40 | 41 | 42 | #take one_hot_sequences of only peaks that have associated intensity values in cell_type_array 43 | peak_ids = np.intersect1d(peak_names, peak_names2) 44 | 45 | idx = np.isin(peak_names, peak_ids) 46 | peak_names = peak_names[idx] 47 | one_hot_seqs = one_hot_seqs[idx, :, :] 48 | peak_seqs = peak_seqs[idx] 49 | 50 | 51 | idx2 = np.isin(peak_names2, peak_ids) 52 | peak_names2 = peak_names2[idx2] 53 | cell_type_array = cell_type_array[idx2, :] 54 | 55 | if np.sum(peak_names != peak_names2) > 0: 56 | print("Order of peaks not matching for sequences/intensities!") 57 | 58 | 59 | 60 | # write to file 61 | np.save(output_directory + 'one_hot_seqs.npy', one_hot_seqs) 62 | np.save(output_directory + 'peak_names.npy', peak_names) 63 | np.save(output_directory + 'peak_seqs.npy', peak_seqs) 64 | 65 | with open(output_directory + 'invalid_ids.txt', 'w') as f: 66 | f.write(json.dumps(invalid_ids)) 67 | f.close() 68 | 69 | #write fasta file 70 | with open(output_directory + 'sequences.fasta', 'w') as f: 71 | for i in range(peak_seqs.shape[0]): 72 | f.write('>' + peak_names[i] + '\n') 73 | f.write (peak_seqs[i] + '\n') 74 | f.close() 75 | 76 | 77 | np.save(output_directory + 'cell_type_array.npy', cell_type_array) -------------------------------------------------------------------------------- /data_processing/preprocess_utils.py: -------------------------------------------------------------------------------- 1 | from Bio import SeqIO 2 | from Bio.Seq import Seq 3 | from Bio.SeqRecord import SeqRecord 4 | from collections import defaultdict 5 | import numpy as np 6 | 7 | # takes DNA sequence, outputs one-hot-encoded matrix with rows A, T, G, C 8 | def one_hot_encoder(sequence): 9 | l = len(sequence) 10 | x = np.zeros((4,l),dtype = 'int8') 11 | for j, i in enumerate(sequence): 12 | if i == "A" or i == "a": 13 | x[0][j] = 1 14 | elif i == "T" or i == "t": 15 | x[1][j] = 1 16 | elif i == "G" or i == "g": 17 | x[2][j] = 1 18 | elif i == "C" or i == "c": 19 | x[3][j] = 1 20 | else: 21 | return "contains_N" 22 | return x 23 | 24 | #read names and postions from bed file 25 | def read_bed(filename): 26 | positions = defaultdict(list) 27 | with open(filename) as f: 28 | for line in f: 29 | name, chr, start, stop = line.split() 30 | positions[name].append((chr, int(start), int(stop))) 31 | 32 | return positions 33 | 34 | 35 | # parse fasta file and turn into dictionary 36 | def read_fasta(genome_dir, num_chr): 37 | chr_dict = dict() 38 | for chr in range(1, num_chr): 39 | chr_file_path = genome_dir + "chr{}.fa".format(chr) 40 | chr_dict.update(SeqIO.to_dict(SeqIO.parse(open(chr_file_path), 'fasta'))) 41 | 42 | return chr_dict 43 | 44 | 45 | #get sequences for peaks from reference genome 46 | def get_sequences(positions, chr_dict, num_chr): 47 | one_hot_seqs = [] 48 | peak_seqs = [] 49 | invalid_ids = [] 50 | peak_names = [] 51 | 52 | target_chr = ['chr{}'.format(i) for i in range(1, num_chr)] 53 | 54 | for name in positions: 55 | for (chr, start, stop) in positions[name]: 56 | if chr in target_chr: 57 | chr_seq = chr_dict[chr].seq 58 | peak_seq = str(chr_seq)[start - 1:stop].lower() 59 | one_hot_seq = one_hot_encoder(peak_seq) 60 | 61 | if isinstance(one_hot_seq, np.ndarray): # it is valid sequence 62 | peak_names.append(name) 63 | peak_seqs.append(peak_seq) 64 | one_hot_seqs.append(one_hot_seq) 65 | else: 66 | invalid_ids.append(name[20:]) 67 | else: 68 | invalid_ids.append(name[20:]) 69 | 70 | one_hot_seqs = np.stack(one_hot_seqs) 71 | peak_seqs = np.stack(peak_seqs) 72 | peak_names = np.stack(peak_names) 73 | 74 | return one_hot_seqs, peak_seqs, invalid_ids, peak_names 75 | 76 | 77 | def format_intensities(intensity_file, invalid_ids): 78 | cell_type_array = [] 79 | peak_names = [] 80 | with open(intensity_file) as f: 81 | for i, line in enumerate(f): 82 | if i == 0: continue 83 | columns = line.split() 84 | peak_name = columns[0] 85 | if '\x1a' not in columns: 86 | cell_act = columns[1:] 87 | cell_type_array.append(cell_act) 88 | peak_names.append(peak_name) 89 | 90 | cell_type_array = np.stack(cell_type_array) 91 | peak_names = np.stack(peak_names) 92 | 93 | return cell_type_array, peak_names 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /figures/AI-TAC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smaslova/AI-TAC/ec0e62c2fd8d200bf46d8e49a09f0760013d10b5/figures/AI-TAC.png -------------------------------------------------------------------------------- /figures/AITAC.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smaslova/AI-TAC/ec0e62c2fd8d200bf46d8e49a09f0760013d10b5/figures/AITAC.pdf -------------------------------------------------------------------------------- /figures/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smaslova/AI-TAC/ec0e62c2fd8d200bf46d8e49a09f0760013d10b5/figures/logo.png -------------------------------------------------------------------------------- /models/aitac.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smaslova/AI-TAC/ec0e62c2fd8d200bf46d8e49a09f0760013d10b5/models/aitac.ckpt -------------------------------------------------------------------------------- /sample_data/README: -------------------------------------------------------------------------------- 1 | This folder contains examples of the data files that serve as input into the AI-TAC model. 2 | 3 | - one_hot_seqs.npy - tensor of one-hot-encoded sequences of 100 ATAC-seq open chromatin regions (OCRs) 4 | - cell_type_array.npy - matrix of normalized ATAC-seq peak values for the same 100 OCRs 5 | - peak_names.npy - IDs of 100 OCRs in sample data 6 | - cell_type_names.npy - names of the 81 cell types in ATAC-seq dataset, in the same order as the data in cell_type_array.npy 7 | - filter_set99_index.txt - indices of 99 reproducible filters in the AI-TAC model 8 | -------------------------------------------------------------------------------- /sample_data/cell_type_array.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smaslova/AI-TAC/ec0e62c2fd8d200bf46d8e49a09f0760013d10b5/sample_data/cell_type_array.npy -------------------------------------------------------------------------------- /sample_data/cell_type_names.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smaslova/AI-TAC/ec0e62c2fd8d200bf46d8e49a09f0760013d10b5/sample_data/cell_type_names.npy -------------------------------------------------------------------------------- /sample_data/filter_set99_index.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 10 3 | 102 4 | 105 5 | 106 6 | 11 7 | 112 8 | 115 9 | 120 10 | 122 11 | 123 12 | 124 13 | 127 14 | 129 15 | 132 16 | 133 17 | 135 18 | 14 19 | 15 20 | 151 21 | 154 22 | 155 23 | 157 24 | 165 25 | 166 26 | 167 27 | 170 28 | 173 29 | 174 30 | 178 31 | 183 32 | 185 33 | 186 34 | 190 35 | 194 36 | 195 37 | 196 38 | 198 39 | 202 40 | 205 41 | 209 42 | 214 43 | 217 44 | 218 45 | 219 46 | 220 47 | 223 48 | 224 49 | 227 50 | 23 51 | 230 52 | 231 53 | 238 54 | 242 55 | 244 56 | 247 57 | 250 58 | 252 59 | 255 60 | 257 61 | 260 62 | 264 63 | 270 64 | 275 65 | 276 66 | 279 67 | 28 68 | 281 69 | 283 70 | 284 71 | 286 72 | 288 73 | 290 74 | 292 75 | 297 76 | 34 77 | 35 78 | 40 79 | 43 80 | 44 81 | 47 82 | 50 83 | 51 84 | 6 85 | 65 86 | 66 87 | 67 88 | 68 89 | 69 90 | 70 91 | 78 92 | 8 93 | 80 94 | 85 95 | 89 96 | 9 97 | 93 98 | 94 99 | 97 100 | -------------------------------------------------------------------------------- /sample_data/one_hot_seqs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smaslova/AI-TAC/ec0e62c2fd8d200bf46d8e49a09f0760013d10b5/sample_data/one_hot_seqs.npy -------------------------------------------------------------------------------- /sample_data/peak_names.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/smaslova/AI-TAC/ec0e62c2fd8d200bf46d8e49a09f0760013d10b5/sample_data/peak_names.npy --------------------------------------------------------------------------------