├── Data ├── ALLgenes_ism_attributions_driversfw_refseq_winsize13.npz ├── ALLgenes_ism_attributions_driversfw_varseq_winsize13.npz ├── ALLgenes_ism_attributions_driversfwmain_ism_significance_stats.txt ├── ALLgenes_ism_attributions_driversfwmain_refseq_winsize13.npz ├── ALLgenes_ism_attributions_driversfwmain_varseq_winsize13.npz ├── Enformer_predictions.txt.gz ├── MeanGeXPredFineTuned.txt ├── Observed_gene_expression.txt.gz ├── PrediXcanCorrelationWithExpressionENSG.tsv ├── Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list ├── Prediction_correlationsCageAdultBrain_Allstats.txt ├── README.md ├── SupplementaryTable1.tsv ├── SupplementaryTable2.tsv ├── SupplementaryTable3.txt ├── enformer_test_set_genes.npy ├── gene-ids-and-positions.tsv ├── geneTSS.txt ├── ism_res.tar ├── maindrivernpz.tar ├── snp_positions.tar ├── tss1000bpnpz.tar └── tss_attribution_stats.txt ├── Fig1.sh ├── Fig2.sh ├── FigS.sh ├── LICENSE ├── README.md ├── enformer_analysis ├── README.md ├── TSS_win_ISM.py ├── attribution_plot.py ├── basic_pred_gene_expr.py ├── cluster_grad_attributions.py ├── combine_eqtl_types.py ├── compute_correlation.py ├── compute_eqtl.py ├── compute_population_frequency.py ├── compute_tstatpvalue.py ├── count_drivers.py ├── driver_distance_to_tss.py ├── eqtl_attribution_plot.py ├── eqtl_types.py ├── extract_ism_stats_around_drivers.py ├── find_common_motifs.py ├── from_main_var_drivers_ISM.py ├── from_ref_drivers_ISM.py ├── generate_null.py ├── gradient_attributions.py ├── per_SNP_ISM.py ├── plot_attribution_alongsequence.py ├── plot_common_motifs.py ├── plot_distribution_enformer_correlations.py ├── plot_driver_motifstats.py ├── plot_drivercounts.py ├── plot_driverdistance.py ├── plot_drivertype.py ├── plot_individual_gene_scatter.py ├── plot_refandvar_attribution.py ├── plot_snp_clustering.py ├── predict_with_augmented_data.py ├── replace_genename.py ├── scatter_correlations.py ├── scatter_mean_prediction.py ├── scatter_pvalue_vs_correlationprediction.py ├── scatterplot.py └── select_drivers.py ├── fine_tuned_model ├── GTExWeights.csv ├── README.md └── enformerFinetuned.m ├── process_genomic_data ├── Genotype_AMPAD_WGS_04_phasing_Eagle_script.sh ├── README.md ├── extractSeq.py ├── save_ref_seqs.py └── save_snp_info.txt └── runPrediXcan.txt /Data/ALLgenes_ism_attributions_driversfw_refseq_winsize13.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/ALLgenes_ism_attributions_driversfw_refseq_winsize13.npz -------------------------------------------------------------------------------- /Data/ALLgenes_ism_attributions_driversfw_varseq_winsize13.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/ALLgenes_ism_attributions_driversfw_varseq_winsize13.npz -------------------------------------------------------------------------------- /Data/ALLgenes_ism_attributions_driversfwmain_refseq_winsize13.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/ALLgenes_ism_attributions_driversfwmain_refseq_winsize13.npz -------------------------------------------------------------------------------- /Data/ALLgenes_ism_attributions_driversfwmain_varseq_winsize13.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/ALLgenes_ism_attributions_driversfwmain_varseq_winsize13.npz -------------------------------------------------------------------------------- /Data/Enformer_predictions.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/Enformer_predictions.txt.gz -------------------------------------------------------------------------------- /Data/Observed_gene_expression.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/Observed_gene_expression.txt.gz -------------------------------------------------------------------------------- /Data/Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list: -------------------------------------------------------------------------------- 1 | ENSG00000001460 2 | ENSG00000005156 3 | ENSG00000006607 4 | ENSG00000013573 5 | ENSG00000014138 6 | ENSG00000015592 7 | ENSG00000023041 8 | ENSG00000029639 9 | ENSG00000031691 10 | ENSG00000032742 11 | ENSG00000040531 12 | ENSG00000041802 13 | ENSG00000042493 14 | ENSG00000053900 15 | ENSG00000064199 16 | ENSG00000065060 17 | ENSG00000071894 18 | ENSG00000074410 19 | ENSG00000075131 20 | ENSG00000075336 21 | ENSG00000083535 22 | ENSG00000084207 23 | ENSG00000085514 24 | ENSG00000085982 25 | ENSG00000088854 26 | ENSG00000089041 27 | ENSG00000095485 28 | ENSG00000096872 29 | ENSG00000099330 30 | ENSG00000100299 31 | ENSG00000100325 32 | ENSG00000100483 33 | ENSG00000101104 34 | ENSG00000103037 35 | ENSG00000104979 36 | ENSG00000105136 37 | ENSG00000105248 38 | ENSG00000105497 39 | ENSG00000106733 40 | ENSG00000109771 41 | ENSG00000109854 42 | ENSG00000110074 43 | ENSG00000111906 44 | ENSG00000112110 45 | ENSG00000112149 46 | ENSG00000112852 47 | ENSG00000114331 48 | ENSG00000116132 49 | ENSG00000119285 50 | ENSG00000119328 51 | ENSG00000119392 52 | ENSG00000119673 53 | ENSG00000119943 54 | ENSG00000119979 55 | ENSG00000120451 56 | ENSG00000120539 57 | ENSG00000120675 58 | ENSG00000120860 59 | ENSG00000122085 60 | ENSG00000122335 61 | ENSG00000122971 62 | ENSG00000123297 63 | ENSG00000123545 64 | ENSG00000124380 65 | ENSG00000124508 66 | ENSG00000124523 67 | ENSG00000124587 68 | ENSG00000125375 69 | ENSG00000125630 70 | ENSG00000125734 71 | ENSG00000125885 72 | ENSG00000125895 73 | ENSG00000126790 74 | ENSG00000126870 75 | ENSG00000127311 76 | ENSG00000127364 77 | ENSG00000127720 78 | ENSG00000128463 79 | ENSG00000128944 80 | ENSG00000129055 81 | ENSG00000130348 82 | ENSG00000130363 83 | ENSG00000130517 84 | ENSG00000131849 85 | ENSG00000132305 86 | ENSG00000133028 87 | ENSG00000133433 88 | ENSG00000133466 89 | ENSG00000133597 90 | ENSG00000133704 91 | ENSG00000134014 92 | ENSG00000134108 93 | ENSG00000134202 94 | ENSG00000134265 95 | ENSG00000135698 96 | ENSG00000135828 97 | ENSG00000136059 98 | ENSG00000136104 99 | ENSG00000136425 100 | ENSG00000137185 101 | ENSG00000137563 102 | ENSG00000138031 103 | ENSG00000138073 104 | ENSG00000138381 105 | ENSG00000142082 106 | ENSG00000142102 107 | ENSG00000143891 108 | ENSG00000148468 109 | ENSG00000148690 110 | ENSG00000149600 111 | ENSG00000150403 112 | ENSG00000151348 113 | ENSG00000151689 114 | ENSG00000152240 115 | ENSG00000152348 116 | ENSG00000152749 117 | ENSG00000154065 118 | ENSG00000154548 119 | ENSG00000154803 120 | ENSG00000154978 121 | ENSG00000155093 122 | ENSG00000155660 123 | ENSG00000156171 124 | ENSG00000156253 125 | ENSG00000157379 126 | ENSG00000157637 127 | ENSG00000159063 128 | ENSG00000160282 129 | ENSG00000160336 130 | ENSG00000160352 131 | ENSG00000160999 132 | ENSG00000163597 133 | ENSG00000164347 134 | ENSG00000164414 135 | ENSG00000164597 136 | ENSG00000164880 137 | ENSG00000164904 138 | ENSG00000166913 139 | ENSG00000167528 140 | ENSG00000168393 141 | ENSG00000169715 142 | ENSG00000169826 143 | ENSG00000169925 144 | ENSG00000169967 145 | ENSG00000170275 146 | ENSG00000170889 147 | ENSG00000171202 148 | ENSG00000172175 149 | ENSG00000172785 150 | ENSG00000172803 151 | ENSG00000173715 152 | ENSG00000174136 153 | ENSG00000174446 154 | ENSG00000175264 155 | ENSG00000175701 156 | ENSG00000177600 157 | ENSG00000177697 158 | ENSG00000177885 159 | ENSG00000177932 160 | ENSG00000178201 161 | ENSG00000178295 162 | ENSG00000178922 163 | ENSG00000179698 164 | ENSG00000180104 165 | ENSG00000180113 166 | ENSG00000180914 167 | ENSG00000182093 168 | ENSG00000182118 169 | ENSG00000182310 170 | ENSG00000182362 171 | ENSG00000182551 172 | ENSG00000183604 173 | ENSG00000183891 174 | ENSG00000184619 175 | ENSG00000184787 176 | ENSG00000185324 177 | ENSG00000186088 178 | ENSG00000187231 179 | ENSG00000187837 180 | ENSG00000188312 181 | ENSG00000188599 182 | ENSG00000188613 183 | ENSG00000196247 184 | ENSG00000196268 185 | ENSG00000196295 186 | ENSG00000196378 187 | ENSG00000196655 188 | ENSG00000196821 189 | ENSG00000197020 190 | ENSG00000198130 191 | ENSG00000198502 192 | ENSG00000198890 193 | ENSG00000204947 194 | ENSG00000205609 195 | ENSG00000205704 196 | ENSG00000213965 197 | ENSG00000214046 198 | ENSG00000214087 199 | ENSG00000214174 200 | ENSG00000214253 201 | ENSG00000214402 202 | ENSG00000214435 203 | ENSG00000215039 204 | ENSG00000215559 205 | ENSG00000225706 206 | ENSG00000226742 207 | ENSG00000226752 208 | ENSG00000228672 209 | ENSG00000228716 210 | ENSG00000229980 211 | ENSG00000230091 212 | ENSG00000230453 213 | ENSG00000233930 214 | ENSG00000234327 215 | ENSG00000235903 216 | ENSG00000236816 217 | ENSG00000237517 218 | ENSG00000237651 219 | ENSG00000239389 220 | ENSG00000239521 221 | ENSG00000239665 222 | ENSG00000239704 223 | ENSG00000240344 224 | ENSG00000242686 225 | ENSG00000242779 226 | ENSG00000243646 227 | ENSG00000243789 228 | ENSG00000244045 229 | ENSG00000245958 230 | ENSG00000246067 231 | ENSG00000247572 232 | ENSG00000253982 233 | ENSG00000254635 234 | ENSG00000254685 235 | ENSG00000256269 236 | ENSG00000259865 237 | ENSG00000259969 238 | ENSG00000259994 239 | ENSG00000260077 240 | ENSG00000260563 241 | ENSG00000260708 242 | ENSG00000261126 243 | ENSG00000261338 244 | ENSG00000261455 245 | ENSG00000261556 246 | ENSG00000268575 247 | ENSG00000270021 248 | ENSG00000272831 249 | ENSG00000273253 250 | ENSG00000274253 251 | ENSG00000274627 252 | ENSG00000276791 253 | ENSG00000278730 254 | ENSG00000279233 255 | ENSG00000279267 256 | ENSG00000282033 257 | -------------------------------------------------------------------------------- /Data/README.md: -------------------------------------------------------------------------------- 1 | This directory contains intermediate data files that were generated with the scripts in enformer_analysis/ and can be used to reproduce intermediate results. 2 | 3 | # ALLgenes_ism_attributions_driversfw_refseq_winsize13.npz 4 | - Contains one-hot-encodings of 13bp around the driving SNVs 5 | - The base in the center is the base that is present in the reference sequence 6 | 7 | # ALLgenes_ism_attributions_driversfw_varseq_winsize13.npz 8 | - Contains one-hot-encodings of 13bp around the driving SNVs 9 | - The base in the center is the base is the variant instead of the reference 10 | 11 | # ALLgenes_ism_attributions_driversfwmain_ism_significance_stats.txt 12 | - Statistics from comparing the ISM values around the main driver of each gene with the maximum ISM that we observed within 2000bp around the TSS 13 | - Almost as SupplementaryTableS3.txt 14 | 15 | # ALLgenes_ism_attributions_driversfwmain_refseq_winsize13.npz 16 | - Contains one-hot-encodings of 13bp around the main driving SNVs 17 | - The base in the center is the base is the variant instead of the reference 18 | 19 | # ALLgenes_ism_attributions_driversfwmain_varseq_winsize13.npz 20 | - Contains one-hot-encodings of 13bp around the main driving SNVs 21 | - The base in the center is the base is the variant instead of the reference 22 | 23 | # Enformer_predictions.txt.gz 24 | - Contains the predicted expression values for all expressed genes for 839 individual genotypes from the combination of all human output tracks with the fine-tuned elastic net model. 25 | 26 | # MeanGeXPredFineTuned.txt 27 | - Contains the predicted expression values for all expressed genes from their reference sequene from the combination of all human output tracks with the fine-tuned elastic net model. 28 | - Contains the mean observed expression value 29 | 30 | # Observed_gene_expression.txt.gz 31 | - Contains the measured expression values for all expressed genes for 839 individual genotypes. 32 | 33 | # PrediXcanCorrelationWithExpressionENSG.tsv 34 | - Contains the Pearson correlation coefficient between the observed gene expression values of 839 individuals and the predicted values from PrediXcan. 35 | 36 | # Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list 37 | - Selected gene set to analyze with ISM 38 | - Enformer correlation are significantly better than random models and absolute correlation is above 0.2 39 | 40 | # Prediction_correlationsCageAdultBrain_Allstats.txt 41 | - Contains Pearson correlation coefficients between the observed gene expression values of 839 individuals and the predicted values with Enformer's human CAGE,brain,adult track 42 | 43 | # SupplementaryTable1.tsv 44 | - Contains combined statistics from different files for all expressed genes 45 | 46 | # SupplementaryTable2.tsv 47 | - Contains statistics for all detected driver SNVs 48 | 49 | # SupplementaryTable3.txt 50 | - Contains statistics about the ISM values around all main drivers 51 | 52 | # enformer_test_set_genes.npy 53 | - Gene set that was not used for training Enformer 54 | - See Method section in paper for details 55 | 56 | # gene-ids-and-positions.tsv 57 | - Gene ids, gene names, location in hg38, location in hg19, and strand information 58 | 59 | # geneTSS.txt 60 | - Transcription start sites 61 | 62 | # ism_res.tar 63 | - Contains directory ism_res/ which contains CAGE,brain,adult track prediction from reference sequence, and reference sequence with individually inserted main variants 64 | 65 | # maindrivernpz.tar 66 | - Contains .npz files with ISM values for 41 bp around the main drivers. 67 | 68 | # snp_positions.tar 69 | - Contains the loci of all the SNVs that were within the window of Enformer inputs. 70 | - Loci match the variant predictions in ism_res.tar 71 | 72 | # tss1000bpnpz.tar 73 | - Contains .npz files with ISM values for +-1000 bp around the TSS 74 | 75 | # tss_attribution_stats.txt 76 | - Contains statistics from all files in tss1000bpnpz.tar, such as the max absolute ISM, and the standard deviation of ISMs within +-1000 bp around the TSS 77 | 78 | # variant_info_100k.tar 79 | **Personal gentype data cannot be shared with third party unless approved by the RADC.** All requests to the investigator from the third party must be directed to the 80 | RADC in order to assure knowledge of and compliance with the rules and regulations (https://www.radc.rush.edu/) 81 | - Contains snp_info.npz files 82 | - Files contain a matrix representing the genotype of all individuals, i.e rows are SNVs within input window and columns represent indivduals. 83 | - Individuals with two copies of the major allele (genotype 0), those with one copy of the major allele (genotype 1) and those with two copies of the minor allele (genotype 2) 84 | 85 | -------------------------------------------------------------------------------- /Data/enformer_test_set_genes.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/enformer_test_set_genes.npy -------------------------------------------------------------------------------- /Data/ism_res.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/ism_res.tar -------------------------------------------------------------------------------- /Data/maindrivernpz.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/maindrivernpz.tar -------------------------------------------------------------------------------- /Data/snp_positions.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/snp_positions.tar -------------------------------------------------------------------------------- /Data/tss1000bpnpz.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostafavilabuw/EnformerAssessment/afac6658d57bfe6e2c1f17d9a96e3593843000a2/Data/tss1000bpnpz.tar -------------------------------------------------------------------------------- /Fig1.sh: -------------------------------------------------------------------------------- 1 | # 1C: Plot mean expression of all individuals against predicted expression with enformer from reference sequence 2 | python3 scatter_mean_prediction.py 3 | 4 | # 1D: DDX11 scatter plot 5 | python3 plot_individual_gene_scatter.py Observed_gene_expression.txt Enformer_predictions.txt DDX11 --figsize 4 3 6 | 7 | # List of all genes with npy files to get attributions 8 | gene=$(ls *_ref_pred.npy) 9 | # Compute attributions all genes 10 | for g in $gene 11 | do 12 | echo ${g%_ref_pred.npy} 13 | python3 compute_attribution.py ${g%_ref_pred.npy} 14 | done 15 | 16 | gene=$(ls *snp_info.npz) 17 | # Compute the population frequency for every SNV 18 | for g in $gene 19 | do 20 | python3 ../compute_population_frequency.py ${g%snp_info.txt} 21 | done 22 | 23 | # 1E: DDX11 ISM attribution plot 24 | python3 ../attribution_plot.py ENSG00000013573_ism_attributions.txt ISM --colors ../variant_info_100k/ENSG00000013573_frequency.txt --tss ../geneTSS.txt ENSG00000013573 25 | 26 | 27 | -------------------------------------------------------------------------------- /Fig2.sh: -------------------------------------------------------------------------------- 1 | # Replace gene names with ENSG IDs 2 | python3 replace_genename.py Observed_gene_expression.txt 3 | python3 replace_genename.py Enformer_predictions.txt 4 | 5 | # compute correlation between observed and predicted expression across 839 individuals 6 | #python3 compute_correlation.py Observed_gene_expressionENSG.tsv ../Enformer_predictions_CageAdultBrain.txt > Prediction_correlationsCageAdultBrain.txt 7 | #python3 compute_correlation.py Observed_gene_expressionENSG.tsv ../Enformer_predictions_CageAdultBrain.txt --get_allstats > Prediction_correlationsCageAdultBrain_Allstats.txt 8 | 9 | # compute gene specific random null that is dependent on SNV structure 10 | python3 generate_null.py 11 | 12 | # compute gene specific p-values for the predictions 13 | python3 compute_tstatpvalue.py ../Prediction_correlationsCageAdultBrain.txt GeneSpecific_CorrelationtoObsRandomNull.txt 14 | 15 | # 2A: Plot p-value versus correlation of predicted expression with enformer from individual sequences 16 | python3 scatter_pvalue_vs_correlationprediction.py Prediction_correlationsCageAdultBrain_Allstats.txt Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats.txt --markersize 4 --printset 0 17 | 18 | # 2B: C2orf74 scatter plot 19 | python3 plot_individual_gene_scatter.py Full_analysis/Observed_gene_expressionENSG.tsv Enformer_predictions_CageAdultBrain.txt ENSG00000237651 20 | python3 plot_individual_gene_scatter.py Full_analysis/Observed_gene_expressionENSG.tsv Enformer_predictions_CageAdultBrain.txt ENSG00000134202 --setylim 0.057,0.071 21 | python3 plot_individual_gene_scatter.py Full_analysis/Observed_gene_expressionENSG.tsv Enformer_predictions_CageAdultBrain.txt ENSG00000128944 22 | python3 plot_individual_gene_scatter.py Full_analysis/Observed_gene_expressionENSG.tsv Enformer_predictions_CageAdultBrain.txt ENSG00000133433 23 | python3 plot_individual_gene_scatter.py Full_analysis/Observed_gene_expressionENSG.tsv Enformer_predictions_CageAdultBrain.txt ENSG00000226752 24 | python3 plot_individual_gene_scatter.py Full_analysis/Observed_gene_expressionENSG.tsv Enformer_predictions_CageAdultBrain.txt ENSG00000120675 25 | 26 | # 2C: Scatter plot comparison between PrediXcan and Enformer correlations to observed individual expression 27 | python3 scatter_correlations.py --colors Full_analysis/Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats.txt --colorscut 1.301 --vlim -1.5,1 28 | 29 | # compute eqtl 30 | cd eqtl/ 31 | python3 -W ignore compute_eqtl.py ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list 32 | cd .. 33 | 34 | cd ism_res/ 35 | # determine drivers with forward method 36 | genes=$(ls *_ism_attributions.txt) 37 | for g in genes 38 | do 39 | python3 ../select_drivers.py $g ../variant_info_100k/${g%_ism_attributions.txt}snp_info.txt ${g%_ism_attributions.txt} --forward 40 | done 41 | 42 | # 2D: Eqtl versus ISM plot for GSTM3 43 | python3 ../eqtl_attribution_plot.py ENSG00000134202_ism_attributions.txt ../eqtl/ENSG00000134202_eqtl.txt ISM eQTL --colors ../variant_info_100k/ENSG00000134202_frequency.txt --drivers ENSG00000134202_ism_attributions_driversfw.txt --dpi 450 --fmt '.svg' --minfreq 0.01 44 | 45 | # determine driver types 46 | genes=$(ls *_ism_attributions.txt) 47 | for g in genes 48 | do 49 | python3 ../eqtl_types.py $g ../eqtl/${g%_ism_attributions.txt}_eqtl.txt ${g%_ism_attributions.txt}_ism_attributions_driversfw.txt 50 | done 51 | 52 | # combine driver types 53 | python3 ../combine_eqtl_types.py _ism_attributions_driversfw_types.txt > ALLgenes_ism_attributions_driversfw_types.txt 54 | 55 | # 2E: Enrichment of drivertypes for positive and negative correlated genes 56 | python3 ../plot_drivertype.py ALLgenes_ism_attributions_driversfw_types.txt ../Prediction_correlationsCageAdultBrain.txt ALL_genes_ism_attributions_sumpersonal_mp_vs_Enformer_predictions_CageAdultBrain.txt --genelist ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list -weighted ALLgenes_ism_attributions_driversfw.txt -3 57 | 58 | # Compute driver distance to TSS 59 | python3 ../driver_distance_to_tss.py _ism_attributions_driversfw.txt ../geneTSS.txt > DistancetoTSS_ism_attributions_driversfw.txt 60 | 61 | # 2F: Location of drivers with main drivers in different colors 62 | python3 ../plot_driverdistance.py DistancetoTSS_ism_attributions_driversbw.txt ../Prediction_correlationsCageAdultBrain.txt ALL_genes_ism_attributions_sumpersonal_mp_vs_Enformer_predictions_CageAdultBrain.txt --genelist ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list --plot_main_in_all 63 | 64 | # 2G: Gradient plot for GSTM3 with drivers and other variants 65 | python3 ../plot_attribution_alongsequence.py ENSG00000134202 0 ../ism_res/ENSG00000134202_ism_attributions.txt ../variant_info_100k/ENSG00000134202_frequency.txt ../ism_res/ENSG00000134202_ism_attributions_driversfw.txt ../ism_res/ENSG00000134202_ism_attributions_driversfw_types.txt --savefig 66 | 67 | # Generate ISM figure around the TSS 68 | python3 ../plot_refandvar_attribution.py ENSG00000134202_109741038_attribution.npz --squaresize 0.12 --include_snvs 109741038 ../ism_res/ENSG00000134202_ism_attributions.txt ../eqtl/ENSG00000134202_corr.txt --markdrivers ../ism_res/ENSG00000134202_ism_attributions_driversfw.txt --dpi 350 --include_conservation ../PhyloP100/ENSG00000134202_1000tss_in_hg38.phyloP100way.txt --excludeheatmap --enlargepwm 1.8 69 | python3 ../plot_refandvar_attribution.py $g --figsize 75,1 --include_snvs $loc ../ism_res/${gene}_ism_attributions.txt ../eqtl/${gene}_corr.txt --markdrivers ../ism_res/${gene}_ism_attributions_driversfw.txt --dpi 250 --include_conservation ../PhyloP100/${gene}_1000tss_in_hg38.phyloP100way.txt --excludeheatmap --enlargepwm 1.2 70 | 71 | -------------------------------------------------------------------------------- /FigS.sh: -------------------------------------------------------------------------------- 1 | # Files and scripts to generate the supplementary figures 2 | 3 | # S1. Prediction of mean expression for the test set 4 | python3 scatter_mean_prediction.py --testset enformer_test_set_genes.npy 5 | 6 | #S2. Distribution of number of variants per person for DDX11 7 | python3 ../plot_snp_clustering.py ENSG00000013573snp_info.npz --select_individuals ../Enformer_predictions_individuals.txt --minfreq 0.1 --minsnps 10 --combine_snps 0.9 --combine_individuals 0.9 --savefig ENSG00000013573_individual_enformer_snp 8 | 9 | #S4 Scatter plot between correlations to observed for fine-tuned and CAGE-track 10 | python3 scatterplot.py Prediction_correlationsCageAdultBrain_Allstats.txt ism_res/Refpred.txt "CAGE,brain,adult,MeanSumlog10+-2indv" "CAGE,brain,adult,log10sum+-1ref" --columns -2 -1 --density --alpha 0.5 --label --filternan --linewidth 0.1 --log10y 11 | 12 | #NOT included: Mean of Predicted CAGE tracks and sum of Observed, both are logged 13 | python3 scatterplot.py Prediction_correlationsCageAdultBrain_Allstats.txt Prediction_correlationsCageAdultBrain_Allstats.txt "MeanObs" "CAGE,brain,adult,MeanSumlog10+-2indv" --columns -4 -2 --density --alpha 0.5 --label --filternan --linewidth 0.1 14 | python3 scatterplot.py Prediction_correlationsCageAdultBrain_Allstats.txt Prediction_correlationsCageAdultBrain_Allstats.txt "StdObs" "CAGE,brain,adult,StdSumlog10+-2indv" --columns -3 -1 --density --alpha 0.5 --label --filternan --linewidth 0.1 15 | python3 scatterplot.py Prediction_correlationsCageAdultBrain_Allstats.txt Prediction_correlationsCageAdultBrain_Allstats.txt "MeanObs" "StdObs" --columns -4 -3 --density --alpha 0.5 --label --filternan --linewidth 0.1 16 | python3 scatterplot.py Prediction_correlationsCageAdultBrain_Allstats.txt Prediction_correlationsCageAdultBrain_Allstats.txt "MeanCAGE,adult,brain" "StdCAGE,adult,brain" --columns -2 -1 --density --alpha 0.5 --label --filternan --linewidth 0.1 17 | 18 | #S6 Distribution of Susie genes 19 | python3 plot_distribution_enformer_correlations.py Full_analysis/Prediction_correlationsCageAdultBrain.txt susie_SNP_gene_CortexENSG.txt 20 | 21 | # S7. Correlation between ISM and gradient attributions between variants of all investigated genes 22 | python3 correlation_ism_grad_attributions.py 23 | 24 | # compute the sum of snp attributions 25 | cd ism_res/ 26 | genes=$(ls ENSG*_ism_attributions.txt) 27 | for g in $genes 28 | do 29 | python3 ../sum_meanattributions.py $g ../variant_info_100k/${g%_ism_attributions.txt}snp_info.npz 30 | done 31 | # combine the sum of predictions 32 | python3 ../combine_predictions.py ENSG _ism_attributions_sumpersonal_mp.txt 33 | # compute correlation to Enformers predictions 34 | python3 ../compute_correlation.py ALL_genes_ism_attributions_sumpersonal_mp.txt ../Enformer_predictions_CageAdultBrain.txt > ALL_genes_ism_attributions_sumpersonal_mp_vs_Enformer_predictions_CageAdultBrain.txt 35 | cd .. 36 | 37 | # compute sum of gradient attributions 38 | cd ref_attribs/ 39 | genes=$(ls *_grad_attributions.txt) 40 | for g in $genes 41 | do 42 | time python3 ../sum_meanattributions.py $g ../variant_info_100k/${g%_grad_attributions.txt}snp_info.npz 43 | done 44 | # combine the sum of predictions 45 | python3 ../combine_predictions.py ENSG _grad_attributions_sumpersonal_mp.txt 46 | # compute correlation to Enformers predictions 47 | python3 ../compute_correlation.py ALL_genes_grad_attributions_sumpersonal_mp.txt ../Enformer_predictions_CageAdultBrain.txt > ALL_genes_grad_attributions_sumpersonal_mp_vs_Enformer_predictions_CageAdultBrain.txt 48 | cd .. 49 | 50 | # S8. Correlation between ISM and gradient sums for every individual with predicted values of enformer for every individual 51 | python3 plot_correlations_predictions_tosumattributions.py Prediction_correlationsCageAdultBrain.txt ism_res/ALL_genes_ism_attributions_sumpersonal_mp_vs_Enformer_predictions_CageAdultBrain.txt ref_attribs/ALL_genes_grad_attributions_sumpersonal_mp_vs_Enformer_predictions_CageAdultBrain.txt --savefig 52 | 53 | # Not shown anymore: Intermediate plots for forward identification of drivers 54 | # GSTM3 55 | python3 ../select_drivers.py ENSG00000134202_ism_attributions.txt ../variant_info_100k/ENSG00000134202snp_info.txt ENSG00000134202 --forward --plot_test 40 56 | # DDX11 57 | python3 ../select_drivers.py ENSG00000013573_ism_attributions.txt ../variant_info_100k/ENSG00000013573snp_info.txt ENSG0000013573 --forward --plot_test 40 58 | 59 | # combine drivers 60 | python3 ../combine_drivers.py _ism_attributions_driversfw.txt > ALLgenes_ism_attributions_driversfw.txt 61 | 62 | # S9. Plot number of SNP drivers 63 | python3 ../count_drivers.py _ism_attributions_driversfw.txt > Counts_ism_attributions_driversfw.txt 64 | python3 ../count_drivers.py _ism_attributions_driversbw.txt > Counts_ism_attributions_driversbw.txt 65 | 66 | python3 ../plot_drivercounts.py Counts_ism_attributions_driversfw.txt ../Prediction_correlationsCageAdultBrain.txt ALL_genes_ism_attributions_sumpersonal_mp_vs_Enformer_predictions_CageAdultBrain.txt --genelist ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list --split_sets 67 | 68 | # S10 sequence patterns at the driver snps 69 | 70 | #Make list with main drivers 71 | python3 ../select_main_drivers.py ALLgenes_ism_attributions_driversfw.txt > ALLgenes_ism_attributions_driversfwmain.txt 72 | # make type file with main drivers 73 | python3 ../select_driverset.py ALLgenes_ism_attributions_driversfwmain.txt ALLgenes_ism_attributions_driversfw_types.txt ALLgenes_ism_attributions_driversfwmain_types.txt 74 | 75 | # extract sequence windows of size 13 around drivers 76 | python3 ../extract_sequence_around_drivers.py ALLgenes_ism_attributions_driversfw.list 6 ../geneTSS.txt 77 | python3 ../extract_sequence_around_drivers.py ALLgenes_ism_attributions_driversfwmain.list 6 ../geneTSS.txt 78 | # cluster 13-mers for ref seq 79 | python3 ../find_common_motifs.py ALLgenes_ism_attributions_driversfw_refseq_winsize13.npz complete 0.25 5 1 --geneset ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list 80 | python3 ../find_common_motifs.py ALLgenes_ism_attributions_driversfwmain_refseq_winsize13.npz complete 0.25 5 1 --geneset ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list 81 | # check enrichment of specific types of clusters 82 | python3 ../plot_common_motifs.py ALLgenes_ism_attributions_driversfw_types.txt ALLgenes_ism_attributions_driversfw_refseq_winsize13_clust_ms5-1_complete0.25 --savefig ALLgenes_ism_attributions_driversfw_refseq_winsize13_clust_ms5-1_complete0.2 83 | # check enrichment only for main drivers 84 | python3 ../plot_common_motifs.py ALLgenes_ism_attributions_driversfwmain_types.txt ALLgenes_ism_attributions_driversfwmain_refseq_winsize13_clust_ms5-1_complete0.25 --savefig ALLgenes_ism_attributions_driversfwmain_refseq_winsize13_clust_ms5-1_complete0.25 85 | 86 | # cluster 13mers for varseq 87 | python3 ../find_common_motifs.py ALLgenes_ism_attributions_driversfw_varseq_winsize13.npz complete 0.25 5 1 --geneset ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list 88 | python3 ../find_common_motifs.py ALLgenes_ism_attributions_driversfwmain_varseq_winsize13.npz complete 0.25 5 1 --geneset ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list 89 | # check enrichment of specific types of clusters 90 | python3 ../plot_common_motifs.py ALLgenes_ism_attributions_driversfw_types.txt ALLgenes_ism_attributions_driversfw_varseq_winsize13_clust_ms5-1_complete0.25 --savefig ALLgenes_ism_attributions_driversfw_varseq_winsize13_clust_ms5-1_complete0.25 91 | # check enrichment only for main drivers 92 | python3 ../plot_common_motifs.py ALLgenes_ism_attributions_driversfwmain_types.txt ALLgenes_ism_attributions_driversfwmain_varseq_winsize13_clust_ms5-1_complete0.25 --savefig ALLgenes_ism_attributions_driversfwmain_varseq_winsize13_clust_ms5-1_complete0.25 93 | 94 | # S11. Full size figures of driver distance from tss 95 | # see Fig2.sh 96 | 97 | # S12. Clusters of gradient attribution locations 98 | python3 ../cluster_grad_attributions.py ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list ../Prediction_correlationsCageAdultBrain.txt 99 | 100 | #Not included. Causal SNP enrichment in positive genes 101 | #python3 susie_SNP_gene_analysis.py susie_SNP_gene_brain_union_b.txt --genelist Full_analysis/Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.txt 102 | ## potentially 103 | #python3 susie_SNP_gene_analysis.py susie_SNP_gene_CortexENSG.txt --genelist Full_analysis/Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.txt 104 | 105 | 106 | # S13. Analysis of motifs and attribution signals within and around drivers 107 | # Analyze signal at drivers from ism 108 | python3 extract_ism_stats_around_drivers.py ../ism_res/ALLgenes_ism_attributions_driversfwmain_types.txt --checkmax --geneset Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list 109 | 110 | python3 ../plot_driver_motifstats.py ../ism_res/ALLgenes_ism_attributions_driversfwmain_types.txt ALLgenes_ism_attributions_driversfwmain_ism_significance_stats.txt 2,15 --absolute --combine_columns max --scale 100 --lim 100 --cumulative -1 --nbins 22 --savefig ISMstats256_mainmaxzscore_in_var_or_ref.jpg 111 | 112 | python3 ../plot_driver_motifstats.py ../ism_res/ALLgenes_ism_attributions_driversfwmain_types.txt ALLgenes_ism_attributions_driversfwmain_ism_significance_stats.txt 4,17 --combine_columns max --lim 10 --cumulative -1 --nbins 11 --savefig ISMstats256_mainmotifssize_in_var_or_ref_10percofmax.jpg --xlim 0,10 113 | 114 | python3 ../plot_driver_motifstats.py ../ism_res/ALLgenes_ism_attributions_driversfwmain_types.txt ALLgenes_ism_attributions_driversfwmain_ism_significance_stats.txt 5,18 --combine_columns max --lim 10 --cumulative -1 --nbins 11 --savefig ISMstats256_mainmotifssize_in_var_or_ref_20percofmax.jpg --xlim 0,10 115 | 116 | python3 ../plot_driver_motifstats.py ../ism_res/ALLgenes_ism_attributions_driversfwmain_types.txt ALLgenes_ism_attributions_driversfwmain_ism_significance_stats.txt 6,19 --combine_columns max --lim 10 --cumulative -1 --nbins 11 --savefig ISMstats256_mainmotifssize_in_var_or_ref_50percofmax.jpg --xlim 0,10 117 | 118 | genes='ENSG00000001460_24417389 ENSG00000013573_31073901 ENSG00000134202_109741163' 119 | for g in $genes 120 | do 121 | gene=${g::15} 122 | python3 ../plot_refandvar_attribution.py from_main_var/${g}_attribution.npz --norm_attributions ../TSS_ISM/tss_attribution_stats.txt ${gene} 2 --setylimpwm -1,1 123 | done 124 | genes='ENSG00000001460_24416934_1000bp_attribution.npz ENSG00000013573_31073845_1000bp_attribution.npz ENSG00000134202_109741038_1000bp_attribution.npz' 125 | for g in $genes 126 | do 127 | gene=${g::15} 128 | loc=${g%_1000bp_attribution.npz} 129 | loc=${loc:16} 130 | python3 ../plot_refandvar_attribution.py $g --norm_attributions --norm_attributions ../TSS_ISM/tss_attribution_stats.txt ${gene} 2 --figsize 100,3 --include_snvs $loc ../ism_res/${gene}_ism_attributions.txt ../variant_info_100k/${gene}_frequency.txt --markdrivers ../ism_res/${gene}_ism_attributions_driversfw.txt --dpi 150 --include_conservation ../PhyloP100/${gene}_1000tss_in_hg38.phyloP100way.txt --excludeheatmap 131 | done 132 | 133 | 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # How far are we from personalized gene expression prediction using sequence-to-expression deep neural networks? 2 | 3 | This repository contains the data and the scripts that were used to assess Enformer's [[1]](#1) ability to predict differential gene expression across 839 individuals [[2]](#2). 4 | 5 | To reproduce our analysis [[2]](#2), please follow the examplary commands in Fig1.sh, Fig2.sh, and FigS.sh. These bash scripts were not intended to run by themselves. Instead, these scripts were meant to guide interested people through our analysis and document our analysis. Please **run each step separately** and if necessary replace hard coded file locations in python scripts with the correct location in the Data directory. Most intermediate output files from these scripts can also be found in the Data/ directory, so that intermediate steps can be skipped. 6 | 7 | To reproduce the entire analysis, it is **essential to obtain access to personal genotype data from the the Rush Memory and Aging Project (ROSMAP)**. The personal genotype data that was used in this study cannot be shared with a third party unless approved by the RADC. All requests to the investigator from the third party must be directed to the RADC in order to assure knowledge of and compliance with the rules and regulations. Genotype, RNA-seq, and DNAm data for the Religious Orders Study and Rush Memory and Aging Project (ROSMAP) samples are available from the Synapse AMP-AD Data Portal (Accession Code: syn2580853) as well as RADC Research Resource Sharing Hub at https://www.radc.rush.edu/. 8 | 9 | For RNA-seq pre-processing, please refer to the Supplementary Methods in [[2]](#2). In brief, we applied TMM normalization (using edgeR calcNormFactors) to the raw counts to estimate the effective library size of each individual. We then applied voom/limma to regress out confounds and convert the counts into log2(CPM). Technical covariates included: batch, study (ROS or MAP), RNA integrity number, postmortem interval, Library size, log PF number of aligned reads, PCT_CODING_BASES, PCT_INTERGENIC_BASES, PCT_PF_READS_ALIGNED, PCT_RIBOSOMAL_BASES, PCT_UTR_BASES, PERCENT_DUPLICATION, MEDIAN_3PRIME_BIAS, MEDIAN_5PRIME_TO_3PRIME_BIAS, MEDIAN_CV_COVERAGE. Biological covariates, including 1) age, 2) sex, and 3) top 10 expression principal components. Both biological and technical covariates were regressed out from log raw read counts. Only genes with mean log2(CPM) > 2 were included. Mean expression values were retained for downstream analysis. 10 | 11 | The variant call files for whole genome sequencing (WGS) data from the ROSMAP in variant call format (VCF) were obtained from the Synapse repository (syn117074200). The coordinates of variant calls (GRCh37) were converted to GRCh38 coordinates using the Picard LiftoverVcf tool (http://broadinstitute.github.io/picard). The Eagle software2 version 2.4.1 was used to phase the genotypes with the default setting. 12 | 13 | ## Abstract 14 | 15 | _Deep learning (DL) methods accurately predict various functional properties from genomic DNA, including gene expression, promising to serve as an important tool in interpreting the full spectrum of genetic variations in personal genomes. However, systematic out-of-sample benchmarking is needed to assess the gap in their utility as personalized DNA interpreters. Using paired Whole Genome Sequencing and gene expression data we evaluate DL sequence-to-expression models, identifying their critical failure to make correct predictions on a substantial number of genomic loci, highlighting the limits of the current model training paradigm._ 16 | 17 | ## References 18 | [1] 19 | Avsec, Ž., Agarwal, V., Visentin, D., Ledsam, J.R., Grabska-Barwinska, A., Taylor, K.R., Assael, Y., Jumper, J., Kohli, P. and Kelley, D.R., 2021. Effective gene expression prediction from sequence by integrating long-range interactions. [Nature methods](https://www.nature.com/articles/s41592-021-01252-x), 18(10), pp.1196-1203. 20 | 21 | [2] 22 | Sasse, A.\*, Ng, B.\*, Spiro, A.\*, Tasaki, S., Bennett, D., Gaiteri, C., De Jager, P.L., Chikina, M., and Mostafavi, S., 2023. How far are we from personalized gene expression prediction using sequence-to-expression deep neural networks? [bioRxiv](https://doi.org/10.1101/2023.03.16.532969), https://doi.org/10.1101/2023.03.16.532969, \* These authors contributed equally 23 | -------------------------------------------------------------------------------- /enformer_analysis/TSS_win_ISM.py: -------------------------------------------------------------------------------- 1 | # Get ISM results for Enformer on a given gene set 2 | # For this analysis, get Enformer output for inserting each nucleotide at every position within 1000bp of the TSS 3 | 4 | import time 5 | import pandas as pd 6 | import numpy as np 7 | import sparse 8 | import os 9 | import tensorflow as tf 10 | import tensorflow_hub as hub 11 | import time as time 12 | import argparse 13 | 14 | track_idx = 4980 # CAGE:brain, adult, 15 | center_bins = [447,448,449] # three center bins of enformer output 16 | 17 | save_path = '/data/aspiro17/enformer_res/' 18 | model_path = 'https://tfhub.dev/deepmind/enformer/1' 19 | 20 | padded_input_len = 393216 # full input to enformer 21 | input_len = 196608 22 | starting_seq_len = 200001 # starting length of reference seq 23 | pad_before = int((padded_input_len - input_len)/2) # "pad" sequence to padded_input_len 24 | 25 | mid_index = int((starting_seq_len-1)/2) 26 | start_index = int(mid_index - (input_len-1)/2) 27 | end_index = int(mid_index + (input_len-1)/2) + 1 28 | 29 | nuc_order =['A', 'C', 'G', 'T'] 30 | nuc_onehots = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] 31 | 32 | gene_win_info = pd.read_csv('/data/mostafavilab/bng/rosmapAD/data/expressionData/DLPFC/20220207-bulk-RNAseq/gene-ids-and-positions.tsv', sep = '\t', index_col = 1) # contains TSS position info 33 | 34 | win = 1000 # window around TSS to get results for 35 | 36 | # enformer helper functs 37 | class Enformer: 38 | 39 | def __init__(self, tfhub_url): 40 | self._model = hub.load(tfhub_url).model 41 | 42 | def predict_on_batch(self, inputs): 43 | predictions = self._model.predict_on_batch(inputs) 44 | return {k: v.numpy() for k, v in predictions.items()} 45 | 46 | def get_tss_ism(gene_id): 47 | 48 | gene_res = np.zeros([win*2+1,4]) 49 | 50 | model = Enformer(model_path) 51 | snp_info = pd.read_csv('/data/aspiro17/enformer_res/variant_info_100k/' + gene_id + '.csv',header=None,encoding='latin1') 52 | current_gene = np.load(save_path+'ref_seqs/'+gene_id+'.npy') 53 | 54 | starting_seq_len = np.shape(current_gene)[1] 55 | current_tss = int(gene_win_info['tss_hg38'][gene_id]) 56 | attrib_start_pos = current_tss - int(input_len/2) + 1 57 | 58 | # adjust sequence to be input to model 59 | current_gene = np.transpose(current_gene, (1,0)) # transpose to be seq_len x 4 60 | current_gene = current_gene[:, [0,3,2,1]] # go from previous: seqP=='A',seqP=='T',seqP=='G',seqP=='C', to Enformer: 'ACGT' 61 | current_gene = np.reshape(current_gene, (1, starting_seq_len, 4)) # add a 1 dimen 62 | current_gene = np.pad(current_gene, pad_width=((0,0),(pad_before-start_index, pad_before-(starting_seq_len - end_index)), (0,0))) # pad seq 63 | 64 | # from ref 65 | ref_out = model.predict_on_batch(current_gene) 66 | ref_out = ref_out['human'][0] 67 | ref_val = np.sum(ref_out[center_bins,track_idx]) # 448 is center bin 68 | 69 | np.save(save_path + 'attrib_res/TSS_win_ISM_res/' +gene_id + '_ref_pred', ref_val) 70 | 71 | center_idx = (current_tss - attrib_start_pos) + pad_before # where in this array 72 | window_start_pos = center_idx - win 73 | 74 | for i in range(win*2+1): 75 | 76 | # fill in for ref 77 | current_onehot = list(current_gene[:,window_start_pos+i,:][0]) 78 | current_nuc_idx = nuc_onehots.index(current_onehot) 79 | gene_res[i,current_nuc_idx] = ref_val # this nuc represents no change from the ref 80 | rel_nuc_idxs = list(range(4)) 81 | rel_nuc_idxs.remove(current_nuc_idx) # don't have to re run for ref nuc 82 | 83 | for idx in rel_nuc_idxs: # the other 3 nucs 84 | inserted_seq = current_gene.copy() 85 | current_onehot = nuc_onehots[idx] 86 | inserted_seq[:,window_start_pos+i,:] = current_onehot # get current seq 87 | 88 | var_out = model.predict_on_batch(inserted_seq) 89 | var_out = var_out['human'][0] 90 | var_val = np.sum(var_out[center_bins,track_idx]) 91 | 92 | gene_res[i,idx] = var_val 93 | 94 | np.save(save_path + 'attrib_res/TSS_win_ISM_res/' + gene_id + '_' + str(win) + 'bp', gene_res) 95 | 96 | 97 | if __name__ == '__main__': 98 | 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('which_gpu', action="store", default='0') 101 | parser.add_argument('gene_file', action="store", default='test') 102 | 103 | args = parser.parse_args() 104 | 105 | os.environ["CUDA_VISIBLE_DEVICES"]=args.which_gpu 106 | genes = np.load(args.gene_file) 107 | 108 | for gene_id in genes: 109 | print(gene_id) 110 | get_tss_ism(gene_id) 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /enformer_analysis/attribution_plot.py: -------------------------------------------------------------------------------- 1 | # Plots attributions of SNVs along the genomic location 2 | # Run: 3 | # python3 ../attribution_plot.py ENSG00000013573_ism_attributions.txt ISM ../variant_info_100k/ENSG00000013573snp_info.txt --tss ../geneTSS.txt ENSG00000013573 4 | 5 | import numpy as np 6 | import sys, os 7 | import matplotlib.pyplot as plt 8 | 9 | # Read attributions 10 | attributions = np.genfromtxt(sys.argv[1]) 11 | # Define ylabel 12 | ylabel = sys.argv[2] 13 | 14 | # Color attributions by other features such as population frequency 15 | if '--colors' in sys.argv: 16 | colors = np.genfromtxt(sys.argv[sys.argv.index('--colors')+1]) 17 | # sort colors to attributions 18 | sort = np.argsort(colors[:,0])[np.isin(np.sort(colors[:,0]), attributions[:,0])] 19 | colors = colors[sort, 1] 20 | 21 | # Change vmin, vmax for colors if needed 22 | vmin, vmax = 0, 1 23 | if '--vmin' in sys.argv: 24 | vmin = int(sys.argv[sys.argv.index('--vmin')+1]) 25 | if '--vmax' in sys.argv: 26 | vmax = int(sys.argv[sys.argv.index('--vmax')+1]) 27 | 28 | 29 | figsize = (15,2) # None 30 | fig = plt.figure(figsize = figsize) 31 | ax = fig.add_subplot(111) 32 | ax.spines['top'].set_visible(False) 33 | ax.spines['right'].set_visible(False) 34 | sort = np.argsort(colors) 35 | a = ax.scatter(attributions[sort,0], attributions[sort,1], cmap = 'Blues', vmin = vmin, vmax = vmax, c = colors[sort], edgecolor = 'grey') 36 | fig.colorbar(a, pad = 0.01, fraction = 0.09, shrink = 0.25, aspect = 2, anchor = (0.,0.9)) 37 | 38 | # Show markers with increased size and red edgecolor 39 | if '--drivers' in sys.argv: 40 | dobj = open(sys.argv[sys.argv.index('--drivers')+1], 'r').readlines() 41 | dfile = [line.strip().split() for line in dobj] 42 | dfile = np.array(dfile,dtype =float) 43 | dloc = np.where(np.isin(attributions[:,0], dfile[:,0])) 44 | dloc = dloc[0] 45 | size0 = plt.rcParams['lines.markersize'] ** 2 46 | ax.scatter(attributions[dloc,0], attributions[dloc,1], s = size0 * (1+2*dfile[:,-3]), linewidths = 1.2, cmap = 'Blues', vmin = 0, vmax =1, c = colors[dloc], edgecolor = 'red') 47 | for t in dloc: 48 | print(attributions[t,0]) 49 | ax.text(attributions[t,0], attributions[t,1], str(int(attributions[t,0])), ha = 'left', va = 'bottom') 50 | 51 | else: 52 | dloc = np.argsort(-np.absolute(attributions[:,1]))[:6] 53 | # Add the name of the top attributions to the figure 54 | if '--name_top_attributions' in sys.argv: 55 | for t in dloc: 56 | print(attributions[t,0]) 57 | ax.text(attributions[t,0], attributions[t,1], str(int(attributions[t,0])), ha = 'left') 58 | 59 | # Mark the location of the TSS in the figure 60 | if '--tss' in sys.argv: 61 | tss = np.genfromtxt(sys.argv[sys.argv.index('--tss')+1], dtype = str) 62 | tss = int(tss[list(tss[:,0]).index(sys.argv[sys.argv.index('--tss')+2]),1]) 63 | ax.plot([tss, tss],[np.amin(attributions[:,1]), np.amax(attributions[:,1])], ls = '--', color = 'goldenrod', alpha = 0.8) 64 | ax.set_title(sys.argv[sys.argv.index('--tss')+2]) 65 | xmin, xmax = np.amin(attributions[:,0]), np.amax(attributions[:,0]) 66 | ax.set_xticks([tss-int((xmax-xmin)/40000)*10000, tss, tss+int((xmax-xmin)/40000)*10000]) 67 | ax.set_xticklabels([-int((xmax-xmin)/40000)*10000, 'TSS', int((xmax-xmin)/40000)*10000]) 68 | 69 | ax.set_xlabel('Genomic location') 70 | ax.set_ylabel(ylabel) 71 | mina, maxa = np.amin(attributions[:,1]), np.amax(attributions[:,1]) 72 | dista = maxa - mina 73 | ax.set_ylim([mina-0.1*dista, maxa + 0.1*dista]) 74 | ax.set_xlim([np.amin(attributions[:,0]), np.amax(attributions[:,0])]) 75 | fig.savefig(os.path.splitext(sys.argv[1])[0]+'.jpg', dpi = 250, bbox_inches = 'tight') 76 | if '--tss' in sys.argv: 77 | zeroline = ax.plot([xmin, xmax], [0,0], color = 'grey', alpha = 0.5, zorder = -1) 78 | ax.set_xlim([tss-4000, tss+4000]) 79 | ax.set_xticks([tss-2000, tss+2000]) 80 | ax.set_xticklabels([-2000,2000]) 81 | fig.savefig(os.path.splitext(sys.argv[1])[0]+'tsszoom.jpg', dpi = 350, bbox_inches = 'tight') 82 | 83 | ax.set_xlim([np.amin(attributions[dloc,0])-2000, np.amax(attributions[dloc,0])+2000]) 84 | fig.savefig(os.path.splitext(sys.argv[1])[0]+'zoom.jpg', dpi = 350, bbox_inches = 'tight') 85 | 86 | -------------------------------------------------------------------------------- /enformer_analysis/basic_pred_gene_expr.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pandas as pd 3 | import numpy as np 4 | import sparse 5 | import os 6 | import tensorflow as tf 7 | import tensorflow_hub as hub 8 | import time as time 9 | 10 | save_path = '/data/aspiro17/enformer_res/' 11 | model_path = 'https://tfhub.dev/deepmind/enformer/1' 12 | data_path = '/data/mostafavilab/bng/rosmapAD/projects/insilicoMutagenesis/extractSequence/results/sequence100K/' 13 | 14 | track_idx = 4980 # CAGE:brain, adult, 15 | center_bins = [447,448,449] # three center bins of enformer output 16 | 17 | save_path = '/data/aspiro17/enformer_res/' 18 | model_path = 'https://tfhub.dev/deepmind/enformer/1' 19 | 20 | padded_input_len = 393216 # full input to enformer 21 | input_len = 196608 22 | starting_seq_len = 200001 # starting length of reference seq 23 | pad_before = int((padded_input_len - input_len)/2) # "pad" sequence to padded_input_len 24 | 25 | mid_index = int((starting_seq_len-1)/2) 26 | start_index = int(mid_index - (input_len-1)/2) 27 | end_index = int(mid_index + (input_len-1)/2) + 1 28 | 29 | # enformer helper functs 30 | class Enformer: 31 | 32 | def __init__(self, tfhub_url): 33 | self._model = hub.load(tfhub_url).model 34 | 35 | def predict_on_batch(self, inputs): 36 | predictions = self._model.predict_on_batch(inputs) 37 | return {k: v.numpy() for k, v in predictions.items()} 38 | 39 | 40 | # other helper functs 41 | def pad_genes(pat_current_gene, mat_current_gene,shift=0): 42 | # pad genes to create a shift 43 | 44 | padded_pat_current_gene = np.pad(pat_current_gene, pad_width=((0,0),(pad_before-start_index+shift, pad_before-(starting_seq_len - end_index)-shift), (0,0))) 45 | padded_mat_current_gene = np.pad(mat_current_gene, pad_width=((0,0),(pad_before-start_index+shift, pad_before-(starting_seq_len - end_index)-shift), (0,0))) 46 | 47 | return padded_pat_current_gene, padded_mat_current_gene 48 | 49 | 50 | 51 | def get_pred(gene_id, curr_chr): 52 | 53 | curr_gene_path = data_path + 'chr' + str(curr_chr) + '/' + gene_id + '.npz' 54 | current_sparse = sparse.load_npz(curr_gene_path) 55 | current_gene = current_sparse.todense() 56 | 57 | # init results 58 | gene_res = np.zeros([num_subj,2],dtype=float) 59 | 60 | pat_current_gene = current_gene[:4,:] 61 | mat_current_gene = current_gene[4:,:] 62 | 63 | # transpose to be num_sub x seq_len x 4 64 | pat_current_gene = np.transpose(pat_current_gene, (2,1,0)) 65 | mat_current_gene = np.transpose(mat_current_gene, (2,1,0)) 66 | 67 | # go from saved: seqP=='A',seqP=='T',seqP=='G',seqP=='C', to Enformer: 'ACGT' pat_current_gene = pat_current_gene[:, :, [0,3,2,1]] 68 | mat_current_gene = mat_current_gene[:, :, [0,3,2,1]] 69 | 70 | padded_pat_current_gene, padded_mat_current_gene = pad_genes(pat_current_gene, mat_current_gene,shift=0) 71 | 72 | for sub in range(num_subj): 73 | 74 | pat_single_sub = np.reshape(padded_pat_current_gene[sub,:,:], (1, padded_input_len, 4)) 75 | mat_single_sub = np.reshape(padded_mat_current_gene[sub,:,:], (1, padded_input_len, 4)) 76 | 77 | pat_out = model.predict_on_batch(pat_single_sub) 78 | mat_out = model.predict_on_batch(mat_single_sub) 79 | 80 | pat_out = pat_out['human'][0] 81 | mat_out = mat_out['human'][0] 82 | 83 | pat_bins_sum = pat_out[center_bins[0]][track_idx] + pat_out[center_bins[1]][track_idx] + pat_out[center_bins[2]][track_idx] 84 | mat_bins_sum = mat_out[center_bins[0]][track_idx] + mat_out[center_bins[1]][track_idx] + mat_out[center_bins[2]][track_idx] 85 | 86 | gene_res[sub,0] = pat_bins_sum 87 | gene_res[sub,1] = mat_bins_sum 88 | 89 | 90 | np.save(save_path + gene_id, gene_res) 91 | 92 | 93 | if __name__ == '__main__': 94 | 95 | # load model 96 | model = Enformer(model_path) 97 | 98 | # load TSS info 99 | gene_win_info = pd.read_csv('/data/mostafavilab/bng/rosmapAD/data/expressionData/DLPFC/20220207-bulk-RNAseq/gene-ids-and-positions.tsv', sep = '\t', index_col = 1) 100 | 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument('which_gpu', action="store", default='0') 103 | parser.add_argument('gene_file', action="store", default='test') 104 | args = parser.parse_args() 105 | 106 | os.environ["CUDA_VISIBLE_DEVICES"]=args.which_gpu 107 | genes = np.load(args.gene_file) 108 | 109 | for gene in genes: 110 | print(gene) 111 | chrom = gene_win_info[gene_win_info['gene_id'] == gene]['chr_hg38'] 112 | get_pred(gene, chrom) 113 | 114 | 115 | -------------------------------------------------------------------------------- /enformer_analysis/cluster_grad_attributions.py: -------------------------------------------------------------------------------- 1 | # Read in gradient attributions along the entire sequence and cluster them on normalized standard deviation 2 | # python3 ../cluster_grad_attributions.py ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list ../Prediction_correlationsCageAdultBrain.tx 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import sys, os 7 | from sklearn.cluster import AgglomerativeClustering 8 | 9 | # List of genes 10 | genes = np.genfromtxt(sys.argv[1], dtype = str) 11 | # Correlation of genes with observed across individuals 12 | genecorr = np.genfromtxt(sys.argv[2], dtype = str) 13 | 14 | # sort and select geneset 15 | genes = np.sort(genes) 16 | genecorr = genecorr[np.argsort(genecorr[:,0])[np.isin(np.sort(genecorr[:,0]), genes)]] 17 | genecorr = genecorr[:,1].astype(float) > 0 18 | 19 | nts = list('ACGT') 20 | ntsar = np.array(nts) 21 | window = 128 # window size to compute std 22 | 23 | allstd = [] 24 | allatt = [] 25 | # read in gradient files and one-hot sequence encoding files and compute std and mena 26 | for g, gene in enumerate(genes): 27 | grad = np.load('../gradient_tensors/'+gene+'_complete_grad_at_ref.npy')[0] 28 | genehot = np.load('../ref_seqs/' +gene+'.npy').T 29 | genehot = genehot[:, [0,3,2,1]] 30 | 31 | lengene = np.shape(genehot)[0] 32 | 33 | # determine the center of the gradient and the sequence file 34 | centergrad = int(np.shape(grad)[0]/2) -1 35 | centergene = int(np.shape(genehot)[0]/2)-1 36 | # adjust the size of the gradient to the size of the sequence 37 | offset = centergrad - centergene 38 | grad = grad[offset:offset+lengene] 39 | 40 | attribution = np.sum(grad,axis = 1) 41 | 42 | attribution[np.sum(genehot,axis = 1) > 0] = grad[genehot==1] - (attribution[np.sum(genehot,axis = 1) > 0]-grad[genehot==1])/3 43 | attribution = attribution/np.std(attribution) 44 | allatt.append(attribution) 45 | stds = [] 46 | for i in range(0, lengene-int(window/2), window): 47 | stds.append(np.std(attribution[i:i+window])) 48 | allstd.append(stds) 49 | 50 | allstd = np.array(allstd) 51 | allatt = np.array(allatt) 52 | 53 | # Perform clustering and visualization of clusters 54 | minsize = 0 # minsize could be adjusted to only show clusters with at least minsize genes 55 | for nc in [2,5,10,20,40]: # Cluster data into 2,5,10,20,40 clusters with complete linkage 56 | clustering = AgglomerativeClustering(n_clusters = nc, affinity = 'euclidean', linkage = 'complete', distance_threshold= None).fit(allstd) 57 | clusters = clustering.labels_ 58 | unclust, unclustn = np.unique(clusters, return_counts = True) 59 | unclust, unclustn = unclust[np.argsort(unclustn)], np.sort(unclustn) 60 | print(len(unclust), int(np.sum(unclustn>minsize))) 61 | height = 0.8/int(np.sum(unclustn>minsize)) 62 | fig = plt.figure(figsize=(15,1*len(unclust[unclustn>minsize])), dpi = 100) 63 | axs = [] 64 | axbs = [] 65 | meanstds=[] 66 | maxattrib = [] 67 | maxnum = [] 68 | for i,c in enumerate(unclust[unclustn>minsize]): 69 | ax = fig.add_subplot(int(np.sum(unclustn>minsize)),1,i+1) 70 | ax.set_position([0.1, 0.9 - (i+0.95)*height, 0.8, height * 0.9]) 71 | ax.spines['top'].set_visible(False) 72 | ax.spines['right'].set_visible(False) 73 | cmean = np.mean(np.absolute(allatt[clusters == c]), axis = 0) 74 | smean = np.mean(allstd[clusters == c], axis = 0) 75 | maxattrib.append(np.amax(cmean)) 76 | meanstds.append(smean) 77 | ax.fill_between(np.arange(lengene), cmean, color = 'grey', alpha = 0.5) 78 | ax.plot(np.arange(window/2, lengene, window), smean) 79 | ax.plot([centergene, centergene],[np.amin(cmean), np.amax(cmean)], color= 'goldenrod', ls = '--') 80 | ax.plot([centergene-192, centergene-192],[np.amin(cmean), np.amax(cmean)/2], color= 'darkgoldenrod', ls = '--') 81 | ax.plot([centergene+192, centergene+192],[np.amin(cmean), np.amax(cmean)/2], color= 'darkgoldenrod', ls = '--') 82 | if i == int(np.sum(unclustn>minsize))-1: 83 | ax.set_xticks([centergene-20000, centergene, centergene +20000]) 84 | ax.set_xticklabels(['-20,000', 'TSS', '20,000'], rotation = 60) 85 | else: 86 | ax.set_xticks([centergene-20000, centergene, centergene +20000]) 87 | ax.tick_params(bottom = True, labelbottom = False) 88 | ax.set_xlim([0,lengene]) 89 | axs.append(ax) 90 | axb = fig.add_subplot(int(np.sum(unclustn>minsize)),10,i+1) 91 | axb.set_position([0.91, 0.9 - (i+0.95)*height, 0.03, height * 0.9]) 92 | axb.spines['top'].set_visible(False) 93 | axb.spines['left'].set_visible(False) 94 | axb.bar([0,1],[np.sum(genecorr[clusters == c]), np.sum(~genecorr[clusters == c])], width = 0.6, color = 'navy') 95 | axb.set_xticks([0,1]) 96 | axb.set_xticklabels(['+', '-']) 97 | axb.tick_params(left = False, labelleft = False, right = True, labelright = True) 98 | maxnum.append(max(np.sum(~genecorr[clusters == c]), np.sum(genecorr[clusters ==c]))) 99 | axb.set_yticks([maxnum[-1]]) 100 | axbs.append(axb) 101 | print('ylim', np.median(maxattrib)) 102 | for i, ax in enumerate(axs): 103 | ax.set_ylim([0.05, np.median(maxattrib)]) 104 | axbs[i].set_ylim([0,np.amax(maxnum)]) 105 | fig.savefig('Gradientstd128bp_clusters'+str(nc)+'.jpg', dpi = 200, bbox_inches = 'tight') 106 | print('Gradientstd128bp_clusters'+str(nc)+'.jpg') 107 | meanstds = np.array(meanstds) 108 | for m, mcut in enumerate([2,4,6]): 109 | mask = np.where(np.sum(meanstds>mcut,axis = 0)>0)[0] 110 | print(mask[0]*window-centergene, mask[-1]*window-centergene) 111 | for i, ax in enumerate(axs): 112 | ax.set_xlim([mask[0]*window, mask[-1]*window]) 113 | fig.savefig('Gradientstd128bp_clusters'+str(nc)+'_zoomcut'+str(mcut)+'.jpg', dpi = 200, bbox_inches = 'tight') 114 | print('Gradientstd128bp_clusters'+str(nc)+'_zoomcut'+str(mcut)+'.jpg') 115 | for r in [10000, 5000, 2000, 1000]: 116 | for i, ax in enumerate(axs): 117 | ax.set_xlim([centergene-r, centergene+r]) 118 | if i == int(np.sum(unclustn>minsize))-1: 119 | ax.set_xticks([min(centergene-r+500, centergene-r+1000), centergene, max(centergene +r-1000,centergene +r-500)]) 120 | 121 | ax.set_xticklabels([str(int(min(centergene-r+500, centergene-r+1000)-centergene)), 'TSS', str(int(max(centergene +r-1000,centergene +r-500)-centergene))], rotation = 60) 122 | else: 123 | ax.set_xticks([centergene-r+1000, centergene, centergene+r-1000]) 124 | fig.savefig('Gradientstd128bp_clusters'+str(nc)+'_zoomin'+str(r)+'.jpg', dpi = 200, bbox_inches = 'tight') 125 | print('Gradientstd128bp_clusters'+str(nc)+'_zoomin'+str(r)+'.jpg') 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /enformer_analysis/combine_eqtl_types.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys, os 3 | import glob 4 | 5 | 6 | files=np.sort(glob.glob('ENSG*'+sys.argv[1])) 7 | 8 | for f, fil in enumerate(files): 9 | gene = fil.strip(sys.argv[1]) 10 | dobj = open(fil, 'r').readlines() 11 | if len(dobj) > 0: 12 | for line in dobj: 13 | line = line.strip().split() 14 | print(gene, line[0], line[1], line[2]) 15 | 16 | -------------------------------------------------------------------------------- /enformer_analysis/compute_correlation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys, os 3 | from scipy.stats import pearsonr 4 | import matplotlib.pyplot as plt 5 | 6 | def read(file, delimiter = ' '): 7 | genes, indv, exp = [],[],[] 8 | for l, line in enumerate(open(file,'r').readlines()): 9 | if l == 0: 10 | genes = line.strip('#').strip().split(delimiter) 11 | else: 12 | line = line.strip().split(delimiter) 13 | indv.append(line[0]) 14 | exp.append(line[1:]) 15 | return np.array(genes), np.array(indv), np.array(exp, dtype = float) 16 | 17 | delimiter = ' ' 18 | if '--delimiter' in sys.argv: 19 | delimiter = sys.argv[sys.argv.index('--delimiter')+1] 20 | 21 | obgenes, obindv, obexp = read(sys.argv[1], delimiter = delimiter) 22 | enfgenes, enfindv, enfexp = read(sys.argv[2], delimiter = delimiter) 23 | 24 | e_, s1a = np.unique(enfgenes, return_index = True) 25 | o_, s2a = np.unique(obgenes, return_index = True) 26 | 27 | s1a, s1b = s1a[np.isin(e_, obgenes)], np.argsort(enfindv)[np.isin(np.sort(enfindv), obindv)] 28 | s2a, s2b = s2a[np.isin(o_, enfgenes)], np.argsort(obindv)[np.isin(np.sort(obindv), enfindv)] 29 | 30 | enfgenes, enfindv, enfexp = enfgenes[s1a], enfindv[s1b], enfexp[s1b][:,s1a] 31 | obgenes, obindv, obexp = obgenes[s2a], obindv[s2b], obexp[s2b][:,s2a] 32 | print(np.array_equal(enfgenes, obgenes), len(enfgenes), len(obgenes)) 33 | 34 | 35 | if '--get_allstats' in sys.argv: 36 | print('# Gene PearsonR P-value MeanObs StdObs MeanEnf StdEnf') 37 | for g, gene in enumerate(enfgenes): 38 | pears = pearsonr(enfexp[:,g], obexp[:,g]) 39 | print(gene, round(pears[0],3), round(pears[1],3), round(np.mean(obexp[:,g]),2), round(np.std(obexp[:,g]),3), round(np.mean(enfexp[:,g]),2), round(np.std(enfexp[:,g]),3)) #, np.unique(enfexp[:,g])[0], len(np.unique(enfexp[:,g])), len(np.unique(obexp[:,g]))) 40 | else: 41 | for g, gene in enumerate(enfgenes): 42 | pears = pearsonr(enfexp[:,g], obexp[:,g]) 43 | print(gene, round(pears[0],3)) 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /enformer_analysis/compute_eqtl.py: -------------------------------------------------------------------------------- 1 | # From observed gene expression across individuals and thei genoptypes compute the eqtl value and correlation with observed gene expression 2 | 3 | import numpy as np 4 | import sys, os 5 | from scipy.stats import pearsonr 6 | from sklearn.linear_model import LinearRegression 7 | 8 | # Set of genes for which eqtls will be computed 9 | tgenes = np.genfromtxt(sys.argv[1], dtype = str) 10 | 11 | # load observed gene exppression 12 | obsfile = open('../Observed_gene_expressionENSG.tsv','r').readlines() 13 | inds = [] 14 | obs_exp = [] 15 | for l, line in enumerate(obsfile): 16 | line = line.strip().split() 17 | if l == 0: 18 | genes = line 19 | else: 20 | inds.append(line[0]) 21 | obs_exp.append(line[1:]) 22 | 23 | genes = np.array(genes) 24 | inds = np.array(inds) 25 | obs_exp = np.array(obs_exp, dtype = float) 26 | 27 | for i, gene in enumerate(tgenes): 28 | print(gene) 29 | g = list(genes).index(gene) 30 | obs_e = obs_exp[:,g] 31 | 32 | # Load SNV file 33 | snpfile = np.load('../variant_info_100k/'+gene+'snp_info.npz' ) 34 | snp_info = snpfile['snps'].astype(float) 35 | snp_name = snpfile['rows'].astype(str) 36 | ind_names = snpfile['columns'].astype(str) 37 | 38 | sortobs = np.argsort(inds)[np.isin(np.sort(inds), ind_names)] 39 | obs_e = obs_e[sortobs] 40 | 41 | sortvar = np.argsort(ind_names)[np.isin(np.sort(ind_names), inds)] 42 | snp_info = snp_info[:, sortvar] 43 | ind_names = ind_names[sortvar] 44 | # check sorting of individuals 45 | print(np.array_equal(inds[sortobs], ind_names)) 46 | 47 | obj = open(gene+'_eqtl.txt', 'w') 48 | obj2 = open(gene+'_corr.txt', 'w') 49 | 50 | lr = LinearRegression(fit_intercept=True) 51 | corr = [] 52 | eqtl = [] 53 | for s, sn in enumerate(snp_name): 54 | pear = pearsonr(snp_info[s], obs_e)[0] 55 | lr = lr.fit(snp_info[s].reshape(-1,1),obs_e) 56 | coef_ = lr.coef_ 57 | obj2.write(sn+' '+str(round(pear,4))+'\n') 58 | obj.write(sn+' '+str(round(coef_[0],6))+'\n') 59 | corr.append(pear) 60 | eqtl.append(coef_[0]) 61 | argmax = np.argmax(np.absolute(np.nan_to_num(corr))) 62 | print(corr[argmax], eqtl[argmax]) 63 | argmax = np.argmax(np.absolute(np.nan_to_num(eqtl))) 64 | print(corr[argmax], eqtl[argmax]) 65 | obj.close() 66 | obj2.close() 67 | -------------------------------------------------------------------------------- /enformer_analysis/compute_population_frequency.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys, os 3 | 4 | gene = sys.argv[1] 5 | snpfile = np.load(gene+'snp_info.npz', allow_pickle = True) 6 | snps = snpfile['snps'] 7 | locs = snpfile['rows'] 8 | 9 | freq = np.around(np.mean(snps>0, axis = 1),3) 10 | if '--centered' in sys.argv: 11 | freq[freq>0.5] = np.around(1.-freq[freq>0.5],3) 12 | np.savetxt(gene+'_frequencycentered.txt', np.array([locs, freq]).T, fmt = '%s') 13 | else: 14 | np.savetxt(gene+'_frequency.txt', np.array([locs, freq]).T, fmt = '%s') 15 | 16 | 17 | -------------------------------------------------------------------------------- /enformer_analysis/compute_tstatpvalue.py: -------------------------------------------------------------------------------- 1 | # Uses mean and std of absolute correlation from Null and compute p-value and corrected p-value for each gene 2 | # Usage 3 | # python3 compute_tstatpvalue.py ../Prediction_correlationsCageAdultBrain.txt GeneSpecific_CorrelationtoObsRandomNull.txt 4 | 5 | import numpy as np 6 | import scipy.stats as stats 7 | import sys, os 8 | from statsmodels.stats.multitest import multipletests 9 | # Read file with correlation between observed and predicted expression for all individuals 10 | corr = np.genfromtxt(sys.argv[1], dtype = str) 11 | # Reat file with mean and std of random null 12 | null = np.genfromtxt(sys.argv[2], dtype = str) 13 | 14 | # sort corr file so that genes match null file 15 | cnames, csort = np.unique(corr[:,0], return_index = True) 16 | csort = csort[np.isin(cnames, null[:,0])] 17 | corr = corr[csort] 18 | # sort null file so that genes match corr file 19 | nsort = np.argsort(null[:,0])[np.isin(np.sort(null[:,0]), corr[:,0])] 20 | null =null[nsort] 21 | 22 | # check sorting process 23 | print(np.array_equal(corr[:,0], null[:,0])) 24 | 25 | # select columns that contain the data: 26 | # sm: predicted correlation 27 | # m: mean of Null 28 | # sv: std of Null 29 | # n: number of random sets to compute mean and std 30 | sm, m, sv, n = np.absolute(corr[:,1].astype(float)), null[:,1].astype(float), null[:,2].astype(float), null[:,3].astype(float) 31 | sm, m, sv = np.nan_to_num(sm), np.nan_to_num(m), np.nan_to_num(sv) 32 | 33 | tt = (sm-m)/np.sqrt(sv/n) # t-statistic for mean 34 | pval = stats.t.sf(np.abs(tt), n-1) 35 | if '--bothsided' in sys.argv: 36 | pval = pval *2 37 | else: 38 | pval[tt < 0] = 1 39 | 40 | # Correct for multiple testing with benjamini hochberg 41 | issig, corr_pvals, calphasid, calphabonf = multipletests(np.nan_to_num(pval,nan=1), alpha=0.05, method='fdr_bh') 42 | pval = -np.log10(pval) 43 | corr_pvals = -np.log10(corr_pvals) 44 | 45 | if '--bothsided' in sys.argv: 46 | pval = np.sign(tt)*pval 47 | corr_pvals = np.sign(tt)*corr_pvals 48 | 49 | 50 | np.savetxt(os.path.splitext(sys.argv[1])[0]+os.path.splitext(os.path.split(sys.argv[2])[1])[0]+'_tstats.txt', np.concatenate([corr[:,[0]], np.around(tt,3).reshape(-1,1), np.around(pval,3).reshape(-1,1), np.around(corr_pvals,3).reshape(-1,1)], axis = 1).astype(str), fmt = '%s', header = 'Gene Tstat log10pvalue BH_corrected') 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /enformer_analysis/count_drivers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys, os 3 | import glob 4 | 5 | ending = sys.argv[1] 6 | files = np.sort(glob.glob('ENSG*'+ending)) 7 | 8 | for f, fi in enumerate(files): 9 | flen = open(fi,'r').readlines() 10 | print(fi.strip(ending), len(flen)) 11 | 12 | -------------------------------------------------------------------------------- /enformer_analysis/driver_distance_to_tss.py: -------------------------------------------------------------------------------- 1 | # Generate file that computes the distance of the drivers to the TSS 2 | # python3 ../driver_distance_to_tss.py _ism_attributions_driversfw.txt ../geneTSS.txt > DistancetoTSS_ism_attributions_driversfw.txt 3 | 4 | import numpy as np 5 | import sys, os 6 | import glob 7 | 8 | dfiles = np.sort(glob.glob('ENSG*'+sys.argv[1])) 9 | tssfile = np.genfromtxt(sys.argv[2], dtype = str) 10 | 11 | for d, df in enumerate(dfiles): 12 | gene = df.strip(sys.argv[1]) 13 | dobj = open(df, 'r').readlines() 14 | if len(dobj) > 0: 15 | dfile = [line.strip().split() for line in dobj] 16 | dfile = np.array(dfile,dtype =float) 17 | tss = int(tssfile[list(tssfile[:,0]).index(gene),1]) 18 | for e, var in enumerate(dfile): 19 | print(gene, int(var[0]), int(var[0]-tss), var[-3]) 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /enformer_analysis/eqtl_attribution_plot.py: -------------------------------------------------------------------------------- 1 | # plot attributions versus eqtl values 2 | 3 | import numpy as np 4 | import sys, os 5 | import matplotlib.pyplot as plt 6 | 7 | # Read attribution values for all SNVs 8 | attributions = np.genfromtxt(sys.argv[1]) 9 | # Read eqtls 10 | eqtl = np.nan_to_num(np.genfromtxt(sys.argv[2])) 11 | # sort if necessary 12 | sort = np.argsort(attributions[:,0])[np.isin(np.sort(attributions[:,0]), eqtl[:,0])] 13 | attributions = attributions[sort] 14 | sort = np.argsort(eqtl[:,0])[np.isin(np.sort(eqtl[:,0]), attributions[:,0])] 15 | eqtl = eqtl[sort] 16 | #check sorting 17 | if not np.array_equal(attributions[:,0], eqtl[:,0]): 18 | print('eqtl and attributions dont match') 19 | sys.exit() 20 | 21 | xlabel = sys.argv[3] 22 | ylabel = sys.argv[4] 23 | 24 | if '--colors' in sys.argv: 25 | colors = np.genfromtxt(sys.argv[sys.argv.index('--colors')+1]) # snp_info file 26 | sort = np.argsort(colors[:,0])[np.isin(np.sort(colors[:,0]), attributions[:,0])] 27 | colors = colors[sort] 28 | if not np.array_equal(attributions[:,0], colors[:,0]): 29 | print('colors and attributions dont match') 30 | sys.exit() 31 | colors = colors[:,1] 32 | # select SNVs based on color assignment 33 | if '--minfreq' in sys.argv: 34 | mask = colors > float(sys.argv[sys.argv.index('--minfreq')+1]) 35 | attributions = attributions[mask] 36 | eqtl = eqtl[mask] 37 | colors = colors[mask] 38 | 39 | stdatt = np.std(attributions[:,1]) 40 | stdeqtl = np.std(eqtl[:,1]) 41 | 42 | maxatt = np.amax(np.absolute(attributions[:,1])) 43 | if '--norm_attribution' in sys.argv: 44 | attributions[:,1] /= maxatt 45 | stdatt /= maxatt 46 | maxatt = 1 47 | print(stdatt, stdeqtl) 48 | 49 | fig = plt.figure(figsize = (4.,3.5)) 50 | ax = fig.add_subplot(111) 51 | ax.spines['top'].set_visible(False) 52 | ax.spines['right'].set_visible(False) 53 | ax.plot([0,0],[-1,1], c = 'silver', lw = 0.7) 54 | ax.plot([-maxatt,maxatt],[0,0], c = 'silver', lw = 0.7) 55 | #ax.plot([stdatt,stdatt],[-1,1], c = 'green', lw = 0.7, ls = '--') 56 | #ax.plot([-stdatt,-stdatt],[-1,1], c = 'green', lw = 0.7, ls = '--') 57 | #ax.plot([-maxatt,maxatt],[stdeqtl,stdeqtl], c = 'green', lw = 0.7, ls = '--') 58 | #ax.plot([-maxatt,maxatt],[-stdeqtl,-stdeqtl], c = 'green', lw = 0.7, ls = '--') 59 | sort = np.argsort(colors) 60 | cab = ax.scatter(attributions[sort,1], eqtl[sort,1], cmap = 'Blues', vmin = 0, vmax =1, c = colors[sort], edgecolor = 'grey') 61 | fig.colorbar(cab, pad = 0., fraction = 0.09, shrink = 0.15, aspect = 2, anchor = (0.,0.99)) 62 | 63 | # Mark driver SNVs in plot with enlarged dots and red edgecolors 64 | if '--drivers' in sys.argv: 65 | dobj = open(sys.argv[sys.argv.index('--drivers')+1], 'r').readlines() 66 | if len(dobj) > 0: 67 | dfile = [line.strip().split() for line in dobj] 68 | dfile = np.array(dfile,dtype =float) 69 | dfile = dfile[np.isin(dfile[:,0], attributions[:,0])] 70 | dloc = np.where(np.isin(attributions[:,0], dfile[:,0])) 71 | dloc = dloc[0] 72 | size0 = plt.rcParams['lines.markersize'] ** 2 73 | ax.scatter(attributions[dloc,1], eqtl[dloc,1], s = size0 * (1+2*dfile[:,-3]), linewidths = 1.2, cmap = 'Blues', vmin = 0, vmax =1, c = colors[dloc], edgecolor = 'red') 74 | dtypefile = open(os.path.splitext(sys.argv[sys.argv.index('--drivers')+1])[0]+'_types.txt', 'w') 75 | for t in dloc: 76 | print(int(attributions[t,0]), attributions[t,1]) 77 | if '--name_driver' in sys.argv: 78 | ax.text(int(attributions[t,1]), eqtl[t,1], str(int(attributions[t,0])), ha = 'left') 79 | # Determine SNP type of drivers by looking at position in eqtl and attribution plot 80 | # For each driver determine where it's located based on the std of eqtls and attributions 81 | dtypefile.write(str(int(attributions[t,0]))+' '+str(round(attributions[t,1]/stdatt,2))+' '+str(round(eqtl[t,1]/stdeqtl,2))+'\n') 82 | 83 | ax.set_xticks([-np.around(maxatt,3), 0, np.around(maxatt,3)]) 84 | 85 | ax.set_xlabel(xlabel) 86 | ax.set_ylabel(ylabel) 87 | if '--show' in sys.argv: 88 | plt.show() 89 | else: 90 | dpi, fmt = 250, '.jpg' 91 | if '--dpi' in sys.argv: 92 | dpi = int(sys.argv[sys.argv.index('--dpi')+1]) 93 | if '--fmt' in sys.argv: 94 | fmt = sys.argv[sys.argv.index('--fmt')+1] 95 | fig.savefig(os.path.splitext(sys.argv[1])[0]+'_vs_'+sys.argv[5]+fmt, transparent = True, dpi = dpi, bbox_inches = 'tight') 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /enformer_analysis/eqtl_types.py: -------------------------------------------------------------------------------- 1 | # For all the drivers generate a file that contains the attribution normed by the max absolute attribution and their eqtl value 2 | 3 | import numpy as np 4 | import sys, os 5 | 6 | attributions = np.genfromtxt(sys.argv[1]) 7 | eqtl = np.nan_to_num(np.genfromtxt(sys.argv[2])) 8 | 9 | stdatt = np.amax(np.abs(attributions[:,1])) 10 | stdeqtl = 1. #np.std(eqtl[:,1]) 11 | maxeqtl = np.amax(np.abs(eqtl[:,1])) 12 | # sort attribution and eqtl 13 | attributions = attributions[np.argsort(attributions[:,0])[np.isin(np.sort(attributions[:,0]), eqtl[:,0])]] 14 | eqtl = eqtl[np.argsort(eqtl[:,0])[np.isin(np.sort(eqtl[:,0]), attributions[:,0])]] 15 | if not np.array_equal(attributions[:,0], eqtl[:,0]): 16 | print('Not sorted correctly, check files') 17 | sys.exit() 18 | 19 | 20 | dobj = open(sys.argv[3], 'r').readlines() 21 | if len(dobj) > 0: 22 | dfile = [line.strip().split() for line in dobj] 23 | dfile = np.array(dfile,dtype =float) 24 | dloc = np.where(np.isin(attributions[:,0], dfile[:,0])) 25 | dloc = dloc[0] 26 | dtypefile = open(os.path.splitext(sys.argv[3])[0]+'_types.txt', 'w') 27 | for t in dloc: 28 | dtypefile.write(str(int(attributions[t,0]))+' '+str(round(attributions[t,1]/stdatt,3))+' '+str(round(eqtl[t,1]/stdeqtl,6))+'\n') 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /enformer_analysis/extract_ism_stats_around_drivers.py: -------------------------------------------------------------------------------- 1 | # Takes ISM attribution windows around the drivers and compares to "global" attributions stats from the entire sequence 2 | # Determines standardized effect of base change 3 | # Determine size of susequent motifs at drivers with certain impact 4 | # Determine distance to base with certain impact 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import sys, os 9 | 10 | # Read drivers 11 | drivers = np.genfromtxt(sys.argv[1], dtype = str) # file with all drivers 12 | 13 | # Select subset of genes 14 | if '--geneset' in sys.argv: 15 | geneset = np.genfromtxt(sys.argv[sys.argv.index('--geneset')+1], dtype = str) 16 | drivers = drivers[np.isin(drivers[:,0], geneset)] 17 | 18 | genes = drivers[:,0] 19 | gloc = drivers[:,1].astype(int) 20 | 21 | # Read in global stats for isms 22 | #globalstds = np.genfromtxt('baseline_ism/Baseline_stats.txt', dtype = str) 23 | globalstds = np.genfromtxt('../TSS_ISM/tss_attribution_stats.txt', dtype = str) 24 | 25 | nts = list('ACGT') 26 | ntsar = np.array(nts) 27 | 28 | ar = [] 29 | for g, gene in enumerate(genes): 30 | # load the attributions around the driver snps at the reference sequence 31 | std = float(globalstds[list(globalstds[:,0]).index(gene),2]) 32 | attatreffile = np.load('from_ref/'+gene+'_'+str(gloc[g])+'_attribution.npz') 33 | attatref, seqatref = attatreffile['attribution'], str(attatreffile['seq']) 34 | attatref/=std 35 | attatref = -np.sum(attatref, axis = 1)/3 36 | # load the attributions around the driver snps at the variant sequence 37 | attatvarfile = np.load('from_main_var/'+gene+'_'+str(gloc[g])+'_attribution.npz') 38 | attatvar, seqatvar = attatvarfile['attribution'], list(str(attatreffile['seq'])) 39 | attatvar/=std 40 | attatvar = -np.sum(attatvar, axis = 1)/3 41 | 42 | # check check direction to each other of zscore and motifs 43 | if '--checkmax' in sys.argv: 44 | thresh = np.array([0.05,0.1,0.2,0.5]) 45 | else: 46 | thresh = np.array([1.64,1.96,2.58,3.29]) 47 | print(thresh) 48 | loc = int(len(attatref)/2) 49 | latt = len(attatref) 50 | 51 | score = attatref[loc] 52 | size = (thresh<=abs(score)).astype(int) 53 | dist = (thresh>abs(score)).astype(int) 54 | checksizeneg = thresh <= abs(score) 55 | checksizepos = thresh <= abs(score) 56 | checkdist = thresh > abs(score) 57 | i = 1 58 | # check the size of subsequent bases with thresh or how far first base is from snp with thresh 59 | while True: 60 | #print(loc-i,loc+j, score, checkdist, checksizeneg, checksizepos, attribution[loc+i], attribution[loc-i], dist) 61 | if loc + i < latt: 62 | checksizepos = checksizepos * (thresh <= (np.sign(score) * attatref[loc+i])) 63 | size += checksizepos.astype(int) 64 | checkdist = checkdist * (thresh > abs(attatref[loc+i])) 65 | if loc - i >=0: 66 | checksizeneg = checksizeneg * (thresh <= (np.sign(score) * attatref[loc-i])) 67 | size += checksizeneg.astype(int) 68 | checkdist = checkdist * (thresh > abs(attatref[loc-i])) 69 | if loc - i <=0 or loc + i >= latt-1: 70 | break 71 | dist += checkdist.astype(int) 72 | if not checksizepos.any() and not checksizeneg.any() and not checkdist.any(): 73 | break 74 | i += 1 75 | 76 | distscore = np.argmax(np.absolute(np.array([attatref[loc+dist], attatref[loc-dist]])), axis = 0) 77 | distscore = np.array([attatref[loc+dist], attatref[loc-dist]])[distscore,np.arange(4,dtype = int)] 78 | add = np.concatenate([[gene, str(gloc[g]), str(score)], size.astype(str), dist.astype(str), distscore.astype(str)]) 79 | 80 | score = attatvar[loc] 81 | size = (thresh<=abs(score)).astype(int) 82 | dist = (thresh>abs(score)).astype(int) 83 | checksizeneg = thresh <= abs(score) 84 | checksizepos = thresh <= abs(score) 85 | checkdist = thresh > abs(score) 86 | i = 1 87 | while True: 88 | #print(loc-i,loc+j, score, checkdist, checksizeneg, checksizepos, attribution[loc+i], attribution[loc-i], dist) 89 | if loc + i < latt: 90 | checksizepos = checksizepos * (thresh <= (np.sign(score) * attatvar[loc+i])) 91 | size += checksizepos.astype(int) 92 | checkdist = checkdist * (thresh > abs(attatvar[loc+i])) 93 | if loc - i >=0: 94 | checksizeneg = checksizeneg * (thresh <= (np.sign(score) * attatvar[loc-i])) 95 | size += checksizeneg.astype(int) 96 | checkdist = checkdist * (thresh > abs(attatvar[loc-i])) 97 | if loc - i <=0 or loc + i >= latt-1: 98 | break 99 | dist += checkdist.astype(int) 100 | if not checksizepos.any() and not checksizeneg.any() and not checkdist.any(): 101 | break 102 | i += 1 103 | 104 | distscore = np.argmax(np.absolute(np.array([attatvar[loc+dist], attatvar[loc-dist]])), axis = 0) 105 | distscore = np.array([attatvar[loc+dist], attatvar[loc-dist]])[distscore,np.arange(4,dtype = int)] 106 | add = np.append(add,np.concatenate([[str(score)], size.astype(str), dist.astype(str), distscore.astype(str)])) 107 | 108 | ar.append(add) 109 | print(ar[-1]) 110 | 111 | ar = np.array(ar) 112 | print(os.path.splitext(os.path.split(sys.argv[1])[1])[0] + '_ism_significance_stats.txt') 113 | np.savetxt(os.path.splitext(os.path.split(sys.argv[1])[1])[0] + '_ism_significance_stats.txt', ar.astype(str), fmt='%s') 114 | 115 | 116 | -------------------------------------------------------------------------------- /enformer_analysis/from_main_var_drivers_ISM.py: -------------------------------------------------------------------------------- 1 | # get ISM within 20bp of each driver SNP 2 | # start with the most common variant inserted, then insert each possible variant 3 | # same as from_ref_drivers_ISM but with main variant inserted 4 | 5 | import time 6 | import pandas as pd 7 | import numpy as np 8 | import sparse 9 | import os 10 | import tensorflow as tf 11 | import tensorflow_hub as hub 12 | import argparse 13 | 14 | track_idx = 4980 # CAGE:brain, adult, 15 | center_bins = [447,448,449] # three center bins of enformer output 16 | 17 | save_path = '/data/aspiro17/enformer_res/' 18 | model_path = 'https://tfhub.dev/deepmind/enformer/1' 19 | 20 | padded_input_len = 393216 # full input to enformer 21 | input_len = 196608 22 | starting_seq_len = 200001 # starting length of reference seq 23 | pad_before = int((padded_input_len - input_len)/2) # "pad" sequence to padded_input_len 24 | 25 | mid_index = int((starting_seq_len-1)/2) 26 | start_index = int(mid_index - (input_len-1)/2) 27 | end_index = int(mid_index + (input_len-1)/2) + 1 28 | 29 | nuc_order =['A', 'C', 'G', 'T'] 30 | nuc_onehots = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] 31 | 32 | gene_win_info = pd.read_csv('/data/mostafavilab/bng/rosmapAD/data/expressionData/DLPFC/20220207-bulk-RNAseq/gene-ids-and-positions.tsv', sep = '\t', index_col = 1) # contains TSS position info 33 | 34 | driver_win = 20 35 | os.environ["CUDA_VISIBLE_DEVICES"]='0' 36 | 37 | # enformer helper functs 38 | class Enformer: 39 | 40 | def __init__(self, tfhub_url): 41 | self._model = hub.load(tfhub_url).model 42 | 43 | def predict_on_batch(self, inputs): 44 | predictions = self._model.predict_on_batch(inputs) 45 | return {k: v.numpy() for k, v in predictions.items()} 46 | 47 | def get_gene_drivers(gene_id, chrom): 48 | 49 | model = Enformer(model_path) 50 | 51 | snp_info = pd.read_csv('/data/aspiro17/enformer_res/variant_info_100k/' + gene_id + '.csv',header=None,encoding='latin1') 52 | current_gene = np.load(save_path+'ref_seqs/'+gene_id+'.npy') 53 | 54 | starting_seq_len = np.shape(current_gene)[1] 55 | current_tss = int(gene_win_info['tss_hg38'][gene_id]) 56 | attrib_start_pos = current_tss - int(input_len/2) + 1 57 | 58 | # adjust sequence to be input to model 59 | current_gene = np.transpose(current_gene, (1,0)) # transpose to be seq_len x 4 60 | current_gene = current_gene[:, [0,3,2,1]] # go from previous: seqP=='A',seqP=='T',seqP=='G',seqP=='C', to Enformer: 'ACGT' 61 | current_gene = np.reshape(current_gene, (1, starting_seq_len, 4)) # add a 1 dimen 62 | current_gene = np.pad(current_gene, pad_width=((0,0),(pad_before-start_index, pad_before-(starting_seq_len - end_index)), (0,0))) # pad seq 63 | 64 | # get all drivers pos 65 | driver_pos = driver_dict[gene_id] 66 | 67 | for snp in driver_pos: 68 | 69 | snp = int(snp) 70 | driver_res = np.zeros([41,4]) # fill in the middle w ref 71 | ref = snp_info[snp_info[2] == snp][1].values[0] # get the single nuc 72 | 73 | # get most common alt 74 | sel_row = snp_info[snp_info[2] == snp].iloc[:,3:] # 3 on 75 | sel_row = sel_row.astype(str).values.flatten().tolist() 76 | sel_row = ' '.join(sel_row) 77 | counts = [sel_row.count('A'), sel_row.count('C'), sel_row.count('G'), sel_row.count('T')] 78 | counts[nuc_order.index(ref)] = 0 # we want the most common allele that's not the ref 79 | most_common_alt_idx = np.argmax(counts) # now take the max 80 | most_common_alt_onehot = nuc_onehots[most_common_alt_idx] 81 | 82 | snp_center_idx = (snp - attrib_start_pos) + pad_before # where in this array 83 | window_start_pos = snp_center_idx - driver_win 84 | 85 | current_gene_main_var_inserted = current_gene.copy() 86 | current_gene_main_var_inserted[:,snp_center_idx,:] = most_common_alt_onehot 87 | 88 | # from main var 89 | main_var_out = model.predict_on_batch(current_gene_main_var_inserted) 90 | main_var_out = main_var_out['human'][0] 91 | main_var_val = np.sum(main_var_out[center_bins,track_idx]) bin 92 | 93 | np.save(save_path + 'attrib_res/drivers_analysis/from_main_var/' +gene_id + '_' + str(snp) + '_main_var_pred', main_var_val) 94 | 95 | for i in range(driver_win*2+1): 96 | current_onehot = list(current_gene_main_var_inserted[:,window_start_pos+i,:][0]) 97 | 98 | current_nuc_idx = nuc_onehots.index(current_onehot) 99 | driver_res[i,current_nuc_idx] = main_var_val # this nuc represents no change from the ref 100 | rel_nuc_idxs = list(range(4)) 101 | rel_nuc_idxs.remove(current_nuc_idx) 102 | 103 | for idx in rel_nuc_idxs: # the other 3 104 | inserted_seq = current_gene_main_var_inserted.copy() 105 | current_onehot = nuc_onehots[idx] 106 | inserted_seq[:,window_start_pos+i,:] = current_onehot # get current seq 107 | 108 | var_out = model.predict_on_batch(inserted_seq) 109 | var_out = var_out['human'][0] 110 | var_val = np.sum(var_out[center_bins,track_idx]) # use 447,448,and 449 111 | 112 | driver_res[i,idx] = var_val 113 | 114 | np.save(save_path + 'attrib_res/drivers_analysis/from_main_var/' +gene_id + '_' + str(snp), driver_res) 115 | 116 | 117 | if __name__ == '__main__': 118 | 119 | drivers = np.load('genes_to_run/drivers_to_run.npy') 120 | 121 | driver_dict = {} 122 | for item in drivers: 123 | if item[0] not in driver_dict.keys(): 124 | driver_dict[item[0]] = [item[1]] 125 | else: 126 | driver_dict[item[0]].append(item[1]) 127 | 128 | for gene in driver_dict.keys(): 129 | print(gene) 130 | chrom = all_gene_info[all_gene_info['gene_id'] == gene]['chr_hg38'] 131 | get_gene_drivers(gene, chrom) 132 | -------------------------------------------------------------------------------- /enformer_analysis/from_ref_drivers_ISM.py: -------------------------------------------------------------------------------- 1 | # get ISM within 20bp of each driver SNP 2 | # start with the reference sequence, then insert each possible variant 3 | # same as from_main_var_drivers_ISM but without main variant inserted 4 | 5 | import time 6 | import pandas as pd 7 | import numpy as np 8 | import sparse 9 | import os 10 | import tensorflow as tf 11 | import tensorflow_hub as hub 12 | import argparse 13 | 14 | track_idx = 4980 # CAGE:brain, adult, 15 | center_bins = [447,448,449] # three center bins of enformer output 16 | 17 | save_path = '/data/aspiro17/enformer_res/' 18 | model_path = 'https://tfhub.dev/deepmind/enformer/1' 19 | 20 | padded_input_len = 393216 # full input to enformer 21 | input_len = 196608 22 | starting_seq_len = 200001 # starting length of reference seq 23 | pad_before = int((padded_input_len - input_len)/2) # "pad" sequence to padded_input_len 24 | 25 | mid_index = int((starting_seq_len-1)/2) 26 | start_index = int(mid_index - (input_len-1)/2) 27 | end_index = int(mid_index + (input_len-1)/2) + 1 28 | 29 | nuc_order =['A', 'C', 'G', 'T'] 30 | nuc_onehots = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] 31 | 32 | gene_win_info = pd.read_csv('/data/mostafavilab/bng/rosmapAD/data/expressionData/DLPFC/20220207-bulk-RNAseq/gene-ids-and-positions.tsv', sep = '\t', index_col = 1) # contains TSS position info 33 | 34 | driver_win = 20 35 | os.environ["CUDA_VISIBLE_DEVICES"]='0' 36 | 37 | # enformer helper functs 38 | class Enformer: 39 | 40 | def __init__(self, tfhub_url): 41 | self._model = hub.load(tfhub_url).model 42 | 43 | def predict_on_batch(self, inputs): 44 | predictions = self._model.predict_on_batch(inputs) 45 | return {k: v.numpy() for k, v in predictions.items()} 46 | 47 | def get_gene_drivers(gene_id, chrom): 48 | 49 | model = Enformer(model_path) 50 | 51 | snp_info = pd.read_csv('/data/aspiro17/enformer_res/variant_info_100k/' + gene_id + '.csv',header=None,encoding='latin1') 52 | current_gene = np.load(save_path+'ref_seqs/'+gene_id+'.npy') 53 | 54 | starting_seq_len = np.shape(current_gene)[1] 55 | current_tss = int(gene_win_info['tss_hg38'][gene_id]) 56 | attrib_start_pos = current_tss - int(input_len/2) + 1 57 | 58 | # adjust sequence to be input to model 59 | current_gene = np.transpose(current_gene, (1,0)) # transpose to be seq_len x 4 60 | current_gene = current_gene[:, [0,3,2,1]] # go from previous: seqP=='A',seqP=='T',seqP=='G',seqP=='C', to Enformer: 'ACGT' 61 | current_gene = np.reshape(current_gene, (1, starting_seq_len, 4)) # add a 1 dimen 62 | current_gene = np.pad(current_gene, pad_width=((0,0),(pad_before-start_index, pad_before-(starting_seq_len - end_index)), (0,0))) # pad seq 63 | 64 | # from ref 65 | ref_out = model.predict_on_batch(current_gene) 66 | ref_out = ref_out['human'][0] 67 | ref_val = np.sum(ref_out[center_bins,track_idx]) # 448 is center bin 68 | 69 | np.save(save_path + 'attrib_res/drivers_analysis/from_ref/' +gene_id + '_ref_pred', ref_val) 70 | 71 | # get all drivers pos 72 | driver_pos = driver_dict[gene_id] 73 | 74 | for snp in driver_pos: 75 | 76 | snp = int(snp) 77 | driver_res = np.zeros([driver_win*2+1,4]) 78 | ref = snp_info[snp_info[2] == snp][1].values[0] # get the single nuc 79 | 80 | snp_center_idx = (snp - attrib_start_pos) + pad_before # where in this array 81 | window_start_pos = snp_center_idx - driver_win 82 | 83 | for i in range(driver_win*2+1): 84 | current_onehot = list(current_gene[:,window_start_pos+i,:][0]) 85 | 86 | current_nuc_idx = nuc_onehots.index(current_onehot) 87 | driver_res[i,current_nuc_idx] = ref_val # this nuc represents no change from the ref 88 | rel_nuc_idxs = list(range(4)) 89 | rel_nuc_idxs.remove(current_nuc_idx) 90 | 91 | for idx in rel_nuc_idxs: # the other 3 92 | inserted_seq = current_gene.copy() 93 | current_onehot = nuc_onehots[idx] 94 | inserted_seq[:,window_start_pos+i,:] = current_onehot # get current seq 95 | 96 | var_out = model.predict_on_batch(inserted_seq) 97 | var_out = var_out['human'][0] 98 | var_val = np.sum(var_out[center_bins,track_idx]) # use 447,448,and 449 99 | 100 | driver_res[i,idx] = var_val 101 | 102 | np.save(save_path + 'attrib_res/drivers_analysis/from_ref/' +gene_id + '_' + str(snp), driver_res) 103 | 104 | if __name__ == '__main__': 105 | 106 | drivers = np.load('genes_to_run/drivers_to_run.npy') 107 | 108 | driver_dict = {} 109 | for item in drivers: 110 | if item[0] not in driver_dict.keys(): 111 | driver_dict[item[0]] = [item[1]] 112 | else: 113 | driver_dict[item[0]].append(item[1]) 114 | 115 | for gene in driver_dict.keys(): 116 | print(gene) 117 | chrom = all_gene_info[all_gene_info['gene_id'] == gene]['chr_hg38'] 118 | get_gene_drivers(gene, chrom) -------------------------------------------------------------------------------- /enformer_analysis/generate_null.py: -------------------------------------------------------------------------------- 1 | # Generates random Null distribution for absolute correlation between observed and predicted gene expression from set of SNVs for each individual 2 | # For each gene, SNVs get random attribution assigned and the sum of attributions times the SNVs in a person's genome is the predicted gene expression. 3 | # This random process is performed several times and the mean and std of the absolute correlation are computed as a random Null. 4 | # Basic usage: 5 | # python3 generate_null.py 6 | 7 | 8 | import numpy as np 9 | import sys, os 10 | import glob 11 | from scipy.stats import pearsonr 12 | import matplotlib.pyplot as plt 13 | 14 | def correlation(y1, y2, axis = 1, ctype = 'pearson', distance = False): 15 | if ctype == 'spearman': 16 | y1, y2 = np.argsort(np.argsort(y1, axis = axis), axis = axis), np.argsort(np.argsort(y2,axis = axis), axis = axis) 17 | if ctype != 'cosine': 18 | mean1, mean2 = np.mean(y1, axis = axis), np.mean(y2, axis = axis) 19 | y1mean, y2mean = y1-np.expand_dims(mean1,axis = axis), y2-np.expand_dims(mean2,axis = axis) 20 | else: 21 | y1mean, y2mean = y1, y2 22 | n1, n2 = np.sqrt(np.sum(y1mean*y1mean, axis = axis)), np.sqrt(np.sum(y2mean*y2mean, axis = axis)) 23 | n12 = n1*n2 24 | y12 = np.sum(y1mean*y2mean, axis = axis) 25 | if isinstance(y12, float): 26 | if n12/max(n1,n2) < 1e-8: 27 | n12, y12 = 1., -1. 28 | else: 29 | y12[n12/np.amax(np.array([n1,n2]),axis = 0) < 1e-8] = -1. 30 | n12[n12/np.amax(np.array([n1,n2]),axis = 0) < 1e-8] = 1 31 | corout = y12/n12 32 | if distance: 33 | corout = 1.-corout 34 | return np.around(corout,4) 35 | 36 | 37 | # Find all snp_info.npz files 38 | sfiles = np.sort(glob.glob('ENSG*snp_info.npz')) 39 | # Load ID of indivdiuals 40 | inds = np.load(sfiles[0])['columns'] 41 | # Make list of gene names 42 | genes = np.array([s.split('snp_info')[0] for s in sfiles]) 43 | 44 | # Load Observed gene expression data 45 | obs = np.genfromtxt('Observed_gene_expressionENSG.tsv', dtype = str, skip_header = 1) 46 | obsgenes = np.array(open('Observed_gene_expressionENSG.tsv', 'r').readline().strip().split('\t')) 47 | obsind, obs = obs[:,0], obs[:,1:].astype(float) 48 | # Sort observed expression to given genes 49 | osort = np.argsort(obsgenes)[np.isin(np.sort(obsgenes), genes)] 50 | obsgenes, obs = obsgenes[osort], obs[:,osort] 51 | # sort observed expression to given individuals 52 | osort = np.argsort(obsind)[np.isin(np.sort(obsind), inds)] 53 | obsind, obs = obsind[osort], obs[osort] 54 | # sort given inds to to observed expression inds 55 | isort = np.argsort(inds)[np.isin(np.sort(inds), obsind)] 56 | inds = inds[isort] 57 | # check if sorting was successful 58 | print(np.array_equal(obsind, inds), len(obsind), len(inds)) 59 | 60 | np.random.seed(1) 61 | nstat = 50 62 | 63 | # Second file that compares expression to predicted values can be generated 64 | if '--savemeanexp' in sys.argv: 65 | enf = np.genfromtxt('Enformer_predictionsENSG.tsv', dtype = str, skip_header = 1) 66 | enfgenes = np.array(open('Enformer_predictionsENSG.tsv', 'r').readline().strip().split('\t')) 67 | enfind, enf = enf[:,0], enf[:,1:].astype(float) 68 | esort = np.argsort(enfgenes)[np.isin(np.sort(enfgenes), genes)] 69 | enfgenes, enf = enfgenes[esort], enf[:,esort] 70 | esort = np.argsort(enfind)[np.isin(np.sort(enfind), inds)] 71 | enfind, enf = enfind[esort], enf[esort] 72 | print('enformer', np.array_equal(enfind, inds)) 73 | 74 | 75 | corrs = [] # mean std n 76 | exps = [] 77 | for g, gene in enumerate(genes): 78 | 79 | snpfile = np.load(gene+'snp_info.npz') 80 | indv = snpfile['columns'][isort] 81 | if np.array_equal(indv, obsind): 82 | snp, snames = snpfile['snps'], snpfile['rows'] 83 | snp = snp[:,isort] 84 | 85 | rat= np.random.normal(size = (len(snames), nstat)) 86 | # Different option for Null with all SNV weights being positive 87 | if '--positive' in sys.argv: 88 | rat = np.absolute(rat) 89 | ex = np.sum(snp.T[...,None]*rat[None,...], axis = 1) 90 | pears = correlation(ex, obs[:,np.where(obsgenes == gene)[0][[0]]], axis = 0) 91 | pears = np.absolute(np.nan_to_num(pears)) 92 | corrs.append([gene, np.around(np.mean(pears),3), np.around(np.std(pears),5), nstat]) 93 | print(corrs[-1]) 94 | if '--savemeanexp' in sys.argv: 95 | pears = correlation(ex, enf[:,np.where(enfgenes == gene)[0][[0]]], axis = 0) 96 | pears = np.absolute(np.nan_to_num(pears)) 97 | exps.append([gene, np.around(np.mean(pears),3), np.around(np.std(pears),5), nstat]) 98 | print(exps[-1]) 99 | 100 | corrs = np.array(corrs, dtype = str) 101 | if '--savemeanexp' in sys.argv: 102 | exps = np.array(exps, dtype = str) 103 | 104 | if '--positive' in sys.argv: 105 | np.savetxt('GeneSpecific_CorrelationtoObsRandomPosNull.txt', corrs, header = 'Gene MeanR StdR N', fmt = '%s') 106 | if '--savemeanexp' in sys.argv: 107 | np.savetxt('GeneSpecific_CorrelationtoObsRandomPosNulltoEnf.txt', exps, header = 'Gene MeanR StdR N', fmt = '%s') 108 | else: 109 | np.savetxt('GeneSpecific_CorrelationtoObsRandomNull.txt', corrs, header = 'Gene MeanR StdR N', fmt = '%s') 110 | if '--savemeanexp' in sys.argv: 111 | np.savetxt('GeneSpecific_CorrelationtoObsRandomNulltoEnf.txt', exps, header = 'Gene MeanR StdR N', fmt = '%s') 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /enformer_analysis/gradient_attributions.py: -------------------------------------------------------------------------------- 1 | # Save Enformer gradients 2 | # Save (gradient at reference sequence x reference) and (gradient at reference sequence x main variant) for each main variant (most common alternate allele) at each SNP position for a given gene 3 | 4 | import time 5 | import pandas as pd 6 | import numpy as np 7 | import sparse 8 | import os 9 | import tensorflow as tf 10 | import tensorflow_hub as hub 11 | import time as time 12 | import argparse 13 | 14 | track_idx = 4980 # CAGE:brain, adult, 15 | center_bins = [447,448,449] # three center bins of enformer output 16 | 17 | save_path = '/data/aspiro17/enformer_res/' 18 | model_path = 'https://tfhub.dev/deepmind/enformer/1' 19 | 20 | padded_input_len = 393216 # full input to enformer 21 | input_len = 196608 22 | starting_seq_len = 200001 # starting length of reference seq 23 | pad_before = int((padded_input_len - input_len)/2) # "pad" sequence to padded_input_len 24 | 25 | mid_index = int((starting_seq_len-1)/2) 26 | start_index = int(mid_index - (input_len-1)/2) 27 | end_index = int(mid_index + (input_len-1)/2) + 1 28 | 29 | nuc_order =['A', 'C', 'G', 'T'] 30 | nuc_onehots = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] 31 | 32 | gene_win_info = pd.read_csv('/data/mostafavilab/bng/rosmapAD/data/expressionData/DLPFC/20220207-bulk-RNAseq/gene-ids-and-positions.tsv', sep = '\t', index_col = 1) # contains TSS position info 33 | # enformer helper functs 34 | class Enformer: 35 | 36 | def __init__(self, tfhub_url): 37 | self._model = hub.load(tfhub_url).model 38 | 39 | def predict_on_batch(self, inputs): 40 | predictions = self._model.predict_on_batch(inputs) 41 | return {k: v.numpy() for k, v in predictions.items()} 42 | 43 | @tf.function 44 | def contribution_input_grad(self, input_sequence, 45 | target_mask, output_head='human'): 46 | input_sequence = input_sequence[tf.newaxis] 47 | 48 | target_mask_mass = tf.reduce_sum(target_mask) 49 | with tf.GradientTape() as tape: 50 | tape.watch(input_sequence) 51 | prediction = tf.reduce_sum( 52 | target_mask[tf.newaxis] * 53 | self._model.predict_on_batch(input_sequence)[output_head]) / target_mask_mass 54 | grad = tape.gradient(prediction, input_sequence) 55 | return grad 56 | 57 | 58 | def get_ref_attribs(gene_id, chrom): 59 | 60 | model = Enformer(model_path) 61 | 62 | snp_info = pd.read_csv('/data/aspiro17/enformer_res/variant_info_100k/' + gene_id + '.csv',header=None,encoding='latin1') 63 | 64 | most_common_alts = [] 65 | current_gene = np.load(save_path+'ref_seqs/'+gene_id+'.npy') 66 | 67 | starting_seq_len = np.shape(current_gene)[1] 68 | current_tss = int(gene_win_info['tss_hg38'][gene_id]) 69 | attrib_start_pos = current_tss - int(input_len/2) + 1 70 | 71 | snp_pos = np.load(save_path + 'snp_positions/' + gene_id + '_snp_positions.npy') 72 | 73 | # adjust sequence to be input to model 74 | current_gene = np.transpose(current_gene, (1,0)) # transpose to be seq_len x 4 75 | current_gene = current_gene[:, [0,3,2,1]] # go from previous: seqP=='A',seqP=='T',seqP=='G',seqP=='C', to Enformer: 'ACGT' 76 | current_gene = np.reshape(current_gene, (1, starting_seq_len, 4)) # add a 1 dimen 77 | current_gene = np.pad(current_gene, pad_width=((0,0),(pad_before-start_index, pad_before-(starting_seq_len - end_index)), (0,0))) # pad seq 78 | 79 | inserted_seq = current_gene.copy() 80 | 81 | for snp in snp_pos: 82 | 83 | ref = snp_info[snp_info[2] == snp][1].values[0] 84 | 85 | # get most common alt 86 | sel_row = snp_info[snp_info[2] == snp].iloc[:,3:] # 3 on 87 | sel_row = sel_row.astype(str).values.flatten().tolist() 88 | sel_row = ' '.join(sel_row) 89 | counts = [sel_row.count('A'), sel_row.count('C'), sel_row.count('G'), sel_row.count('T')] 90 | counts[nuc_order.index(ref)] = 0 # we want the most common allele that's not the ref 91 | alt_idx = np.argmax(counts) # now take the max 92 | allele_to_insert = nuc_onehots[alt_idx] 93 | 94 | most_common_alt = nuc_order[alt_idx] 95 | most_common_alts.append(most_common_alt) 96 | 97 | idx_to_insert = (snp - attrib_start_pos) + pad_before 98 | inserted_seq[:,idx_to_insert,:] = allele_to_insert 99 | 100 | np.save(save_path + 'most_common_alt_alleles/' +gene_id + '_alt_alleles', most_common_alts) 101 | 102 | # target mask to get relevant gradient 103 | target_mask = np.zeros([896,5313], dtype='float32') 104 | for idx in [447, 448, 449]: # central 3 bins 105 | target_mask[idx, track_idx] = 1 106 | 107 | # get ref grad 108 | ref_grad = model.contribution_input_grad(current_gene[0,:,:].astype(np.float32), target_mask).numpy() 109 | np.save(save_path + 'attrib_res/ref_attribs/' + gene_id + '_complete_grad_at_ref', ref_grad) 110 | 111 | # get (grad at ref) x (ref) 112 | grad_at_ref_times_ref = ref_grad * current_gene[0,:,:] 113 | grad_at_ref_times_ref = tf.squeeze(grad_at_ref_times_ref, axis=0) 114 | grad_at_ref_times_ref = tf.reduce_sum(grad_at_ref_times_ref, axis=-1).numpy() 115 | 116 | # get (grad at ref) x (var) 117 | grad_at_ref_times_var = ref_grad * inserted_seq 118 | grad_at_ref_times_var = tf.squeeze(grad_at_ref_times_var, axis=0) 119 | grad_at_ref_times_var = tf.reduce_sum(grad_at_ref_times_var, axis=-1).numpy() 120 | 121 | grad_times_ref_vals = [] 122 | grad_times_var_vals = [] 123 | for snp in snp_pos: 124 | adjusted_pos = (snp - attrib_start_pos) + pad_before # where in the array 125 | grad_times_ref_vals.append(grad_at_ref_times_ref[adjusted_pos]) 126 | grad_times_var_vals.append(grad_at_ref_times_var[adjusted_pos]) 127 | 128 | np.save(save_path + 'attrib_res/ref_attribs/' + gene_id + '_grad_at_ref_times_ref', grad_times_ref_vals) 129 | np.save(save_path + 'attrib_res/ref_attribs/' + gene_id + '_grad_at_ref_times_var', grad_times_var_vals) 130 | 131 | 132 | if __name__ == '__main__': 133 | 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument('which_gpu', action="store", default='0') 136 | parser.add_argument('gene_file', action="store", default='test') 137 | 138 | args = parser.parse_args() 139 | 140 | os.environ["CUDA_VISIBLE_DEVICES"]=args.which_gpu 141 | genes = np.load(args.gene_file) 142 | 143 | for gene in genes: 144 | print(gene) 145 | chrom = all_gene_info[all_gene_info['gene_id'] == gene]['chr_hg38'] 146 | get_ref_attribs(gene, chrom) 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /enformer_analysis/per_SNP_ISM.py: -------------------------------------------------------------------------------- 1 | # Get ISM results for Enformer on a given gene set 2 | # For this analysis, get Enformer output for reference and Enformer output for inserting the most common variant for each SNP position in a gene 3 | 4 | import time 5 | import pandas as pd 6 | import numpy as np 7 | import sparse 8 | import os 9 | import tensorflow as tf 10 | import tensorflow_hub as hub 11 | import argparse 12 | 13 | track_idx = 4980 # CAGE:brain, adult, 14 | center_bins = [447,448,449] # three center bins of enformer output 15 | 16 | save_path = '/data/aspiro17/enformer_res/' 17 | model_path = 'https://tfhub.dev/deepmind/enformer/1' 18 | 19 | padded_input_len = 393216 # full input to enformer 20 | input_len = 196608 21 | starting_seq_len = 200001 # starting length of reference seq 22 | pad_before = int((padded_input_len - input_len)/2) # "pad" sequence to padded_input_len 23 | 24 | mid_index = int((starting_seq_len-1)/2) 25 | start_index = int(mid_index - (input_len-1)/2) 26 | end_index = int(mid_index + (input_len-1)/2) + 1 27 | 28 | nuc_order =['A', 'C', 'G', 'T'] 29 | nuc_onehots = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] 30 | 31 | gene_win_info = pd.read_csv('/data/mostafavilab/bng/rosmapAD/data/expressionData/DLPFC/20220207-bulk-RNAseq/gene-ids-and-positions.tsv', sep = '\t', index_col = 1) # contains TSS position info 32 | 33 | 34 | # enformer helper functs 35 | class Enformer: 36 | 37 | def __init__(self, tfhub_url): 38 | self._model = hub.load(tfhub_url).model 39 | 40 | def predict_on_batch(self, inputs): 41 | predictions = self._model.predict_on_batch(inputs) 42 | return {k: v.numpy() for k, v in predictions.items()} 43 | 44 | def run_ism(gene_id, chrom): 45 | # for a current gene_id (ENSG) and chromosome, save ISM results 46 | model = Enformer(model_path) 47 | 48 | snp_info = pd.read_csv('/data/aspiro17/enformer_res/variant_info_100k/' + gene_id + '.csv',header=None,encoding='latin1') # variant info for this gene 49 | current_gene = np.load(save_path+'ref_seqs/'+gene_id+'.npy') # reference sequence for this gene 50 | 51 | most_common_alts = [] # save most common alternate alleles for future analysis 52 | var_vals = [] # save output of enformer for inserting a given var (for each var position) 53 | 54 | starting_seq_len = np.shape(current_gene)[1] 55 | current_tss = int(gene_win_info['tss_hg38'][gene_id]) 56 | attrib_start_pos = current_tss - int(input_len/2) + 1 57 | 58 | # get snp positions for this gene (within the 100k window) 59 | all_snp_pos = list(snp_info[2]) 60 | snp_pos = [] 61 | for snp in all_snp_pos: 62 | adjusted_pos = snp - attrib_start_pos # make the attrib start pos like 0 63 | if adjusted_pos >= 0 and adjusted_pos <= input_len: # in range 64 | snp_pos.append(snp) 65 | np.save(save_path + 'snp_positions/' + gene_id + '_snp_positions', snp_pos) 66 | 67 | # adjust sequence to be input to model 68 | current_gene = np.transpose(current_gene, (1,0)) # transpose to be seq_len x 4 69 | current_gene = current_gene[:, [0,3,2,1]] # go from previous: seqP=='A',seqP=='T',seqP=='G',seqP=='C', to Enformer: 'ACGT' 70 | current_gene = np.reshape(current_gene, (1, starting_seq_len, 4)) # add a 1 dimen 71 | current_gene = np.pad(current_gene, pad_width=((0,0),(pad_before-start_index, pad_before-(starting_seq_len - end_index)), (0,0))) # pad seq 72 | 73 | ref_out = model.predict_on_batch(current_gene) 74 | ref_out = ref_out['human'][0] 75 | ref_val = np.sum(ref_out[center_bins,track_idx]) # sum over center 3 bins 76 | 77 | np.save(save_path + 'attrib_res/ism_res/' +gene_id + '_ref_pred', ref_val) 78 | 79 | for snp in snp_pos: 80 | print(snp) 81 | 82 | inserted_seq = current_gene.copy() 83 | ref = snp_info[snp_info[2] == snp][1].values[0] # single ref nuc 84 | 85 | # get most common alt 86 | sel_row = snp_info[snp_info[2] == snp].iloc[:,3:] # 3 on 87 | sel_row = sel_row.astype(str).values.flatten().tolist() 88 | sel_row = ' '.join(sel_row) 89 | counts = [sel_row.count('A'), sel_row.count('C'), sel_row.count('G'), sel_row.count('T')] 90 | counts[nuc_order.index(ref)] = 0 # we want the most common allele that's not the ref 91 | alt_idx = np.argmax(counts) # now take the max 92 | allele_to_insert = nuc_onehots[alt_idx] 93 | 94 | most_common_alt = nuc_order[alt_idx] 95 | most_common_alts.append(most_common_alt) 96 | 97 | idx_to_insert = (snp - attrib_start_pos) + pad_before # where to insert variant 98 | 99 | print(ref, current_gene[0,idx_to_insert,:]) # make sure the reference value is the same as what we find at this idx in the one-hot 100 | 101 | inserted_seq[:,idx_to_insert,:] = allele_to_insert # insert 102 | 103 | var_out = model.predict_on_batch(inserted_seq) 104 | var_out = var_out['human'][0] 105 | var_val = np.sum(var_out[center_bins,track_idx]) 106 | 107 | var_vals.append(var_val) 108 | 109 | np.save(save_path + 'attrib_res/ism_res/' +gene_id + '_var_preds', var_vals) 110 | np.save(save_path + 'attrib_res/ism_res/' +gene_id + '_alt_alleles', most_common_alts) 111 | 112 | 113 | if __name__ == '__main__': 114 | 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('which_gpu', action="store", default='0') 117 | parser.add_argument('gene_file', action="store", default='test') 118 | args = parser.parse_args() 119 | 120 | os.environ["CUDA_VISIBLE_DEVICES"]=args.which_gpu 121 | genes = np.load(args.gene_file) 122 | 123 | for gene in genes: 124 | print(gene) 125 | chrom = gene_win_info[gene_win_info['gene_id'] == gene]['chr_hg38'] 126 | run_ism(gene, chrom) -------------------------------------------------------------------------------- /enformer_analysis/plot_attribution_alongsequence.py: -------------------------------------------------------------------------------- 1 | # Plot gradient plot for GSTM3 with drivers and other variants 2 | #python3 ../plot_attribution_alongsequence.py ENSG00000134202 0 ../ism_res/ENSG00000134202_ism_attributions.txt ../variant_info_100k/ENSG00000134202_frequency.txt ../ism_res/ENSG00000134202_ism_attributions_driversfw.txt ../ism_res/ENSG00000134202_ism_attributions_driversfw_types.txt --savefig 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import logomaker as lm 7 | import sys, os 8 | import pandas as pd 9 | 10 | # Gene name 11 | gene = sys.argv[1] 12 | # Plot array between position [window:-window] 13 | window = int(sys.argv[2]) 14 | # Attributions and position of SNVs 15 | varatts = np.genfromtxt(sys.argv[3], dtype = float) 16 | # Population frequency of SNVs 17 | popfreq = np.genfromtxt(sys.argv[4], dtype = float) 18 | # align driver attributions and population frequency of snp 19 | varatts, popfreq = varatts[np.argsort(varatts[:,0])[np.isin(np.sort(varatts[:,0]),popfreq[:,0])]], popfreq[np.argsort(popfreq[:,0])[np.isin(np.sort(popfreq[:,0]),varatts[:,0])]] 20 | 21 | # normalize by maxmum attribution of any driver 22 | maxvaratt = np.amax(np.absolute(varatts[:,-1])) 23 | varatts[:,-1]/maxvaratt 24 | 25 | # List of drivers for the gene 26 | driverlist = open(sys.argv[5], 'r').readlines() 27 | driverlist = np.array([line.strip().split() for line in driverlist]) 28 | 29 | # Drivertype for the gene 30 | drivertype = open(sys.argv[6], 'r').readlines() # to select and assign type 31 | drivertype = np.array([line.strip().split() for line in drivertype]) 32 | # align driverlist with driver type 33 | driverlist = driverlist[np.argsort(driverlist[:,0])[np.isin(np.sort(driverlist[:,0]), drivertype[:,0])]] 34 | drivertype = drivertype[np.argsort(drivertype[:,0])[np.isin(np.sort(drivertype[:,0]), driverlist[:,0])]] 35 | 36 | loci = driverlist[:,0].astype(int) 37 | attlen = driverlist[:,1].astype(float) 38 | attlen = attlen/maxvaratt 39 | 40 | # remove drivers from varatts 41 | varatts = varatts[~np.isin(varatts[:,0],loci)] 42 | drfreq = popfreq[np.isin(popfreq[:,0],loci),1] 43 | popfreq = popfreq[~np.isin(popfreq[:,0],loci)] 44 | 45 | 46 | allctypes = drivertype[:,-2:].astype(float) 47 | bitypes = np.zeros((len(drivertype),2)) 48 | for c, ctypes in enumerate(allctypes): 49 | if (ctypes[0]<0 and ctypes[1]<0) or (ctypes[0]>0 and ctypes[1]>0): 50 | bitypes[c,0] = 1 51 | elif ctypes[0]>0 and ctypes[1]<0: 52 | bitypes[c,1] = 1 53 | elif ctypes[0]<0 and ctypes[1]>0: 54 | bitypes[c,1] = 1 55 | 56 | tsslist = np.genfromtxt('../geneTSS.txt', dtype = str) 57 | 58 | nts = list('ACGT') 59 | ntsar = np.array(nts) 60 | 61 | 62 | # Load gradient and one-hot encoded sequence for gene 63 | grad = np.load('../gradient_tensors/'+gene+'_complete_grad_at_ref.npy')[0] 64 | genehot = np.load('../ref_seqs/' +gene+'.npy').T 65 | genehot = genehot[:, [0,3,2,1]] 66 | 67 | lengene = np.shape(genehot)[0] 68 | 69 | # determine the center of the gradient and the sequence file 70 | centergrad = int(np.shape(grad)[0]/2) -1 71 | centergene = int(np.shape(genehot)[0]/2)-1 72 | # adjust the size of the gradient to the size of the sequence 73 | offset = centergrad - centergene 74 | grad = grad[offset:offset+lengene] 75 | 76 | # find tss for the gene 77 | tss = int(tsslist[list(tsslist[:,0]).index(gene),1]) 78 | # determine locations of drivers in the array 79 | varintens = centergene + loci - tss 80 | # determine locations of all other variants in the array 81 | locvars = centergene + varatts[:,0].astype(int) - tss 82 | 83 | grad /= np.amax(np.absolute(grad)) 84 | 85 | grad, genehot, varintens, centergene, locvars = grad[window:lengene-window], genehot[window:lengene-window], varintens - window, centergene - window, locvars - window 86 | # remove variants that are located outside the sequence, ie. smaller than 0 87 | varatts, popfreq, locvars = varatts[locvars>0], popfreq[locvars>0], locvars[locvars>0] 88 | 89 | # Compute mean attribution for changing ref base to any other base 90 | attribution = grad[genehot==1] - (np.sum(grad,axis = 1)-grad[genehot==1])/3 91 | 92 | 93 | # compute mean and std with sliding windows 94 | stepsize = 32 95 | wdsize = 128 96 | stdattribution = np.array([np.std(attribution[i:i+wdsize]) for i in range(0,len(attribution)-wdsize,stepsize)]) 97 | meanattribution = np.array([np.mean(attribution[i:i+wdsize]) for i in range(0,len(attribution)-wdsize,stepsize)]) 98 | stdx = np.arange(stepsize/2,len(attribution)-wdsize+stepsize/2,stepsize) 99 | 100 | fig = plt.figure(figsize = (15,1.5)) 101 | ax = fig.add_subplot(111) 102 | ax.spines['top'].set_visible(False) 103 | ax.spines['right'].set_visible(False) 104 | ax.fill_between(np.arange(len(attribution)), attribution, color = 'grey', label = 'Gradient attribution') 105 | ax.plot([centergene, centergene],[np.amin(attribution),np.amax(attribution)], color = 'goldenrod', ls = ':') 106 | if '--savefig' not in sys.argv: 107 | #ax.plot(np.arange(len(attribution)), attribution, color = 'grey', marker = '.') 108 | ax.scatter(np.arange(len(attribution)), attribution, color = 'grey', marker = '.') 109 | ax.plot([0,len(attribution)], [0.,0], color = 'k') 110 | ax.plot(stdx, stdattribution, color = 'navy', lw = 0.5) 111 | ax.plot(stdx, meanattribution, color = 'orange', lw = 0.5) 112 | ax.scatter(locvars, varatts[:,1], c = popfreq[:,1], vmin = 0, vmax = 1, cmap = 'Blues', edgecolor='silver') 113 | else: 114 | ax.scatter(locvars, varatts[:,1], alpha = popfreq[:,1], vmin = 0, vmax = 1, color = 'b', edgecolor='none', marker = '.', label = 'SNPs') 115 | 116 | maxdriver = np.argmax(driverlist[:,-3].astype(float)) 117 | for v, lo in enumerate(varintens): 118 | bounds = [0, attlen[v]] 119 | if types[v,0] == 1: 120 | marker = 's' 121 | elif types[v,1] == 1: 122 | marker = '^' 123 | elif types[v,2] == 1: 124 | marker = 'v' 125 | colr = 'red' 126 | label = None 127 | if v == maxdriver: 128 | colr = 'magenta' 129 | label = 'Main driver' 130 | # ax.plot([lo, lo],bounds, color = colr) 131 | ax.scatter([lo],[bounds[1]], cmap = 'Blues', vmin = 0, vmax = 1, c = drfreq[v], edgecolor = colr, marker = marker, label = label) 132 | 133 | ax.set_xticks([centergene-40000,centergene,centergene+40000]) 134 | ax.set_xticklabels([-40000,'TSS',40000], rotation = 0) 135 | ax.grid(color = 'grey', axis = 'y') 136 | ax.set_ylabel('Attribution') 137 | ax.set_xlim([0,len(attribution)]) 138 | #ax.legend() 139 | 140 | if '--savefig' in sys.argv: 141 | fig.savefig(os.path.splitext(os.path.split(sys.argv[1])[1])[0]+'_ingeneseq.jpg', dpi = 200, bbox_inches = 'tight') 142 | print(os.path.splitext(os.path.split(sys.argv[1])[1])[0]+'_ingeneseq.jpg') 143 | 144 | ax.set_xticks([centergene-20000,centergene,centergene+20000]) 145 | ax.set_xticklabels([-20000,'TSS',20000],rotation = 0) 146 | ax.set_xlim([centergene-30000, centergene+30000]) 147 | fig.savefig(os.path.splitext(os.path.split(sys.argv[1])[1])[0]+'_ingeneseqcenter.jpg', dpi = 200, bbox_inches = 'tight') 148 | 149 | 150 | #ax.plot([0,len(attribution)], [0.,0], color = 'k') 151 | ax.set_xticks(np.append([centergene],varintens)) 152 | ax.set_xticklabels(np.append(['TSS'],loci.astype(str)), rotation = 90) 153 | ax.scatter(np.arange(len(attribution)), attribution, color = 'grey', marker = '.') 154 | #ax.plot([0,len(attribution)], [0.,0], color = 'k') 155 | #ax.plot(stdx, stdattribution, color = 'blue', lw = 1) 156 | #ax.plot(stdx, meanattribution, color = 'orange', lw = 1) 157 | ax.scatter(locvars, varatts[:,1], c = popfreq[:,1], vmin = 0, vmax = 1, cmap = 'Blues', edgecolor='k') 158 | for v, lo in enumerate(varintens): 159 | bounds = [0, attlen[v]] 160 | if types[v,0] == 1: 161 | marker = 's' 162 | elif types[v,1] == 1: 163 | marker = '^' 164 | elif types[v,2] == 1: 165 | marker = 'v' 166 | colr = 'red' 167 | if v == maxdriver: 168 | colr = 'magenta' 169 | #ax.plot([lo, lo],bounds, color = colr) 170 | ax.scatter([lo],[bounds[1]], cmap = 'Blues', vmin = 0, vmax = 1, c = drfreq[v], edgecolor = colr, marker = marker) 171 | 172 | ax.set_xlim([np.amin(varintens)-100, np.amax(varintens)+100]) 173 | fig.savefig(os.path.splitext(os.path.split(sys.argv[1])[1])[0]+'_ingeneseqclose.jpg', dpi = 200, bbox_inches = 'tight') 174 | print(os.path.splitext(os.path.split(sys.argv[1])[1])[0]+'_ingeneseqclose.jpg') 175 | 176 | ax.set_xticks([centergene-100,centergene,centergene+100]) 177 | ax.set_xticklabels([-100,'TSS',100],rotation = 0) 178 | ax.set_xlim([centergene-150, centergene+150]) 179 | fig.savefig(os.path.splitext(os.path.split(sys.argv[1])[1])[0]+'_ingeneseqtss.jpg', dpi = 200, bbox_inches = 'tight') 180 | print(os.path.splitext(os.path.split(sys.argv[1])[1])[0]+'_ingeneseqtss.jpg') 181 | else: 182 | plt.show() 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /enformer_analysis/plot_common_motifs.py: -------------------------------------------------------------------------------- 1 | # Performs statistical test for sequence patterns to be enriched for unsupported or supported snp 2 | #python3 ../plot_common_motifs.py ALLgenes_ism_attributions_driversfw_types.txt ALLgenes_ism_attributions_driversfw_refseq_winsize13_clust_ms5-1_complete0.25 --savefig ALLgenes_ism_attributions_driversfw_refseq_winsize13_clust_ms5-1_complete0.2 3 | 4 | import numpy as np 5 | import sys, os 6 | import matplotlib.pyplot as plt 7 | import pandas as pd 8 | from statsmodels.stats.multitest import multipletests 9 | import logomaker as lm 10 | 11 | def read_meme(mfile): 12 | obj = open(mfile, 'r').readlines() 13 | pwms = [] 14 | pwmnames = [] 15 | pwm = [] 16 | t = -1000 17 | for l, line in enumerate(obj): 18 | line = line.strip().split() 19 | if len(line) > 0: 20 | if line[0] == 'ALPHABET=': 21 | nts = list(line[1]) 22 | if line[0] == 'MOTIF': 23 | pwmnames.append(line[1].split('_')[1]) 24 | t = 0 25 | t += 1 26 | if t > 2: 27 | pwm.append(line) 28 | else: 29 | if len(pwm) > 0: 30 | pwm = np.array(pwm, dtype = float) 31 | #pwm = np.log2(pwm/0.25) 32 | #pwm[pwm<0] = 0 33 | pwms.append(pd.DataFrame(data = np.array(pwm, dtype = float), columns = nts)) 34 | pwm = [] 35 | if len(pwm) > 0: 36 | pwm = np.array(pwm, dtype = float) 37 | #pwm = np.log2(pwm/0.25) 38 | #pwm[pwm<0] = 0 39 | pwms.append(pd.DataFrame(data = pwm, columns = nts)) 40 | return pwms, pwmnames 41 | 42 | # Read driver type 43 | drivertype = np.genfromtxt(sys.argv[1], dtype = str) # to select and assign type 44 | 45 | # select gene set 46 | if '--geneset' in sys.argv: 47 | geneset = np.genfromtxt(sys.argv[sys.argv.index('--geneset')+1], dtype = str) 48 | dmask = np.isin(drivertype[:,0], geneset) 49 | drivertype = drivertype[dmask] 50 | 51 | allctypes = drivertype[:,-2:].astype(float) 52 | drivers = np.array([d[0]+d[1] for d in drivertype]) 53 | mask = np.argsort(drivers) 54 | drivers, allctypes = drivers[mask], allctypes[mask] 55 | bitypes = np.zeros((len(drivers),2)) 56 | for c, ctypes in enumerate(allctypes): 57 | #print(ctypes) 58 | if (ctypes[0]<0 and ctypes[1]<0) or (ctypes[0]>0 and ctypes[1]>0): 59 | bitypes[c,0] = 1 60 | elif ctypes[0]>0 and ctypes[1]<0: 61 | bitypes[c,1] = 1 62 | elif ctypes[0]<0 and ctypes[1]>0: 63 | bitypes[c,1] = 1 64 | 65 | # Start of file names that contain the information about the clustered sequences 66 | clufile = sys.argv[2] 67 | 68 | if '--savefig' in sys.argv: 69 | outname = sys.argv[sys.argv.index('--savefig')+1] 70 | 71 | # Read pwms of clusters 72 | pwms, pwmnames = read_meme(clufile+'_clusterpwms.txt') 73 | 74 | # read cluster assignments 75 | clusters = np.genfromtxt(clufile+'_clusters.txt', dtype = str) 76 | clusternames, clusters = clusters[:,:2], clusters[:,2] 77 | clusternames = np.array([d[0]+d[1] for d in clusternames]) 78 | 79 | # read location of snps in these pwms 80 | snploc = np.genfromtxt(clufile+'_locclpwms.txt', dtype = str) 81 | snpname, loc = snploc[:,:2], snploc[:,2].astype(int) 82 | snpname = np.array([d[0]+d[1] for d in snpname]) 83 | 84 | if not np.array_equal(snpname, clusternames): 85 | print('CLUSTERFiles are different') 86 | sys.exit() 87 | # align clusternames and snpnames to drivers 88 | mask = np.argsort(snpname)[np.isin(np.sort(snpname), drivers)] 89 | clusters, clusternames, loc, snpname = clusters[mask], clusternames[mask], loc[mask], snpname[mask] 90 | 91 | #align drivers to clusternames 92 | mask = np.argsort(drivers)[np.isin(np.sort(drivers), snpname)] 93 | drivers, types, bitypes = drivers[mask], types[mask], bitypes[mask] 94 | print(np.array_equal(snpname, drivers), len(snpname), len(drivers)) 95 | 96 | cnames, cnum = np.unique(clusters, return_counts=True) 97 | print('N clusters', len(cnames), 'in', len(clusters)) 98 | from scipy.stats import fisher_exact 99 | pvalsneg = [] 100 | pvalspos = [] 101 | totalsize = [] 102 | # for each cluster, use fisher exact test to compute enrichment of supported and unsupported snps 103 | for c, cna in enumerate(cnames): 104 | table = [[np.sum((clusters == cna) * (bitypes[:,1] == 1)), np.sum((clusters == cna) * (bitypes[:,1] == 0))],[np.sum((clusters != cna) * (bitypes[:,1] == 1)), np.sum((clusters != cna) * (bitypes[:,1] == 0))]] 105 | oddspos, p_valuepos = fisher_exact(table, 'less') # pvalue that positives are enriched 106 | oddsneg, p_valueneg = fisher_exact(table, 'greater') # pvalue that negatives are enriched 107 | pvalsneg.append(p_valueneg) 108 | pvalspos.append(p_valuepos) 109 | totalsize.append(np.sum(clusters == cna)) 110 | 111 | 112 | totalsize = np.array(totalsize) 113 | pvalsneg = np.array(pvalsneg) 114 | pvalspos = np.array(pvalspos) 115 | 116 | # adjsut p-value with benjamini hochberg 117 | ylabel = 'p_value' 118 | if '--adjust_pvalue' in sys.argv: 119 | is_, pvalsneg, a_, b_ = multipletests(pvalsneg, method='fdr_bh') 120 | is_, pvalspos, a_, b_ = multipletests(pvalspos, method='fdr_bh') 121 | ylabel = 'p_value_corrected_BH' 122 | pvalsneg = -np.log10(pvalsneg) 123 | pvalspos = -np.log10(pvalspos) 124 | 125 | sortneg = np.argsort(pvalsneg) 126 | sortpos = np.argsort(-pvalspos) 127 | 128 | fig = plt.figure(figsize = (6,4)) 129 | ax = fig.add_subplot(111) 130 | ax.spines['top'].set_visible(False) 131 | ax.spines['right'].set_visible(False) 132 | ax.scatter(np.arange(len(pvalsneg)), -pvalsneg[sortneg], s = totalsize[sortneg] * 5, label = 'unsupported') 133 | ax.scatter(np.arange(len(pvalspos)), pvalspos[sortpos], s = totalsize[sortpos] * 5, label = 'supported') 134 | ax.set_ylabel(ylabel) 135 | ax.set_xlabel('Motif cluster') 136 | ax.plot([0,len(pvalsneg)], [0,0],color = 'k') 137 | ax.plot([0,len(pvalsneg)], [-np.log10(0.01),-np.log10(0.01)],color = 'grey',ls = '--') 138 | ax.plot([0,len(pvalsneg)], [-np.log10(0.05),-np.log10(0.05)],color = 'grey',ls = '--') 139 | ax.plot([0,len(pvalsneg)], [-np.log10(0.05/len(pvalsneg)),-np.log10(0.05/len(pvalsneg))],color = 'red',ls = '--') 140 | ax.plot([0,len(pvalsneg)], [np.log10(0.05),np.log10(0.05)],color = 'grey',ls = '--') 141 | ax.plot([0,len(pvalsneg)], [np.log10(0.01),np.log10(0.01)],color = 'grey',ls = '--') 142 | ax.plot([0,len(pvalsneg)], [np.log10(0.05/len(pvalsneg)),np.log10(0.05/len(pvalsneg))],color = 'red',ls = '--') 143 | ax.set_yticks([-np.log10(0.05/len(pvalsneg)), -np.log10(0.01),-np.log10(0.05),0,np.log10(0.05),np.log10(0.01),np.log10(0.05/len(pvalsneg))]) 144 | ax.set_yticklabels(['FDR 0.05', '0.01', '0.05', '0', '0.05', '0.01', 'FDR 0.05']) 145 | ax.legend() 146 | if '--savefig' in sys.argv: 147 | fig.savefig(outname + '_clustenrich.jpg', dpi = 200, bbox_inches = 'tight') 148 | else: 149 | plt.show() 150 | 151 | fig2 = plt.figure(figsize = (4,4)) 152 | # plot distribution of clustersizes 153 | ax2 = fig2.add_subplot(111) 154 | ax2.spines['top'].set_visible(False) 155 | ax2.spines['right'].set_visible(False) 156 | bins = np.arange(np.amax(cnum)+2) 157 | ax2.hist(cnum, bins = bins) 158 | ax2.set_xlabel('Cluster size') 159 | if '--savefig' in sys.argv: 160 | fig2.savefig(outname + '_clustsizedist.jpg', dpi = 200, bbox_inches = 'tight') 161 | else: 162 | plt.show() 163 | 164 | 165 | # plot pwms of most significant clusters 166 | for s in sortneg[::-1][:20]: 167 | clu = cnames[s] 168 | clumask = clusters == clu 169 | pwm = pwms[pwmnames.index(clu)] 170 | ploc = loc[clumask] 171 | snpdir = bitypes[clumask,0] - bitypes[clumask, 1] 172 | dcounts = [] 173 | for sd in [-1,1]: 174 | dcounts.append(np.unique(ploc[snpdir == sd], return_counts = True)) 175 | figp = plt.figure(figsize = (len(pwm)*0.5,2)) 176 | axp = figp.add_subplot(111) 177 | axp.spines['top'].set_visible(False) 178 | axp.spines['right'].set_visible(False) 179 | lm.Logo(pwm, ax = axp) 180 | scale = -0.05 181 | axp.set_title('logp; '+str(round(pvalsneg[s],3))+'; numvar:'+str(int(cnum[s]))) 182 | axp.bar(dcounts[0][0],scale* dcounts[0][1], label = 'False', width = 0.4) 183 | bottom = np.zeros(len(dcounts[1][0])) 184 | bottom[np.isin(dcounts[1][0], dcounts[0][0])] = dcounts[0][1][np.isin(dcounts[0][0],dcounts[1][0])] 185 | axp.bar(dcounts[1][0], scale*dcounts[1][1], bottom = scale*bottom, label = 'Correct', width = 0.4) 186 | min0, min1 = 0,0 187 | if len(dcounts[0][1]) > 0: 188 | min0 = np.amin(scale*dcounts[0][1]) 189 | if len(dcounts[1][1]) > 0: 190 | min1 = np.amin(scale*dcounts[1][1]+bottom) 191 | 192 | axp.set_ylim([min(min0,min1)+scale, 1]) 193 | axp.set_yticks([min(min0,min1), 0.5,1]) 194 | axp.set_yticklabels([int(min(min0,min1)/scale), 0.5, 1]) 195 | if '--savefig' in sys.argv: 196 | figp.savefig(outname + '_clust'+str(clu)+'.jpg', dpi = 200, bbox_inches = 'tight') 197 | print(outname + '_clust'+str(clu)+'.jpg') 198 | else: 199 | plt.show() 200 | plt.close() 201 | 202 | -------------------------------------------------------------------------------- /enformer_analysis/plot_distribution_enformer_correlations.py: -------------------------------------------------------------------------------- 1 | # Plot the distribution of genes with a SuSie assigned SNP versus the ones without 2 | #python3 plot_distribution_enformer_correlations.py Full_analysis/Prediction_correlationsCageAdultBrain.txt susie_SNP_gene_CortexENSG.txt 3 | 4 | import numpy as np 5 | import sys, os 6 | import glob 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | def vplot(corrs, ylabel = None, xlabel = None, names = None, outname = None): 11 | medians = [] 12 | quartile1 = [] 13 | quartile3 = [] 14 | for f, corr in enumerate(corrs): 15 | medians.append(np.median(corr)) 16 | quartile1.append(np.percentile(corr, 25)) 17 | quartile3.append(np.percentile(corr, 75)) 18 | fig = plt.figure(figsize = (len(corrs)*0.8,3.5)) 19 | ax = fig.add_subplot(111) 20 | ax.spines['top'].set_visible(False) 21 | ax.spines['right'].set_visible(False) 22 | sns.violinplot(data = corrs, ax = ax, width = 1., alphas = 0.5, palette = 'magma', colors = [0.1, 1.2], vmin = 0, vmax = 2, cut = 0) 23 | #sns.swarmplot(data = corrs, ax = ax, color = 'k', size = 3, zorder = 2) 24 | #ax.vlines(np.arange(len(corrs)), quartile1, quartile3, color='silver', linestyle='-', lw=1, zorder = 3) 25 | #ax.scatter(np.arange(len(corrs)), medians, color = 'silver',zorder = 3) 26 | ax.set_xticks(np.arange(len(corrs))) 27 | if names is not None: 28 | ax.set_xticklabels(names) 29 | if ylabel is not None: 30 | ax.set_ylabel(ylabel) 31 | if xlabel is not None: 32 | ax.set_xlabel(xlabel) 33 | if outname is not None: 34 | fig.savefig(outname, transparent = True, dpi = 500, bbox_inches = 'tight') 35 | print(outname) 36 | else: 37 | plt.show() 38 | plt.close() 39 | 40 | 41 | if __name__ == '__main__': 42 | # Read correlation to observed data 43 | corr = np.genfromtxt(sys.argv[1], dtype = str) 44 | cnames, corr = corr[:,0], corr[:,1].astype(float) 45 | cnames, corr = cnames[~np.isnan(corr)], corr[~np.isnan(corr)] 46 | # Read list of gene set 47 | clist = np.genfromtxt(sys.argv[2], dtype = str)[:,0] 48 | 49 | corrs = [corr[np.isin(cnames, clist)], corr[~np.isin(cnames, clist)]] 50 | print('neg', np.sum(corrs[0]<0)) 51 | 52 | 53 | 54 | vplot(corrs, ylabel = 'R Enformer to obs.', names = ['Susie','Non-Susie'], outname = 'Distribution_EnformerCorrelation_susieset.jpg') 55 | vplot([np.absolute(corr) for corr in corrs], ylabel = 'Abs(R) Enformer to obs.', names = ['Susie','Non-Susie'], outname = 'Distribution_EnformerAbsCorrelation_susieset.jpg') 56 | 57 | 58 | -------------------------------------------------------------------------------- /enformer_analysis/plot_driver_motifstats.py: -------------------------------------------------------------------------------- 1 | # plots statistics of attributions around driver snps in comparison to global ism attributions 2 | #python3 ../plot_driver_motifstats.py ../ism_res/ALLgenes_ism_attributions_driversfw_types.txt ALLgenes_ism_attributions_driversfw_types_ism_significance_stats.txt 2,15 --absolute --combine_columns max --lim 10 --cumulative -1 --nbins 22 --savefig ISMstats_maxzscore_in_var_or_ref.jpg 3 | 4 | import numpy as np 5 | import sys, os 6 | import matplotlib.pyplot as plt 7 | 8 | # File with drivers for all genes 9 | drivers = np.genfromtxt(sys.argv[1], dtype = str) 10 | # File generated with extract_ism_stats_around_drivers.py 11 | motifs = np.genfromtxt(sys.argv[2], dtype = str) 12 | drnames = np.array([m[0]+m[1] for m in drivers]) 13 | motnames = np.array([m[0]+m[1] for m in motifs]) 14 | # sort files to match SNVs 15 | dsort = np.argsort(drnames)[np.isin(np.sort(drnames), motnames)] 16 | drnames, drivers = drnames[dsort], drivers[dsort] 17 | msort = np.argsort(motnames)[np.isin(np.sort(motnames), drnames)] 18 | motnames, motifs = motnames[msort], motifs[msort] 19 | 20 | # If selected, only main drivers will be selected 21 | if '--maindrivers' in sys.argv: 22 | mainfile = np.genfromtxt(sys.argv[sys.argv.index('--maindrivers')+1], dtype = str) 23 | mcol = int(sys.argv[sys.argv.index('--maindrivers')+2]) 24 | mains = [] 25 | for g, gen in enumerate(np.unique(mainfile[:,0])): 26 | mask = np.where(mainfile[:,0] == gen)[0] 27 | if len(mask) > 1: 28 | argm = np.argmax(mainfile[mask,mcol].astype(float)) 29 | mains.append(mainfile[mask[argm],0]+mainfile[mask[argm],1]) 30 | else: 31 | mains.append(mainfile[mask[0],0]+mainfile[mask[0],1]) 32 | mask = np.isin(drnames, mains) 33 | print('From', len(drnames)) 34 | drnames, drivers, motnames, motifs = drnames[mask], drivers[mask], motnames[mask], motifs[mask] 35 | print('to', len(drnames)) 36 | #print(drnames) 37 | 38 | # check if rows are aligned 39 | if np.array_equal(drnames, motnames): 40 | 41 | # determine driver type 42 | driverdir = drivers[:,-2].astype(float) 43 | drloc = drivers[:,1].astype(int) 44 | drivertype = [] 45 | for d, dr in enumerate(drivers): 46 | drivertype.append(np.sign(float(dr[-1])*float(dr[-2]))) 47 | drivers = drivers[:,0] 48 | drivertype = np.array(drivertype) 49 | 50 | # Potentially combine the entries of two columns, take the mean, max, min of these columns 51 | distarg = None 52 | if '--combine_columns' in sys.argv: 53 | column = np.array(sys.argv[3].split(','), dtype = int) 54 | ctype = sys.argv[sys.argv.index('--combine_columns')+1] 55 | dist = motifs[:, column].astype(float) 56 | if '--absolute' in sys.argv: 57 | dist = np.absolute(dist) 58 | if ctype == 'min': 59 | # Decide if you would like to split the histogram between data points taken from the ref and var sequence 60 | if '--splitrefvar' in sys.argv: 61 | distarg = np.argmin(dist, axis =1) 62 | dist = np.amin(dist, axis =1) 63 | elif ctype == 'max': 64 | if '--splitrefvar' in sys.argv: 65 | distarg = np.argmax(dist, axis =1) 66 | dist = np.amax(dist, axis =1) 67 | elif ctype == 'mean': 68 | dist = np.mean(dist, axis =1) 69 | 70 | else: 71 | column = int(sys.argv[3]) 72 | dist = motifs[:, column].astype(float) 73 | 74 | # transform the data more 75 | if '--absolute' in sys.argv: 76 | dist = np.absolute(dist) 77 | 78 | if '--scale' in sys.argv: 79 | dist *= float(sys.argv[sys.argv.index('--scale')+1]) 80 | 81 | if '--lim' in sys.argv: 82 | lim = int(sys.argv[sys.argv.index('--lim')+1]) 83 | dist[np.absolute(dist)>lim] = np.sign(dist[np.absolute(dist)>lim])*lim 84 | 85 | if '--log' in sys.argv: 86 | dist = np.sign(dist)*np.log10(1+np.absolute(dist)) 87 | if '--symlog' in sys.argv: 88 | dist[np.absolute(dist)>1] = np.sign(dist[np.absolute(dist)>1])*np.log10(10+np.absolute(dist[np.absolute(dist)>1])) 89 | 90 | # define umber of bins 91 | nbins = 31 92 | if '--nbins' in sys.argv: 93 | nbins = int(sys.argv[sys.argv.index('--nbins')+1]) 94 | rnge = np.amax(dist)-np.amin(dist) 95 | over = rnge/nbins/2 96 | if '--nooverhang' in sys.argv: 97 | bins = np.linspace(np.amin(dist), np.amax(dist),nbins) 98 | # determine limits of x axis 99 | elif '--xlim' in sys.argv: 100 | blim0, blim1 = sys.argv[sys.argv.index('--xlim')+1].split(',') 101 | bins = np.linspace(float(blim0), float(blim1), nbins) 102 | else: 103 | bins = np.linspace(np.amin(dist)-over, np.amax(dist)+over,nbins) 104 | print(bins) 105 | 106 | # define if histogram is cumulative 107 | cum = False 108 | if '--cumulative' in sys.argv: 109 | cum = int(sys.argv[sys.argv.index('--cumulative')+1]) 110 | 111 | # define if density or numbers are shown in histogram 112 | dens = True 113 | if '--numbers' in sys.argv: 114 | dens = False 115 | 116 | htype = 'bar' 117 | if '--histtype' in sys.argv: 118 | htype =sys.argv[sys.argv.index('--histtype')+1] 119 | 120 | # print some stats 121 | if '--printout' in sys.argv: 122 | for dt in [1,-1]: 123 | print('Drivertype', dt) 124 | for d in np.where(drivertype == dt)[0]: 125 | print(motifs[d,0], motifs[d,1], dist[d]) 126 | 127 | fig = plt.figure(figsize = (3.5, 3.5), dpi = 200) 128 | ax = fig.add_subplot(111) 129 | ax.spines['top'].set_visible(False) 130 | ax.spines['right'].set_visible(False) 131 | if distarg is not None: 132 | nc, bi = np.histogram(dist[(drivertype == 1)*(distarg == 0)], bins = bins, density = dens) 133 | nc = np.cumsum(nc[::cum])[::cum]*np.sum((drivertype == 1)*(distarg == 0))/np.sum(drivertype == 1) 134 | xc = np.mean(np.array([bi[:-1], bi[1:]]),axis = 0) 135 | ax.bar(xc, nc, alpha = 0.6, color = 'navy', label = 'Supportive ref',zorder = 0,width= 1) 136 | ncv, bi = np.histogram(dist[(drivertype == 1)*(distarg == 1)], bins = bins, density = dens) 137 | ncv = np.cumsum(ncv[::cum])[::cum]*np.sum((drivertype == 1)*(distarg == 1))/np.sum(drivertype == 1) 138 | xc = np.mean(np.array([bi[:-1], bi[1:]]),axis = 0) 139 | ax.bar(xc, ncv, alpha = 0.6, bottom = nc , color = 'slateblue', label = 'Supportive var',zorder = -2, width = 1) 140 | nc, bi = np.histogram(dist[(drivertype == -1)*(distarg == 0)], bins = bins, density = dens) 141 | nc = np.cumsum(nc[::cum])[::cum]*np.sum((drivertype == -1)*(distarg == 0))/np.sum(drivertype == -1) 142 | xc = np.mean(np.array([bi[:-1], bi[1:]]),axis = 0) 143 | ax.bar(xc, nc, alpha = 0.6, color = 'darkgoldenrod', label = 'Unsupportive ref',zorder = 1, width= 1) 144 | ncv, bi = np.histogram(dist[(drivertype == -1)*(distarg == 1)], bins = bins, density = dens) 145 | ncv = np.cumsum(ncv[::cum])[::cum]*np.sum((drivertype == -1)*(distarg == 1))/np.sum(drivertype == -1) 146 | xc = np.mean(np.array([bi[:-1], bi[1:]]),axis = 0) 147 | ax.bar(xc, ncv, alpha = 0.6, bottom = nc , color = 'goldenrod', label = 'Unsupportive var',zorder = -1, width = 1) 148 | 149 | else: 150 | nc, bi, t_ = ax.hist(dist[(drivertype == 1)], bins = bins, alpha = 0.6, color = 'navy', histtype = htype, density = dens, cumulative = cum, label = 'Supportive') 151 | print(nc) 152 | print(dist[(drivertype == 1)]) 153 | nc, bi, t_ = ax.hist(dist[(drivertype == -1)], bins = bins, alpha = 0.6, color = 'goldenrod', histtype =htype, density = dens, cumulative = cum, label = 'Unsupportive') 154 | print(nc) 155 | print(dist[(drivertype == -1)]) 156 | 157 | if '--xlabel' in sys.argv: 158 | ax.set_xlabel(sys.argv[sys.argv.index('--xlabel')+1]) 159 | if '--ylabel' in sys.argv: 160 | ax.set_ylabel(sys.argv[sys.argv.index('--ylabel')+1]) 161 | 162 | legloc = 'best' 163 | if '--legendpos' in sys.argv: 164 | legloc = sys.argv[sys.argv.index('--legendpos')+1] 165 | 166 | ax.legend(loc=legloc) 167 | if '--savefig' in sys.argv: 168 | fig.savefig(sys.argv[sys.argv.index('--savefig')+1], dpi = 250, bbox_inches = 'tight') 169 | 170 | else: 171 | plt.show() 172 | 173 | -------------------------------------------------------------------------------- /enformer_analysis/plot_drivercounts.py: -------------------------------------------------------------------------------- 1 | # plot the distribution of the number of drivers per investigated genes 2 | #python3 ../plot_drivercounts.py Counts_ism_attributions_driversfw.txt ../Prediction_correlationsCageAdultBrain.txt ALL_genes_ism_attributions_sumpersonal_mp_vs_Enformer_predictions_CageAdultBrain.txt --genelist ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list --split_sets 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import sys, os 7 | from functools import reduce 8 | 9 | # Read driver counts 10 | counts = np.genfromtxt(sys.argv[1], dtype = str) 11 | # Read correlation of predictions to observed 12 | predictions = np.genfromtxt(sys.argv[2], dtype = str) 13 | # Read correlation between prediction and linear approximation 14 | approx_control = np.genfromtxt(sys.argv[3], dtype = str) 15 | 16 | # sort files to match rows 17 | commons = reduce(np.intersect1d, [counts[:,0], predictions[:,0], approx_control[:,0]]) 18 | a_, sort = np.unique(approx_control[:,0], return_index = True) 19 | sort = sort[np.isin(a_, commons)] 20 | approx_control = approx_control[sort] 21 | a_, sort = np.unique(predictions[:,0], return_index = True) 22 | sort = sort[np.isin(a_, commons)] 23 | predictions = predictions[sort] 24 | a_, sort = np.unique(counts[:,0], return_index = True) 25 | sort = sort[np.isin(a_, commons)] 26 | counts = counts[sort] 27 | 28 | # Check if sorting was successful 29 | if not np.array_equal(counts[:,0], approx_control[:,0]) or not np.array_equal(counts[:,0], predictions[:,0]): 30 | print('Files not alinged') 31 | print(np.array_equal(counts[:,0], approx_control[:,0]), np.array_equal(counts[:,0], predictions[:,0]), len(predictions)) 32 | sys.exit() 33 | 34 | # select subset of genes 35 | if '--genelist' in sys.argv: 36 | genelist = np.genfromtxt(sys.argv[sys.argv.index('--genelist')+1], dtype =str) 37 | mask = np.isin(counts[:,0], genelist) 38 | counts, predictions, approx_control = counts[mask], predictions[mask], approx_control[mask] 39 | 40 | control = approx_control[:,1].astype(float) > 0.2 41 | counts, predictions = counts[control], predictions[control] 42 | 43 | 44 | # split into positively and negatively correlated genes 45 | if '--split_sets' in sys.argv: 46 | pos = predictions[:,1].astype(float) > 0.2 47 | neg = predictions[:,1].astype(float) < -0.2 48 | else: 49 | pos = np.ones(len(predictions)) ==1 50 | neg = np.ones(len(predictions)) ==1 51 | 52 | 53 | 54 | fig = plt.figure(figsize = (8,4)) 55 | axp = fig.add_subplot(121) 56 | bins = np.arange(np.amax(counts[pos|neg,1].astype(float))+1)-0.5 57 | a,b,c = axp.hist(counts[pos,1].astype(float), color = 'navy', alpha = 0.5, bins = bins) 58 | print(a) 59 | print(b) 60 | if '--split_sets' in sys.argv: 61 | axp.set_title('Positive correlated') 62 | axp.spines['top'].set_visible(False) 63 | axp.spines['right'].set_visible(False) 64 | axp.set_xlabel('Number drivers') 65 | axp.set_ylabel('Number genes') 66 | axp.set_xlim([0,20]) 67 | 68 | if '--split_sets' in sys.argv: 69 | axn = fig.add_subplot(122) 70 | a,b,c = axn.hist(counts[neg,1].astype(float), color = 'navy', alpha = 0.5, bins = bins) 71 | axn.set_title('Negative correlated') 72 | axn.spines['top'].set_visible(False) 73 | axn.spines['right'].set_visible(False) 74 | axn.set_xlabel('Number drivers') 75 | axn.set_ylabel('Number genes') 76 | axn.set_xlim([0,20]) 77 | 78 | 79 | figc = plt.figure(figsize = (8,4)) 80 | axpc = figc.add_subplot(121) 81 | 82 | a_, b_, c_ = axpc.hist(counts[pos,1].astype(float), color = 'navy', density = True, bins = bins, cumulative = 1, alpha = 0.5) 83 | print(a_) 84 | print(b_) 85 | if '--split_sets' in sys.argv: 86 | axpc.set_title('Positive correlated') 87 | axpc.spines['top'].set_visible(False) 88 | axpc.spines['right'].set_visible(False) 89 | axpc.set_xlabel('Number drivers') 90 | axpc.set_ylabel('Cumulative Number genes') 91 | axpc.set_xlim([0,20]) 92 | axpc.grid() 93 | 94 | if '--split_sets' in sys.argv: 95 | axnc = figc.add_subplot(122) 96 | axnc.hist(counts[neg,1].astype(float), color = 'navy', density = True, bins = bins, cumulative = 1, alpha = 0.5) 97 | axnc.set_title('Negative correlated') 98 | axnc.spines['top'].set_visible(False) 99 | axnc.spines['right'].set_visible(False) 100 | axnc.set_xlabel('Number drivers') 101 | axnc.set_ylabel('Cumulative Number genes') 102 | axnc.set_xlim([0,20]) 103 | axnc.grid() 104 | 105 | 106 | fig.savefig(os.path.splitext(sys.argv[1])[0]+'_distribution.jpg', dpi = 250, bbox_inches = 'tight') 107 | figc.savefig(os.path.splitext(sys.argv[1])[0]+'_cumdistribution.jpg', dpi = 250, bbox_inches = 'tight') 108 | print(os.path.splitext(sys.argv[1])[0]+'_distribution.jpg') 109 | 110 | -------------------------------------------------------------------------------- /enformer_analysis/plot_driverdistance.py: -------------------------------------------------------------------------------- 1 | # Plot histogram of the distances between drivers and the TSS 2 | # 3 | #python3 ../plot_driverdistance.py DistancetoTSS_ism_attributions_driversbw.txt ../Prediction_correlationsCageAdultBrain.txt ALL_genes_ism_attributions_sumpersonal_mp_vs_Enformer_predictions_CageAdultBrain.txt --genelist ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list --plot_main_in_all 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import sys, os 8 | from functools import reduce 9 | 10 | # Read Distances to TSS 11 | counts = np.genfromtxt(sys.argv[1], dtype = str) 12 | # Read correletions of predictions 13 | predictions = np.genfromtxt(sys.argv[2], dtype = str) 14 | # Read in correlation of sum of attributions to prediction with full model 15 | approx_control = np.genfromtxt(sys.argv[3], dtype = str) 16 | 17 | # sort files to match 18 | commons = np.sort(reduce(np.intersect1d, [counts[:,0], predictions[:,0], approx_control[:,0]])) 19 | a_, sort = np.unique(approx_control[:,0], return_index = True) 20 | sort = sort[np.isin(a_, commons)] 21 | approx_control = approx_control[sort] 22 | a_, sort = np.unique(predictions[:,0], return_index = True) 23 | sort = sort[np.isin(a_, commons)] 24 | predictions = predictions[sort] 25 | counts = counts[np.argsort(counts[:,0])] 26 | 27 | # check if sorting worked 28 | print(np.array_equal(predictions[:,0], approx_control[:,0]), np.array_equal(commons, predictions[:,0]), len(predictions)) 29 | 30 | # Select subset of genes 31 | if '--genelist' in sys.argv: 32 | genelist = np.genfromtxt(sys.argv[sys.argv.index('--genelist')+1], dtype =str) 33 | mask = np.isin(predictions[:,0], genelist) 34 | predictions, approx_control, commons = predictions[mask], approx_control[mask], commons[mask] 35 | 36 | # Only work with genes that can be linearly approximated 37 | control = approx_control[:,1].astype(float) > 0.2 38 | approx_control, predictions, commons = approx_control[control], predictions[control], commons[control] 39 | print('Commons fine', np.array_equal(commons, approx_control[:,0])) 40 | 41 | counts = counts[np.isin(counts[:,0], approx_control[:,0])] 42 | bcountmask = [] 43 | for c, com in enumerate(commons): 44 | isgene = np.where(counts[:,0] == com)[0] 45 | if len(isgene) > 0: 46 | bcountmask.append(isgene[np.argmax(counts[isgene,-1].astype(float))]) 47 | bcounts = counts[bcountmask] 48 | 49 | # Split genes in genes with postive and negative correlation to observed data 50 | bpos = predictions[:,1].astype(float) > 0.2 51 | bneg = predictions[:,1].astype(float) < -0.2 52 | pos = np.isin(counts[:,0], commons[bpos]) 53 | neg = np.isin(counts[:,0], commons[bneg]) 54 | 55 | 56 | # Generate histograms 57 | fig = plt.figure(figsize = (6,4)) 58 | axp = fig.add_subplot(121) 59 | axn = fig.add_subplot(122) 60 | 61 | bins = np.concatenate([-np.arange(10000,10000,10000)[::-1],-np.arange(2000,10000,2000)[::-1],-np.arange(500,2000,250)[::-1],-np.arange(0,500,100)[::-1], np.arange(0,500,100),np.arange(500,2000,250),np.arange(2000,10000,2000), np.arange(10000,10000,10000)]) 62 | 63 | bins = np.arange(-70000,70000,1000) 64 | 65 | a,b,c = axp.hist(counts[pos,2].astype(float), color = 'forestgreen', alpha = 0.9, bins = bins) 66 | axp.set_title('Positively correlated genes') 67 | axp.spines['top'].set_visible(False) 68 | axp.spines['right'].set_visible(False) 69 | axp.set_xlabel('Distance to TSS') 70 | axp.set_ylabel('Number drivers') 71 | if '--plot_main_in_all' in sys.argv: 72 | axp.hist(bcounts[bpos,2].astype(float), color = 'orange', bins = bins, alpha = 0.9, label = 'Main\ndrivers') 73 | axp.legend(loc = 'upper right', fontsize = 'small') 74 | axp.plot([0,0],[0,np.amax(a)], c = 'goldenrod', ls = '--') 75 | #axp.set_xscale('symlog') 76 | 77 | 78 | a,b,c = axn.hist(counts[neg,2].astype(float), color = 'indigo', alpha = 0.9, bins = bins) 79 | axn.set_title('Negatively correlated genes') 80 | axn.spines['top'].set_visible(False) 81 | axn.spines['right'].set_visible(False) 82 | axn.set_xlabel('Distance to TSS') 83 | axn.set_ylabel('Number drivers') 84 | 85 | # Show main drivers in different color 86 | if '--plot_main_in_all' in sys.argv: 87 | axn.hist(bcounts[bneg,2].astype(float), color = 'mediumvioletred', bins = bins, alpha = 0.9, label = 'Main\ndrivers') 88 | axn.legend(loc = 'upper right', fontsize = 'small') 89 | axn.plot([0,0],[0,np.amax(a)], c = 'goldenrod', ls = '--') 90 | #axn.set_xscale('symlog') 91 | 92 | 93 | # Generate extra figure main drivers 94 | figc = plt.figure(figsize = (6,4)) 95 | axpc = figc.add_subplot(121) 96 | axnc = figc.add_subplot(122) 97 | 98 | axpc.hist(bcounts[bpos,2].astype(float), color = 'forestgreen', bins = bins, alpha = 0.5) 99 | axpc.set_title('Positively correlated genes') 100 | axpc.spines['top'].set_visible(False) 101 | axpc.spines['right'].set_visible(False) 102 | axpc.set_xlabel('Distance to TSS') 103 | axpc.set_ylabel('Number main drivers') 104 | 105 | axnc.hist(bcounts[bneg,2].astype(float), color = 'indigo', bins = bins, alpha = 0.5) 106 | axnc.set_title('Negatively correlated genes') 107 | axnc.spines['top'].set_visible(False) 108 | axnc.spines['right'].set_visible(False) 109 | axnc.set_xlabel('Distance to TSS') 110 | axnc.set_ylabel('Number main drivers') 111 | 112 | fig.savefig(os.path.splitext(sys.argv[1])[0]+'_distribution.jpg', dpi = 450, bbox_inches = 'tight') 113 | axp.set_xlim([-19999,19999]) 114 | axn.set_xlim([-19999,19999]) 115 | fig.savefig(os.path.splitext(sys.argv[1])[0]+'_distributionzoom.jpg', dpi = 450, bbox_inches = 'tight') 116 | figc.savefig(os.path.splitext(sys.argv[1])[0]+'_maindistribution.jpg', dpi = 450, bbox_inches = 'tight') 117 | 118 | -------------------------------------------------------------------------------- /enformer_analysis/plot_drivertype.py: -------------------------------------------------------------------------------- 1 | # plots fraction of drivertypes per gene against correlation to observed data 2 | #python3 ../plot_drivertype.py ALLgenes_ism_attributions_driversfw_types.txt ../Prediction_correlationsCageAdultBrain.txt ALL_genes_ism_attributions_sumpersonal_mp_vs_Enformer_predictions_CageAdultBrain.txt --genelist ../Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats_set0.2.list -weighted ALLgenes_ism_attributions_driversfw.txt -3 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import sys, os 7 | from functools import reduce 8 | import seaborn as sns 9 | from scipy.stats import mannwhitneyu as mwu 10 | 11 | 12 | # Read combined drivertype file that contains drivers for all genes 13 | counts = np.genfromtxt(sys.argv[1], dtype = str) 14 | allctypes = counts[:,-2:].astype(float) 15 | bitypes = np.zeros((len(counts),2)) 16 | # Assign is-supported and is unsupported driver 17 | for c, ctypes in enumerate(allctypes): 18 | if (ctypes[0]<0 and ctypes[1]<0) or (ctypes[0]>0 and ctypes[1]>0): 19 | bitypes[c,0] = 1 20 | elif ctypes[0]>0 and ctypes[1]<0: 21 | bitypes[c,1] = 1 22 | elif ctypes[0]<0 and ctypes[1]>0: 23 | bitypes[c,1] = 1 24 | # Read prediction values 25 | predictions = np.genfromtxt(sys.argv[2], dtype = str) 26 | # Only investigate genes that can be approximated with the sum of ISMs > 0.2 to prediction 27 | approx_control = np.genfromtxt(sys.argv[3], dtype = str) 28 | 29 | # Sort files to align 30 | commons = np.sort(reduce(np.intersect1d, [counts[:,0], predictions[:,0], approx_control[:,0]])) 31 | a_, sort = np.unique(approx_control[:,0], return_index = True) 32 | sort = sort[np.isin(a_, commons)] 33 | approx_control = approx_control[sort] 34 | a_, sort = np.unique(predictions[:,0], return_index = True) 35 | sort = sort[np.isin(a_, commons)] 36 | predictions = predictions[sort] 37 | 38 | # Select subset of genes 39 | if '--genelist' in sys.argv: 40 | genelist = np.genfromtxt(sys.argv[sys.argv.index('--genelist')+1], dtype = str) 41 | print(np.array_equal(approx_control[:,0], commons)) 42 | mask = np.isin(commons, genelist) 43 | approx_control, predictions, commons = approx_control[mask], predictions[mask], commons[mask] 44 | 45 | # Weight drivers by the attribution to the prediction for mean within each gene 46 | weights = np.ones(len(counts)) 47 | if '--weighted' in sys.argv: 48 | wfile = np.genfromtxt(sys.argv[sys.argv.index('--weighted')+1], dtype = str) 49 | wl = int(sys.argv[sys.argv.index('--weighted')+2]) 50 | cnames = np.array([cnt[0]+'_'+cnt[1] for cnt in counts]) 51 | wnames = np.array([cnt[0]+'_'+cnt[1] for cnt in wfile]) 52 | csort = np.argsort(cnames)[np.isin(np.sort(cnames), wnames)] 53 | counts, types, bitypes = counts[csort], types[csort], bitypes[csort] 54 | weights = wfile[np.argsort(wnames)[np.isin(np.sort(wnames), cnames)],wl].astype(float) 55 | 56 | # Generate mean per gene from individual drivers 57 | bipgcounts = np.zeros((len(commons),2)) 58 | keep = np.zeros(len(commons)) == 0 59 | for c, com in enumerate(commons): 60 | if com in counts[:,0]: 61 | cmask = counts[:,0]==com 62 | bipgcounts[c] = np.sum(weights[cmask][:,None]*bitypes[cmask] , axis = 0)/np.sum(weights[cmask]) 63 | else: 64 | keep[c] = False 65 | 66 | approx_control, predictions, commons, bipgcounts = approx_control[keep], predictions[keep], commons[keep], bipgcounts[keep] 67 | 68 | #Only investigate genes that can be approximated with the sum of ISMs > 0.2 to prediction 69 | control = approx_control[:,1].astype(float) > 0.2 70 | approx_control, predictions, commons, pgcounts, bipgcounts = approx_control[control], predictions[control], commons[control], pgcounts[control], bipgcounts[control] 71 | 72 | print('Commons fine', np.array_equal(commons, approx_control[:,0])) 73 | 74 | # Divide sets into sets with significant positive and negative correlation 75 | bpos = predictions[:,1].astype(float) > 0.2 76 | bneg = predictions[:,1].astype(float) < -0.2 77 | 78 | # Assign p-values to differences between two sets and replace value with * symbols 79 | pvalcut = np.array([0.05, 0.01,0.001]) 80 | pvalsign = np.array(['*', '**', '***']) 81 | 82 | # Print some stats 83 | print('Positive') 84 | for bpo in np.where(bpos)[0]: 85 | print(commons[bpo], bipgcounts[bpo]) 86 | 87 | print('\nNegative') 88 | for bpo in np.where(bneg)[0]: 89 | print(commons[bpo], bipgcounts[bpo]) 90 | 91 | # Generate figures for fraction of supported and unsupported drivers 92 | muttype = ['supported', 'unsupported'] 93 | for t, mut in enumerate(muttype): 94 | fig = plt.figure(figsize = (1.5,4)) 95 | ax = fig.add_subplot(111) 96 | parts = ax.violinplot([bipgcounts[bpos,t], bipgcounts[bneg,t]], widths = 0.95, showmeans=False, showmedians=False, showextrema=True) 97 | fcolors = ['forestgreen', 'indigo'] 98 | for p, pc in enumerate(parts['bodies']): 99 | pc.set_facecolor(fcolors[p]) 100 | pc.set_edgecolor('black') 101 | pc.set_alpha(1) 102 | for partname in ('cbars','cmins','cmaxes'): 103 | vp = parts[partname] 104 | vp.set_edgecolor('k') 105 | quartile1, medians, quartile3 = [np.percentile(bipgcounts[bpos,t], [25]),np.percentile(bipgcounts[bneg,t], [25])], [np.percentile(bipgcounts[bpos,t], [50]),np.percentile(bipgcounts[bneg,t], [50])],[np.percentile(bipgcounts[bpos,t], [75]),np.percentile(bipgcounts[bneg,t], [75])] 106 | ax.scatter([1,2], medians, marker='o', color='grey', s=30, zorder=3) 107 | ax.vlines([1,2], quartile1, quartile3, color='k', linestyle='-', lw=10) 108 | pval = mwu(bipgcounts[bpos,t], bipgcounts[bneg,t])[1]*len(muttype)/2 109 | ps = np.where(pvalcut>pval)[0] 110 | if len(ps) > 0: 111 | ax.plot([1,2], [1.1, 1.1], c = 'k') 112 | ax.plot([1,1], [1.07,1.1], c = 'k') 113 | ax.plot([2,2], [1.07, 1.1], c = 'k') 114 | ax.text(1.5, 1.11, pvalsign[ps[-1]], ha = 'center', va = 'bottom') 115 | ax.spines['top'].set_visible(False) 116 | ax.spines['right'].set_visible(False) 117 | ax.set_xticks([1,2]) 118 | ax.set_xticklabels(['Positive', 'Negative'], rotation = 45) 119 | ax.set_xlabel('Correlation to Enformer') 120 | ax.set_ylabel('Fraction of '+mut+' drivers') 121 | fig.savefig(os.path.splitext(sys.argv[1])[0]+'_boxplot_negvspos'+mut+'.svg', transparent = True, dpi = 450, bbox_inches = 'tight') 122 | 123 | 124 | figx = plt.figure(figsize = (3.,3.5)) 125 | axx = figx.add_subplot(111) 126 | axx.spines['top'].set_visible(False) 127 | axx.spines['right'].set_visible(False) 128 | axx.plot([0,0],[0,1], color = 'grey', zorder = -1) 129 | axx.scatter(predictions[bpos,1].astype(float), bipgcounts[bpos,t], color = 'forestgreen') 130 | axx.scatter(predictions[bneg,1].astype(float), bipgcounts[bneg,t], color = 'indigo') 131 | axx.scatter(predictions[~bneg*~bpos,1].astype(float), bipgcounts[~bneg*~bpos,t], color = 'grey') 132 | axx.set_xlabel('Correlation to Enformer') 133 | axx.set_ylabel('Fraction of '+mut+' drivers') 134 | 135 | figx.savefig(os.path.splitext(sys.argv[1])[0]+'_scatter_negvspos'+mut+'.jpg', transparent = True, dpi = 450, bbox_inches = 'tight') 136 | 137 | 138 | -------------------------------------------------------------------------------- /enformer_analysis/plot_individual_gene_scatter.py: -------------------------------------------------------------------------------- 1 | # Plots scatter plot between observed and predicted expression values across same set of individuals for a single gene 2 | # $python plot_individual_gene_scatter.py Observed_gene_expression.txt Enformer_predictions.txt DDX11 --figsize 4 3 3 | 4 | import numpy as np 5 | import sys, os 6 | from scipy.stats import pearsonr 7 | import matplotlib.pyplot as plt 8 | 9 | def read(file): 10 | genes, indv, exp = [],[],[] 11 | for l, line in enumerate(open(file,'r').readlines()): 12 | if l == 0: 13 | genes = line.strip().split('\t') 14 | else: 15 | line = line.strip().split('\t') 16 | indv.append(line[0]) 17 | exp.append(line[1:]) 18 | return np.array(genes), np.array(indv), np.array(exp, dtype = float) 19 | 20 | # Read observed expression values 21 | obgenes, obindv, obexp = read(sys.argv[1]) 22 | # Read predicted expression values 23 | enfgenes, enfindv, enfexp = read(sys.argv[2]) 24 | # Define gene (column) 25 | chg = sys.argv[3] 26 | 27 | # Sort enfgenes and obgenes and the data matrices to match 28 | e_, s1a = np.unique(enfgenes, return_index = True) 29 | o_, s2a = np.unique(obgenes, return_index = True) 30 | s1a, s1b = s1a[np.isin(e_, obgenes)], np.argsort(enfindv)[np.isin(np.sort(enfindv), obindv)] 31 | s2a, s2b = s2a[np.isin(o_, enfgenes)], np.argsort(obindv)[np.isin(np.sort(obindv), enfindv)] 32 | enfgenes, enfindv, enfexp = enfgenes[s1a], enfindv[s1b], enfexp[s1b][:,s1a] 33 | obgenes, obindv, obexp = obgenes[s2a], obindv[s2b], obexp[s2b][:,s2a] 34 | 35 | # Check if rows and columns match from sorting 36 | print(np.array_equal(enfgenes, obgenes), np.array_equal(enfindv, obindv), np.shape(enfexp), np.shape(obexp)) 37 | 38 | figsize = (4,3.8) 39 | if '--figsize' in sys.argv: 40 | figsize = (float(sys.argv[sys.argv.index('--figsize')+1]), float(sys.argv[sys.argv.index('--figsize')+2])) 41 | 42 | 43 | # Select column of gene 44 | ic = list(enfgenes).index(chg) 45 | fig = plt.figure(figsize=figsize, dpi = 200) 46 | ax = fig.add_subplot(111) 47 | r = round(pearsonr(obexp[:,ic], enfexp[:,ic])[0], 2) 48 | ax.scatter(obexp[:,ic], enfexp[:,ic], alpha = 0.8, label = 'R='+str(r),color = 'darkred') 49 | ax.spines['top'].set_visible(False) 50 | ax.spines['right'].set_visible(False) 51 | ax.set_xlabel('Observed') 52 | ax.set_ylabel('Predicted') 53 | ax.legend() 54 | ax.set_title(chg) 55 | if '--setxlim' in sys.argv: 56 | ax.set_xlim(np.array(sys.argv[sys.argv.index('--setxlim')+1].split(','), dtype = float)) 57 | 58 | if '--setylim' in sys.argv: 59 | ax.set_ylim(np.array(sys.argv[sys.argv.index('--setylim')+1].split(','), dtype = float)) 60 | 61 | fig.savefig(chg+'scatter.jpg', dpi = 350, bbox_inches = 'tight') 62 | if '--show' in sys.argv: 63 | plt.show() 64 | 65 | 66 | -------------------------------------------------------------------------------- /enformer_analysis/plot_refandvar_attribution.py: -------------------------------------------------------------------------------- 1 | # Plot the mean attribution and attribution for base to base changes in a specific window 2 | #python3 ../plot_refandvar_attribution.py ENSG00000134202_109741038_attribution.npz --squaresize 0.12 --include_snvs 109741038 ../ism_res/ENSG00000134202_ism_attributions.txt ../eqtl/ENSG00000134202_corr.txt --markdrivers ../ism_res/ENSG00000134202_ism_attributions_driversfw.txt --dpi 350 --include_conservation ../PhyloP100/ENSG00000134202_1000tss_in_hg38.phyloP100way.txt --excludeheatmap --enlargepwm 1.8 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import logomaker as lm 7 | import sys, os 8 | import pandas as pd 9 | 10 | nts = list('ACGT') 11 | ntsar = np.array(nts) 12 | 13 | # Load npz with attributions 14 | f = np.load(sys.argv[1]) 15 | outname = os.path.splitext(sys.argv[1])[0] 16 | gene, locge = os.path.split(outname)[1].split('_')[:2] 17 | attribution, seq = f['attribution'], str(f['seq']) 18 | ohseq = np.array(list(seq))[:,None] == ntsar 19 | 20 | # Norm attribuitons with external values, f.e. the max or a std 21 | if '--norm_attributions' in sys.argv: 22 | norms = np.genfromtxt(sys.argv[sys.argv.index('--norm_attributions')+1], dtype = str) 23 | ngene = list(norms[:,0]).index(sys.argv[sys.argv.index('--norm_attributions')+2]) 24 | norm = float(norms[ngene,int(sys.argv[sys.argv.index('--norm_attributions')+3])]) 25 | attribution/=norm 26 | maxg = np.amax(np.absolute(attribution)) 27 | 28 | # adjust figsize 29 | if '--figsize' in sys.argv: 30 | l,h = sys.argv[sys.argv.index('--figsize')+1].split(',') 31 | fig = plt.figure(figsize = (float(l), float(h))) 32 | sqsize = float(l)/len(attribution) 33 | # or adjust size per unit of attribution 34 | elif '--squaresize' in sys.argv: 35 | sqsize = float(sys.argv[sys.argv.index('--squaresize')+1]) 36 | fig = plt.figure(figsize = (sqsize*len(attribution),sqsize*9)) 37 | else: 38 | fig = plt.figure(figsize = (0.4*len(attribution),3.5)) 39 | sqsize = 0.4 40 | 41 | # Plot the attributions of individual SNVs into the mean attribution plot 42 | if '--include_snvs' in sys.argv: 43 | center = int(sys.argv[sys.argv.index('--include_snvs')+1]) 44 | snvs = np.genfromtxt(sys.argv[sys.argv.index('--include_snvs')+2], dtype = str) 45 | freqs = np.genfromtxt(sys.argv[sys.argv.index('--include_snvs')+3], dtype = str) 46 | loc = snvs[:,0].astype(int)-center + int(len(attribution)/2) 47 | lmask = (loc < len(attribution))*(loc >=0 ) 48 | loc, snvatt, snvs = loc[lmask], snvs[lmask,1].astype(float), snvs[lmask] 49 | fmask = (freqs[:,0].astype(int)-center + int(len(attribution)/2) < len(attribution)) * (freqs[:,0].astype(int)-center + int(len(attribution)/2) >= 0) 50 | if not np.array_equal(freqs[fmask, 0], snvs[:,0]): 51 | print('Freqs and attributions dont match') 52 | print(freqs[fmask, 0], snvs[:,0]) 53 | sys.exit() 54 | freqs = np.nan_to_num(freqs[fmask,1].astype(float)) 55 | if '--norm_attributions' in sys.argv: 56 | snvatt /=norm 57 | # also want frequency and common variant to show in the heatmap 58 | edgecolors = np.chararray(len(snvatt), itemsize = 10, unicode = True) 59 | edgecolors[:] = 'k' 60 | msizes = np.ones(len(snvatt)) * sqsize * 450 61 | if '--markdrivers' in sys.argv: 62 | driverfile = open(sys.argv[sys.argv.index('--markdrivers')+1], 'r').readlines() 63 | if len(driverfile) > 0: 64 | drivers = np.array([line.strip().split() for line in driverfile]) 65 | maindriver = np.argmax(drivers[:,-3].astype(float)) 66 | driversinside = np.where(np.isin(drivers[:,0], snvs[:,0]))[0] 67 | if len(driversinside) > 0: 68 | edgecolors[np.isin(snvs[:,0], drivers[:,0])] = 'tomato' 69 | msizes[np.isin(snvs[:,0], drivers[:,0])] *= 2.5 70 | if maindriver in driversinside: 71 | edgecolors[np.isin(snvs[:,0], drivers[maindriver,0])] = 'r' 72 | # Include a plot with conservation values 73 | naxes = 2 74 | if '--include_conservation' in sys.argv: 75 | consfile = np.genfromtxt(sys.argv[sys.argv.index('--include_conservation')+1], dtype = str) 76 | clocs, cons = consfile[:,1].astype(int)-center+ int(len(attribution)/2), consfile[:,2].astype(float) 77 | clmask = (clocs>=0)*(clocs1.3] = 'blue' 90 | cons[cons < -0.5] = -0.5 91 | conspwm = np.zeros(np.shape(attribution)) 92 | conspwm[ohseq] = cons 93 | conspwm[conspwm<1.3] = 0 94 | cons[cons>=1.3] = 0 95 | axcons.bar(clocs, cons, color = colors) 96 | conspwm = pd.DataFrame({'A':conspwm[:,0],'C':conspwm[:,1], 'G':conspwm[:,2], 'T':conspwm[:,3]}) 97 | lm.Logo(conspwm, ax = axcons) 98 | axcons.plot([0,len(attribution)], [0,0], color = 'grey') 99 | axcons.set_xlim([-0.5,len(attribution)-0.5]) 100 | axcons.set_ylabel('PhyloP') 101 | 102 | # Exclude the heatmap that shows attribuitons for individual base to base changes 103 | if not '--excludeheatmap' in sys.argv: 104 | ax = fig.add_subplot(naxes,1,naxes) 105 | hm = ax.imshow(attribution.T, aspect = 'auto', vmin = -maxg, vmax = maxg, cmap = 'RdBu_r') 106 | ax.set_xticks(np.arange(len(attribution))) 107 | ax.set_yticks(np.arange(len(nts))) 108 | ax.set_yticklabels(nts) 109 | ax.set_xticklabels(list(seq)) 110 | ax.set_xticks(np.arange(len(attribution))+0.5, minor = True) 111 | ax.set_yticks(np.arange(len(nts))+0.5, minor = True) 112 | ax.grid(which = 'minor', color = 'k') 113 | #fig.colorbar(hm, pad = 0., fraction = 0.09, shrink = 0.15, aspect = 2, anchor = (0.,0.99), ax = ax) 114 | ax.set_ylabel('ISM') 115 | 116 | 117 | axpwm = fig.add_subplot(naxes, 1, 1) 118 | # Increase the height of the mean attribuition that is shown as a pwm 119 | if '--enlargepwm' in sys.argv: 120 | pos1 = axpwm.get_position() 121 | enl = float(sys.argv[sys.argv.index('--enlargepwm')+1]) 122 | axpwm.set_position([pos1.x0,pos1.y0,pos1.width,enl]) 123 | axpwm.set_title(gene+' - '+locge) 124 | axpwm.spines['top'].set_visible(False) 125 | axpwm.spines['right'].set_visible(False) 126 | axpwm.spines['bottom'].set_visible(False) 127 | axpwm.tick_params(bottom = False, labelbottom = False) 128 | pwm = np.zeros(np.shape(attribution)) 129 | pwm[ohseq] = -np.mean(attribution, axis = 1) 130 | axpwm.set_ylabel('Mean Attribution') 131 | pwm = pd.DataFrame({'A':pwm[:,0],'C':pwm[:,1], 'G':pwm[:,2], 'T':pwm[:,3]}) 132 | lm.Logo(pwm, ax = axpwm) 133 | ylim = axpwm.get_ylim() 134 | if '--include_snvs' in sys.argv: 135 | axpwm.vlines(loc, np.zeros(len(snvatt)), snvatt, color = 'k', zorder = -1) 136 | axpwm.scatter(loc, snvatt, cmap = 'bwr', c = freqs, vmin = -0.5, vmax =0.5, marker = 'o', s = msizes, edgecolor = edgecolors, lw = 1.5) 137 | axpwm.set_ylim([min(ylim[0],np.amin(snvatt))*1.08,max(ylim[1],np.amax(snvatt))*1.08]) 138 | 139 | axpwm.plot([int(len(attribution)/2), int(len(attribution)/2)],ylim, c= 'darkgoldenrod', ls = '--') 140 | if '--setylimpwm' in sys.argv: 141 | ylim = sys.argv[sys.argv.index('--setylimpwm')+1].split(',') 142 | axpwm.set_ylim([float(ylim[0]), float(ylim[1])]) 143 | 144 | dpi = 150 145 | if '--dpi' in sys.argv: 146 | dpi = float(sys.argv[sys.argv.index('--dpi')+1]) 147 | 148 | fig.savefig(outname+'.jpg', dpi = dpi, bbox_inches = 'tight') 149 | print(outname+'.jpg') 150 | #plt.show() 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /enformer_analysis/predict_with_augmented_data.py: -------------------------------------------------------------------------------- 1 | # Test effect of data augmentation (shifts and reverse complement) on Enformer output 2 | 3 | import time 4 | import pandas as pd 5 | import numpy as np 6 | import sparse 7 | import os 8 | 9 | import tensorflow as tf 10 | import tensorflow_hub as hub 11 | 12 | import time as time 13 | 14 | import kipoiseq 15 | from kipoiseq import Interval 16 | 17 | save_path = '/data/aspiro17/enformer_res/' 18 | model_path = 'https://tfhub.dev/deepmind/enformer/1' 19 | data_path = '/data/mostafavilab/bng/rosmapAD/projects/insilicoMutagenesis/extractSequence/results/sequence100K/' 20 | 21 | track_idx = 4980 # CAGE:brain, adult, 22 | center_bins = [447,448,449] # three center bins of enformer output 23 | 24 | save_path = '/data/aspiro17/enformer_res/' 25 | model_path = 'https://tfhub.dev/deepmind/enformer/1' 26 | 27 | padded_input_len = 393216 # full input to enformer 28 | input_len = 196608 29 | starting_seq_len = 200001 # starting length of reference seq 30 | pad_before = int((padded_input_len - input_len)/2) # "pad" sequence to padded_input_len 31 | 32 | mid_index = int((starting_seq_len-1)/2) 33 | start_index = int(mid_index - (input_len-1)/2) 34 | end_index = int(mid_index + (input_len-1)/2) + 1 35 | 36 | shifts = [-3, -2, -1, 0, 1, 2, 3] # how to shift sequence 37 | 38 | # enformer helper functs 39 | class Enformer: 40 | 41 | def __init__(self, tfhub_url): 42 | self._model = hub.load(tfhub_url).model 43 | 44 | def predict_on_batch(self, inputs): 45 | predictions = self._model.predict_on_batch(inputs) 46 | return {k: v.numpy() for k, v in predictions.items()} 47 | 48 | 49 | # other helper functs 50 | 51 | def pad_genes(pat_current_gene, mat_current_gene,shift=0): 52 | # pad genes to create a shift 53 | 54 | padded_pat_current_gene = np.pad(pat_current_gene, pad_width=((0,0),(pad_before-start_index+shift, pad_before-(starting_seq_len - end_index)-shift), (0,0))) 55 | padded_mat_current_gene = np.pad(mat_current_gene, pad_width=((0,0),(pad_before-start_index+shift, pad_before-(starting_seq_len - end_index)-shift), (0,0))) 56 | 57 | return padded_pat_current_gene, padded_mat_current_gene 58 | 59 | 60 | def rev_comp(pat_current_gene, mat_current_gene): 61 | 62 | pat_comp = pat_current_gene[:,:,[3,2,1,0]] 63 | mat_comp = mat_current_gene[:,:,[3,2,1,0]] 64 | 65 | rc_pat = np.flip(pat_comp, axis=1) 66 | rc_mat = np.flip(mat_comp, axis=1) 67 | 68 | return rc_pat, rc_mat 69 | 70 | 71 | def get_aug_pred(gene_id, curr_chr): 72 | 73 | # data for specific subjects 74 | curr_gene_path = data_path + 'chr' + str(curr_chr) + '/' + gene_id + '.npz' 75 | current_sparse = sparse.load_npz(curr_gene_path) 76 | current_gene = current_sparse.todense() 77 | 78 | # if not running for all subj 79 | current_gene = np.take(current_gene, final_idx, axis=2) 80 | num_subj = current_gene.shape[2] 81 | 82 | # init results 83 | gene_res = np.zeros([num_subj,2,8],dtype=float) 84 | 85 | pat_current_gene = current_gene[:4,:] 86 | mat_current_gene = current_gene[4:,:] 87 | 88 | # transpose to be num_sub x seq_len x 4 89 | pat_current_gene = np.transpose(pat_current_gene, (2,1,0)) 90 | mat_current_gene = np.transpose(mat_current_gene, (2,1,0)) 91 | 92 | # go from previous: seqP=='A',seqP=='T',seqP=='G',seqP=='C', to Enformer: 'ACGT' pat_current_gene = pat_current_gene[:, :, [0,3,2,1]] 93 | mat_current_gene = mat_current_gene[:, :, [0,3,2,1]] 94 | 95 | # get output for shifted sequence 96 | for shift in range(len(shifts)): 97 | padded_pat_current_gene, padded_mat_current_gene = pad_genes(pat_current_gene, mat_current_gene,shift=shift) 98 | 99 | for sub in range(num_subj): 100 | 101 | pat_single_sub = np.reshape(padded_pat_current_gene[sub,:,:], (1, padded_input_len, 4)) 102 | mat_single_sub = np.reshape(padded_mat_current_gene[sub,:,:], (1, padded_input_len, 4)) 103 | 104 | pat_out = model.predict_on_batch(pat_single_sub) 105 | mat_out = model.predict_on_batch(mat_single_sub) 106 | 107 | pat_out = pat_out['human'][0] 108 | mat_out = mat_out['human'][0] 109 | 110 | pat_bins_sum = pat_out[center_bins[0]][track_idx] + pat_out[center_bins[1]][track_idx] + pat_out[center_bins[2]][track_idx] 111 | mat_bins_sum = mat_out[center_bins[0]][track_idx] + mat_out[center_bins[1]][track_idx] + mat_out[center_bins[2]][track_idx] 112 | 113 | gene_res[sub,0,shift] = pat_bins_sum 114 | gene_res[sub,1,shift] = mat_bins_sum 115 | 116 | # get output for rc sequence 117 | rc_pat, rc_mat = rev_comp(pat_current_gene, mat_current_gene) 118 | padded_pat_current_gene, padded_mat_current_gene = pad_genes(rc_pat, rc_mat,shift=0) 119 | 120 | for sub in range(num_subj): 121 | pat_single_sub = np.reshape(padded_pat_current_gene[sub,:,:], (1, 393216, 4)) 122 | mat_single_sub = np.reshape(padded_mat_current_gene[sub,:,:], (1, 393216, 4)) 123 | 124 | pat_out = model.predict_on_batch(pat_single_sub) 125 | mat_out = model.predict_on_batch(mat_single_sub) 126 | 127 | pat_out = pat_out['human'][0] 128 | mat_out = mat_out['human'][0] 129 | 130 | pat_bins_sum = pat_out[center_bins[0]][track_idx] + pat_out[center_bins[1]][track_idx] + pat_out[center_bins[2]][track_idx] 131 | mat_bins_sum = mat_out[center_bins[0]][track_idx] + mat_out[center_bins[1]][track_idx] + mat_out[center_bins[2]][track_idx] 132 | 133 | gene_res[sub,0,7] = pat_bins_sum 134 | gene_res[sub,1,7] = mat_bins_sum 135 | 136 | np.save(save_path + 'data_aug_pred/' + gene_id, gene_res) 137 | 138 | 139 | if __name__ == '__main__': 140 | 141 | # deal with subject indexing 142 | 143 | # all subj 144 | chosen_subs = np.load('/data/aspiro17/enformer_res/sub_lists/all_exp_and_seq_subs.npy') 145 | 146 | # select a random 100 subjs 147 | #chosen_subs = np.load(save_path+ 'sub_lists/second_100_subs.npy') 148 | 149 | vcf_sub_list = np.load(save_path + 'sub_lists/vcf_subs.npy') 150 | 151 | # get indices for randomly chosen subs within the 1161 subs we have vcf data for 152 | idx_in_vcf = [] 153 | for item in chosen_subs: 154 | idx_in_vcf.append(np.where(vcf_sub_list == item)) 155 | 156 | final_idx = [] 157 | for item in idx_in_vcf: 158 | final_idx.append(item[0][0]) 159 | 160 | # load model 161 | model = Enformer(model_path) 162 | 163 | # load TSS info 164 | gene_win_info = pd.read_csv('/data/mostafavilab/bng/rosmapAD/data/expressionData/DLPFC/20220207-bulk-RNAseq/gene-ids-and-positions.tsv', sep = '\t', index_col = 1) 165 | 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument('which_gpu', action="store", default='0') 168 | parser.add_argument('gene_file', action="store", default='test') 169 | args = parser.parse_args() 170 | 171 | os.environ["CUDA_VISIBLE_DEVICES"]=args.which_gpu 172 | genes = np.load(args.gene_file) 173 | 174 | for gene in genes: 175 | print(gene) 176 | chrom = gene_win_info[gene_win_info['gene_id'] == gene]['chr_hg38'] 177 | get_aug_pred(gene, chrom) 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /enformer_analysis/replace_genename.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys, os 3 | 4 | 5 | a = np.genfromtxt('gene-ids-and-positions.tsv', dtype = str, delimiter = '\t', skip_header = 1) 6 | repfile = open(sys.argv[1],'r').readlines() 7 | header = repfile[0].strip().split('\t') 8 | print(header) 9 | names = list(a[:,0]) 10 | nid = [] 11 | for h, head in enumerate(header): 12 | newname = a[names.index(head),1] 13 | if newname in nid: 14 | print(newname) 15 | print(header[nid.index(newname)], head) 16 | nid.append(newname) 17 | 18 | obj = open(os.path.splitext(os.path.split(sys.argv[1])[1])[0]+'ENSG.tsv', 'w') 19 | obj.write('\t'+'\t'.join(np.array(nid))+'\n') 20 | for r in repfile[1:]: 21 | obj.write(r) 22 | 23 | 24 | -------------------------------------------------------------------------------- /enformer_analysis/scatter_correlations.py: -------------------------------------------------------------------------------- 1 | # Plots a scatter plot between the correlation of two predictors to observed expression 2 | #python3 scatter_correlations.py --colors Full_analysis/Prediction_correlationsCageAdultBrainGeneSpecific_CorrelationtoObsRandomNull_tstats.txt --colorscut 1.301 --vlim -1.5,1 3 | 4 | import numpy as np 5 | import sys, os 6 | import matplotlib.pyplot as plt 7 | from matplotlib import cm 8 | from scipy.stats import gaussian_kde, pearsonr 9 | 10 | # Read in correlation of Enformer and PrediXcan 11 | pred = np.genfromtxt('Full_analysis/Prediction_correlationsCageAdultBrain_Allstats.txt', dtype = str) 12 | predX = np.genfromtxt('PrediXcanCorrelationWithExpressionENSG.tsv', dtype = str) 13 | predX = predX[predX[:,1]!= 'NaN'] 14 | 15 | # sort genes of both files so that rows match 16 | a_, sortp = np.unique(pred[:,0], return_index = True) 17 | pred = pred[sortp[np.isin(a_, predX[:,0])]] 18 | a_, sortx = np.unique(predX[:,0], return_index = True) 19 | predX = predX[sortx[np.isin(a_, pred[:,0])]] 20 | # check sorting 21 | print(np.array_equal(pred[:,0],predX[:,0])) 22 | 23 | # Select rows that are plotted against each other 24 | measured, predicted, colors = pred[:,3].astype(float), pred[:,1].astype(float), pred[:,4].astype(float) 25 | mask = ~np.isnan(predicted) 26 | measured, predicted, colors, pred, predX = measured[mask], predicted[mask], colors[mask], pred[mask], predX[mask] 27 | 28 | # set cutoff as R>0.2 ~ pvalue 0.05 29 | #cutoff = np.amin(np.absolute(predicted[pred[:,2].astype(float)<0.01])) 30 | cutoff = 0.2 31 | predX = predX[:,1].astype(float) 32 | 33 | outname = 'PearsonR_PredixCan_Enformer_scatter.jpg' 34 | outname2 = 'PearsonR_PredixCan_absEnformer_scatter.jpg' 35 | 36 | 37 | colors = np.chararray(len(colors), itemsize = 20, unicode = True) 38 | colors[:] = 'indigo' 39 | # Choose colors 40 | if '--colors' in sys.argv: 41 | colorfile = sys.argv[sys.argv.index('--colors')+1] 42 | colors = np.genfromtxt(colorfile, dtype = str) 43 | ic_, csort = np.unique(colors[:,0], return_index = True) 44 | colors = colors[csort[np.isin(colors[csort,0], pred[:,0] )]] 45 | colors = colors[colors[:,-1]!='nan'] 46 | 47 | outname = os.path.splitext(outname)[0] + os.path.splitext(os.path.split(colorfile)[1])[0] + '.jpg' 48 | outname2 = os.path.splitext(outname2)[0] + os.path.splitext(os.path.split(colorfile)[1])[0] + '.jpg' 49 | 50 | pnames, mask = np.unique(pred[:,0], return_index = True) 51 | mask = mask[np.isin(np.sort(pnames),colors[:,0])] 52 | measured, predicted, pred, predX = measured[mask], predicted[mask], pred[mask], predX[mask] 53 | 54 | print(np.array_equal(colors[:,0], pred[:,0]), len(colors), len(pred)) 55 | colors = colors[:,-1].astype(float) 56 | # make color binary 57 | if '--colorscut' in sys.argv: 58 | colcut = float(sys.argv[sys.argv.index('--colorscut')+1]) 59 | colors[np.absolute(colors)< colcut] = 0 60 | colors[colors>= colcut] = 1 61 | colors[colors<= -colcut] = -1 62 | print(np.sum(colors == 1), np.sum(colors == -1)) 63 | print(np.sum((np.absolute(colors)==1)*(predicted >0) )) 64 | # cutoff high color values 65 | elif '--colorsmax' in sys.argv: 66 | colcut = float(sys.argv[sys.argv.index('--colorsmax')+1]) 67 | colors[np.abs(colors)>= colcut] = colcut 68 | elif '--logcolors' in sys.argv: 69 | colors = np.sign(colors)*np.log(np.abs(colors)+1) 70 | 71 | # generate colors from density 72 | elif '--density' in sys.argv: 73 | colors = gaussian_kde(np.array([measured, np.abs(predicted)]))(np.array([measured, predicted])) 74 | 75 | # Set other matplotlib parameters for the scatter plot 76 | vmin, vmax = -1, 1 77 | if '--vlim' in sys.argv: 78 | vmin, vmax = sys.argv[sys.argv.index('--vlim')+1].split(',') 79 | vmin, vmax = float(vmin), float(vmax) 80 | 81 | lw = 0 82 | if '--linewidth' in sys.argv: 83 | lw = float(sys.argv[sys.argv.index('--linewidth')+1]) 84 | 85 | alpha = 0.7 86 | if '--alpha' in sys.argv: 87 | alpha = float(sys.argv[sys.argv.index('--alpha')+1]) 88 | 89 | 90 | cmap = 'Purples' 91 | if '--cmap' in sys.argv: 92 | cmap = sys.argv[sys.argv.index('--cmap')+1] 93 | 94 | # print some stats 95 | print('\nHighest Enformer') 96 | for s in np.argsort(predicted)[::-1][:10]: 97 | print(pred[s,0], predX[s], predicted[s]) 98 | 99 | print('\nLowest Enformer') 100 | for s in np.argsort(predicted)[:10]: 101 | print(pred[s,0], predX[s], predicted[s]) 102 | 103 | print('\nHighest PrediXcan') 104 | for s in np.argsort(predX)[::-1][:20]: 105 | print(pred[s,0], predX[s], predicted[s]) 106 | 107 | print('\nMean', np.mean(predicted), np.mean(predX)) 108 | print('Above', cutoff, np.sum(predicted>cutoff), np.sum(predX>cutoff), 'out of', len(predicted)) 109 | 110 | bothmask = (np.abs(predicted)>cutoff) * (predX>cutoff) 111 | onlyx = (np.abs(predicted)cutoff) 112 | nonex = (np.abs(predicted)=-np.log10(0.05)])) 34 | print(cut, 10**-cut) 35 | 36 | 37 | mask = ~np.isnan(predicted) * ~np.isnan(colors) * ~np.isnan(measured) 38 | measured, predicted, colors, pred = measured[mask], predicted[mask], colors[mask], pred[mask] 39 | 40 | print(len(measured)) 41 | 42 | for s in np.argsort(predicted)[::-1][:20]: 43 | print(pred[s,0], measured[s], predicted[s]) 44 | print('\n') 45 | for s in np.argsort(predicted)[:20]: 46 | print(pred[s,0], measured[s], predicted[s]) 47 | 48 | cmap = 'inferno' 49 | vmin, vmax = 0, int(np.amax(np.abs(colors))) 50 | print(np.amin(colors)) 51 | if np.amin(colors) < 0: 52 | vmin = -vmax 53 | 54 | # Choose to color a set of genes darker than the unselected set 55 | if '--genesetcolor' in sys.argv: 56 | colors = np.zeros(len(measured))+0.2 57 | geneset = np.genfromtxt(sys.argv[sys.argv.index('--genesetcolor')+1], dtype = str) 58 | colors[np.isin(pred[:,0], geneset)] = 1.15 59 | outname += os.path.splitext(os.path.split(sys.argv[sys.argv.index('--genesetcolor')+1])[1])[0] 60 | vmin, vmax = 0, int(np.ceil(np.amax(colors))) 61 | # Or color scatters based on a continuous value 62 | elif '--colors' in sys.argv: 63 | colors = np.genfromtxt(sys.argv[sys.argv.index('--colors')+1], dtype = str) 64 | colors = colors[np.argsort(colors[:,0])[np.isin(np.sort(colors[:,0]), pred[:,0])]] 65 | print(np.array_equal(colors[:,0], pred[:,0])) 66 | print(len(colors), len(pred)) 67 | colors = np.nan_to_num(colors[:, int(sys.argv[sys.argv.index('--colors')+2])].astype(float)) 68 | outname += os.path.splitext(os.path.split(sys.argv[sys.argv.index('--colors')+1])[1])[0]+sys.argv[sys.argv.index('--colors')+2] 69 | if np.amin(colors) < 0: 70 | vmin, vmax = -np.amax(np.abs(colors)), np.amax(np.abs(colors)) 71 | cmap = 'BrBG' 72 | else: 73 | vmin, vmax = 0, np.amax(np.abs(colors)) 74 | # Set amin and amax based on 99 quartile 75 | if '--quartilecolor' in sys.argv: 76 | if np.amin(colors) < 0: 77 | vmin, vmax = -np.percentile(np.abs(colors),99), np.percentile(np.abs(colors), 99) 78 | cmap = 'BrBG' 79 | else: 80 | vmin, vmax = 0, np.percentile(np.abs(colors),99) 81 | print(vmin,vmax) 82 | 83 | 84 | 85 | 86 | # Adjust different pyplot features 87 | if '--colormap' in sys.argv: 88 | cmap = sys.argv[sys.argv.index('--colormap')+1] 89 | 90 | if '--markersize' in sys.argv: 91 | plt.rcParams['lines.markersize'] = float(sys.argv[sys.argv.index('--markersize')+1]) 92 | lw = 0 93 | if '--linewidth' in sys.argv: 94 | lw = 0.3 95 | alpha = 1 96 | if '--alpha' in sys.argv: 97 | alpha = float(sys.argv[sys.argv.index('--alpha')+1]) 98 | 99 | fig = plt.figure(figsize = (4,4.), dpi = 200) 100 | ax = fig.add_subplot(111) 101 | ax.spines['top'].set_visible(False) 102 | ax.spines['right'].set_visible(False) 103 | sort = np.argsort(np.abs(colors)) 104 | print('Mean', np.mean(predicted[sort])) 105 | print('PosSig', int(np.sum((predicted>0) * (measured>cut))), 'NegSig', int(np.sum((predicted<0) * (measured>cut)))) 106 | print('PosSig0.1', int(np.sum((predicted>0.1) * (measured>cut))), 'NegSig', int(np.sum((predicted<-0.1) * (measured>cut)))) 107 | print('PosSig0.2', int(np.sum((predicted>0.2) * (measured>cut))), 'NegSig', int(np.sum((predicted<-0.2) * (measured>cut)))) 108 | 109 | a = ax.scatter(measured[sort], predicted[sort], cmap = cmap, vmin = vmin, vmax = vmax, c=colors[sort], alpha = alpha, marker = 'o', lw = lw, edgecolor ='silver' ) 110 | ax.plot([cut,cut],[np.amin(predicted), np.amax(predicted)], color = 'r', ls = '--') 111 | ax.plot([np.amin(measured), np.amax(measured)], [0, 0], color = 'grey', ls = '--') 112 | ax.set_xlabel('Log10(p-value)') 113 | ax.set_ylabel('Pearson R') 114 | if '--genesetcolor' not in sys.argv: 115 | fig.colorbar(a, pad = -0.1, fraction = 0.09, shrink = 0.25, aspect = 2, anchor = (0.,0.99), ticks = [0, int(vmax)], label = None) 116 | print(outname) 117 | fig.savefig(outname + '.jpg', dpi = 450, bbox_inches = 'tight') 118 | 119 | # Print out set of genes that are above cut and also have abs(y-axis) > pcut 120 | if '--printset' in sys.argv: 121 | pcut = float(sys.argv[sys.argv.index('--printset')+1]) 122 | for s in np.where((measured>cut) * (np.abs(predicted) > pcut))[0]: 123 | print(pred[s,0], measured[s], predicted[s], colors[s]) 124 | 125 | 126 | -------------------------------------------------------------------------------- /enformer_analysis/scatterplot.py: -------------------------------------------------------------------------------- 1 | # script to plot to values against each other and color color in different ways 2 | # python3 scatterplot.py Prediction_correlationsCageAdultBrain_Allstats.txt ism_res/Refpred.txt "CAGE,brain,adult,MeanSumlog10+-2indv" "CAGE,brain,adult,log10sum+-1ref" --columns -2 -1 --density --alpha 0.5 --label --filternan --linewidth 0.1 --log10y 3 | 4 | import numpy as np 5 | import sys, os 6 | import matplotlib.pyplot as plt 7 | from scipy.stats import pearsonr 8 | from scipy.spatial.distance import cdist 9 | 10 | def read(file, column = -1): 11 | indv, val = [], [] 12 | for l, line in enumerate(open(file, 'r')): 13 | if line[0] != '#': 14 | line = line.strip().split() 15 | indv.append(line[0]) 16 | val.append(float(line[column])) 17 | indv, val = np.array(indv), np.array(val) 18 | _, sort = np.unique(indv, return_index = True) 19 | indv, val = indv[sort], val[sort] 20 | return indv, val 21 | 22 | # Columns that will be selected from the txt files 23 | column0, column1 = -1, -1 24 | if '--columns' in sys.argv: 25 | column0, column1 = int(sys.argv[sys.argv.index('--columns')+1]), int(sys.argv[sys.argv.index('--columns')+2]) 26 | # Read two text files 27 | print(sys.argv[1], column0) 28 | indv1, val1 = read(sys.argv[1], column0) 29 | indv2, val2 = read(sys.argv[2], column1) 30 | 31 | # Determinne labels of x and y-axis 32 | xname = sys.argv[3] 33 | yname = sys.argv[4] 34 | 35 | outname = os.path.splitext(sys.argv[1])[0]+'_vs_'+os.path.splitext(os.path.split(sys.argv[2])[1])[0]+'_scatter' 36 | if column0 != -1 or column1 != -1: 37 | outname += str(column0)+'-'+str(column1) 38 | 39 | # sort and align values to each other 40 | sort1 = np.isin(indv1, indv2) 41 | sort2 = np.isin(indv2, indv1) 42 | 43 | indv1, val1 = indv1[sort1], val1[sort1] 44 | indv2, val2 = indv2[sort2], val2[sort2] 45 | 46 | # check if sorting was successful 47 | print(np.array_equal(indv1, indv2), len(indv1), len(indv2)) 48 | 49 | # remove nans 50 | if '--filternan' in sys.argv: 51 | mask = ~np.isnan(val1) * ~np.isnan(val2) 52 | indv1, indv2, val1, val2 = indv1[mask], indv2[mask], val1[mask], val2[mask] 53 | 54 | if '--minx' in sys.argv: 55 | mask = val1 > float(sys.argv[sys.argv.index('--minx')+1]) 56 | indv1, indv2, val1, val2 = indv1[mask], indv2[mask], val1[mask], val2[mask] 57 | 58 | if '--miny' in sys.argv: 59 | mask = val2 > float(sys.argv[sys.argv.index('--miny')+1]) 60 | indv1, indv2, val1, val2 = indv1[mask], indv2[mask], val1[mask], val2[mask] 61 | 62 | 63 | if '--absy' in sys.argv: 64 | val2 = np.abs(val2) 65 | if '--absx' in sys.argv: 66 | val1 = np.abs(val1) 67 | if '--log10y' in sys.argv: 68 | val2 = np.log10(val2+1) 69 | if '--log10x' in sys.argv: 70 | val1 = np.log10(val1+1) 71 | 72 | 73 | if '--percentilex' in sys.argv: 74 | perc = float(sys.argv[sys.argv.index('--percentilex')+1]) 75 | pmin, pmax = np.percentile(val1, [100.-perc, perc]) 76 | mask = (val1 < pmax) * (val1 > pmin) 77 | indv1, indv2, val1, val2 = indv1[mask], indv2[mask], val1[mask], val2[mask] 78 | 79 | if '--percentiley' in sys.argv: 80 | perc = float(sys.argv[sys.argv.index('--percentiley')+1]) 81 | pmin, pmax = np.percentile(val2, [100.-perc, perc]) 82 | mask = (val2 < pmax) * (val2 > pmin) 83 | indv1, indv2, val1, val2 = indv1[mask], indv2[mask], val1[mask], val2[mask] 84 | 85 | 86 | colors = np.ones(len(indv1))*0.6 87 | edgecolors = np.chararray(len(indv1), itemsize = 20, unicode = True) 88 | edgecolors[:] = 'grey' 89 | markers = np.array(['o' for i in range(len(indv1))]) 90 | 91 | # assign marker types based on drivers or other features 92 | if '--assigntype' in sys.argv: 93 | drivertype = open(sys.argv[sys.argv.index('--assigntype')+1], 'r').readlines() # to select and assign type 94 | if len(drivertype) > 0: 95 | drivertype = np.array([line.strip().split() for line in drivertype]) 96 | allctypes = drivertype[:,-2:].astype(float) 97 | drivertype = drivertype[:,0] #.astype(int) 98 | # types: 99 | # 0: correct (ar<0&eq<0 or ar>0&eq>0) 100 | # 1: false positive (ar>0&eq<0) 101 | # 2: false negative (ar<0&eq>0) 102 | for c, ctypes in enumerate(allctypes): 103 | if (ctypes[0]<0 and ctypes[1]<0) or (ctypes[0]>0 and ctypes[1]>0): 104 | edgecolors[list(indv1).index(drivertype[c])] = 'r' 105 | markers[list(indv1).index(drivertype[c])] = 's' 106 | elif ctypes[0]>0 and ctypes[1]<0: 107 | edgecolors[list(indv1).index(drivertype[c])] = 'r' 108 | markers[list(indv1).index(drivertype[c])] = '^' 109 | elif ctypes[0]<0 and ctypes[1]>0: 110 | edgecolors[list(indv1).index(drivertype[c])] = 'r' 111 | markers[list(indv1).index(drivertype[c])] = 'v' 112 | 113 | # assign main driver 114 | if '--assignmain' in sys.argv: 115 | maindriver = open(sys.argv[sys.argv.index('--assignmain')+1], 'r').readlines() # to select and assign type 116 | if len(drivertype) > 0: 117 | maindriver = np.array([line.strip().split() for line in maindriver]) 118 | edgecolors[list(indv1).index(maindriver[np.argmax(maindriver[:,-3].astype(float)),0])] = 'magenta' 119 | 120 | vmin, vmax = 0, 1 121 | cmap = 'Blues' 122 | # assign color to dots 123 | if '--colors' in sys.argv: 124 | colors = np.genfromtxt(sys.argv[sys.argv.index('--colors')+1], dtype = str) 125 | colors = colors[np.argsort(colors[:,0])[np.isin(np.sort(colors[:,0]),indv1)]] 126 | colors = colors[:,int(sys.argv[sys.argv.index('--colors')+2])].astype(float) 127 | if np.amin(colors) < 0: 128 | vmin, vmax = -np.amax(np.absolute(colors)), np.amax(np.absolute(colors)) 129 | cmap = 'RdBu_r' 130 | # color based on density of dots 131 | if '--density' in sys.argv: 132 | from scipy.stats import gaussian_kde 133 | colors = gaussian_kde(np.array([val1, val2]))(np.array([val1, val2])) 134 | vmin, vmax = np.amin(colors), np.amax(colors) 135 | cmap = 'viridis' 136 | # If clusters are loaded, all dots in a cluster will be connected by a line 137 | if '--connect_clusters' in sys.argv: 138 | clusters = np.genfromtxt(sys.argv[sys.argv.index('--connect_clusters')+1], dtype = str) 139 | clusters = clusters[np.argsort(clusters[:,0])[np.isin(np.sort(clusters[:,0]),indv1)]] 140 | 141 | fig = plt.figure(figsize = (3.8,3.5), dpi = 200) 142 | ax = fig.add_subplot(111) 143 | ax.spines['top'].set_visible(False) 144 | ax.spines['right'].set_visible(False) 145 | 146 | if np.amin(val1) < 0: 147 | ax.plot([0,0],[np.amin(val2), np.amax(val2)], color = 'k', ls = '--') 148 | if np.amin(val2) < 0: 149 | ax.plot([np.amin(val1), np.amax(val1)],[0,0], color = 'k', ls = '--') 150 | 151 | if '--connect_clusters' in sys.argv: 152 | ucluster, nucluster = np.unique(clusters, return_counts = True) 153 | ucluster = ucluster[nucluster > 1] 154 | for c in ucluster: 155 | allgood = True 156 | mask = np.where(clusters == c)[0] 157 | if '--assigntype' in sys.argv: 158 | if (edgecolors[mask] != 'grey').any(): 159 | allgood = True 160 | else: 161 | allgood = False 162 | if len(mask) > 2 and allgood: 163 | distmat = cdist(np.array([val1[mask], val2[mask]]).T,np.array([val1[mask], val2[mask]]).T ,'euclidean') 164 | maxdist = np.amax(distmat)+1 165 | np.fill_diagonal(distmat, maxdist) 166 | count = np.array(np.where(distmat == np.amin(distmat))).T[0] 167 | hascon = np.zeros(len(distmat)) 168 | ax.plot(val1[mask][count], val2[mask][count], color = 'royalblue', lw = 0.5) 169 | distmat[count] = distmat[count[::-1]] = maxdist 170 | hascon[count] += 1 171 | while True: 172 | conn = np.array(np.where(distmat == np.amin(distmat[count][hascon[count]<2]))).T 173 | connmask = np.sum(np.isin(conn, count),axis =1) == 1 174 | conn = conn[connmask] 175 | conn = conn[np.sum(np.isin(conn,np.where(hascon < 2)[0]),axis = 1)==2] 176 | conn = conn[0] 177 | ax.plot(val1[mask][conn], val2[mask][conn], color = 'royalblue', lw = 0.5) 178 | distmat[conn] = distmat[conn[::-1]] = maxdist 179 | hascon[conn] += 1 180 | count = np.unique(np.append(count, conn)) 181 | if len(count) == len(mask): 182 | break 183 | elif allgood: 184 | ax.plot(val1[mask], val2[mask], color = 'royalblue', lw = 0.5) 185 | 186 | 187 | lw = 0.5 188 | if '--linewidth' in sys.argv: 189 | lw = float(sys.argv[sys.argv.index('--linewidth')+1]) 190 | 191 | alpha = 1. 192 | if '--alpha' in sys.argv: 193 | alpha = float(sys.argv[sys.argv.index('--alpha')+1]) 194 | 195 | label = None 196 | if '--label' in sys.argv: 197 | pears, pval = pearsonr(val1, val2) 198 | label = 'R='+str(round(pears,3))+'\np='+str(round(pval,4)) 199 | marksort, marksortn = np.unique(markers, return_counts = True) 200 | marksort = marksort[np.argsort(-marksortn)] 201 | for mark in marksort: 202 | mask = np.where(markers == mark)[0] 203 | mask = mask[np.argsort(np.absolute(colors[mask]))] 204 | ax.scatter(val1[mask], val2[mask], cmap = cmap, vmin = vmin, vmax = vmax, c = colors[mask], alpha = alpha, lw = lw, edgecolor = list(edgecolors[mask]), marker = mark, label = label) 205 | 206 | ax.set_xlabel(xname) 207 | ax.set_ylabel(yname) 208 | if label is not None: 209 | ax.legend() 210 | 211 | # Print stats from every quadrant 212 | if '--print_quadrant1' in sys.argv: 213 | cuts = sys.argv[sys.argv.index('--print_quadrant1')+1].split(',') 214 | mask = (val1float(cuts[1])) 215 | for i in np.where(mask)[0]: 216 | print(indv1[i], val1[i], val2[i]) 217 | 218 | if '--print_quadrant2' in sys.argv: 219 | cuts = sys.argv[sys.argv.index('--print_quadrant2')+1].split(',') 220 | mask = (val1>float(cuts[0])) * (val2>float(cuts[1])) 221 | for i in np.where(mask)[0]: 222 | print(indv1[i], val1[i], val2[i]) 223 | 224 | if '--print_quadrant3' in sys.argv: 225 | cuts = sys.argv[sys.argv.index('--print_quadrant3')+1].split(',') 226 | mask = (val1>float(cuts[0])) * (val2 0, axis = 1).astype(int) 70 | snpper = snpcount/len(snp[0]) 71 | snpper = np.around(100*snpper,1) 72 | 73 | # compute the sum linear approximated prediction from attributions times genotype 74 | imp = snp.T * avals 75 | impm = np.sum(imp, axis = 1) 76 | if '--logsum' in sys.argv: 77 | impm = np.log10(impm + 1 - np.amin(impm)) 78 | # compute the pearson correlation between the linear approximation and the full prediction 79 | peartot = pearsonr(impm, pexp)[0] 80 | 81 | print(gene, round(peartot,2)) 82 | 83 | # Forward selection adds SNVs to sum in order of their absolute attribution 84 | # A driver is an SNV that significantly correlates with the predicted values and also changes the correlation of the partial approximation significantly 85 | if '--forward' in sys.argv: 86 | obj = open(os.path.splitext(sys.argv[1])[0]+'_'+addname+'driversfw.txt','w') 87 | old_pearsub = 0 88 | previous = np.zeros(len(impm)) 89 | for i in range(len(s)): 90 | impn = np.sum(imp[:,:i+1], axis = 1) 91 | if '--logsum' in sys.argv: 92 | impn = np.log10(impn + 1-np.amin(impn)) 93 | pearsub = pearsonr(impn, pexp)[0] 94 | pearexp = pearsub/peartot 95 | impi = imp[:,i] 96 | if '--logsum' in sys.argv: 97 | impi = np.log10(1+impi-np.amin(impi)) 98 | peari = pearsonr(impi, pexp) 99 | peari, pearisig = peari 100 | isdriver = False 101 | if (pearexp - old_pearsub) > 0.05 and peari > 0 and pearisig < 0.01/len(s): 102 | impj = np.sum(imp[:,np.arange(np.shape(imp)[1]) != i],axis = 1) 103 | if '--logsum' in sys.argv: 104 | impj = np.log10(1+impj-np.amin(impj)) 105 | pearj = (peartot - pearsonr(impj, pexp)[0])/peartot 106 | print(i, str(int(avars[i]))+' '+str(round(avals[i],5))+' '+str(round(peari,2))+' '+ str(round(pearsub,2))+' '+ str(round(pearexp-old_pearsub,2))+' '+ str(snpcount[i])+' '+str(snpper[i])) 107 | obj.write(str(int(avars[i]))+' '+str(round(avals[i],5))+' '+str(round(peartot,2))+' '+str(round(peari,2))+' '+ str(round(pearsub,2))+' '+str(round(pearj,2))+' '+ str(round(pearexp-old_pearsub,2))+' '+ str(snpcount[i])+' '+str(snpper[i])+'\n') 108 | isdriver = True 109 | 110 | # For understanding: Each selection step can be plotted 111 | if '--plot_test' in sys.argv: 112 | tid = int(sys.argv[sys.argv.index('--plot_test')+1]) 113 | if i < tid: 114 | fig = plt.figure(figsize = (4,4)) 115 | ax = fig.add_subplot(111) 116 | ax.spines['top'].set_visible(False) 117 | ax.spines['right'].set_visible(False) 118 | ax.set_ylabel('Enformer prediction') 119 | ax.set_xlabel('Sum prediction') 120 | ax.set_title(gene+'-'+str(int(avars[i]))+'('+str(i)+')\n'+str(round(avals[i],2))+' Drv=:'+str(isdriver)[0]) 121 | ax.scatter(impm,pexp,color = 'grey', label = 'Full sum '+str(round(peartot,2))) 122 | ax.scatter(impi,pexp, color = 'goldenrod', label = 'Only SNP '+str(round(peari,2)), alpha = 0.4) 123 | #ax.scatter(previous, pexp,color='purple', label = 'Sum until SNP-1 '+str(round(old_pearsub,2)), alpha = 0.5) 124 | ax.scatter(impn, pexp,color='navy', label = 'Sum until SNP '+str(round(pearsub,2)), alpha = 0.5) 125 | ax.legend() 126 | fig.savefig(os.path.splitext(sys.argv[1])[0]+'_'+addname+'driversfw'+str(i)+'.jpg', dpi = 200, bbox_inches = 'tight') 127 | print(os.path.splitext(sys.argv[1])[0]+'_'+addname+'driversfw'+str(i)+'.jpg') 128 | plt.close() 129 | old_pearsub = pearexp 130 | previous = impn 131 | 132 | # Reverse selection removes SNVs form the sum and computes the change of correlation to predicted values 133 | # In every round the SNV with largest reduction of the correlation is selected and the SNV permanently removed from the sum 134 | # In the next round the impact of SNVs is computed on the residual sum without drivers from the previous rounds. 135 | # The process is repeated until the correlation to the predicted values is none-significant 136 | else: 137 | obj = open(os.path.splitext(sys.argv[1])[0]+'_'+addname+'driversbw.txt','w') 138 | drivers = [] 139 | mask = ~np.isin(np.arange(len(s), dtype = int),drivers) 140 | pearcomp = pearsonr(impm, pexp)[0] 141 | pearc = pearsonr(impm, pexp) 142 | indvimpact = np.ones(len(s)) 143 | indvpear = np.ones(len(s)) 144 | indvpval = np.zeros(len(s)) 145 | removed = 0 146 | rounds = 0 147 | while pearc[0] > 0 and pearc[1] < 0.01/len(s): 148 | pdriver = [] 149 | for i in np.where(mask)[0]: 150 | if indvpear[i] > 0 and indvpval[i] < 0.01/len(s): 151 | nmask = np.copy(mask) 152 | nmask[i] = False 153 | if rounds == 0: 154 | impi = imp[:,i] 155 | if '--logsum' in sys.argv: 156 | impi = np.log10(1+impi-np.amin(impi)) 157 | peari = pearsonr(impi, pexp) 158 | peari, pearisig = peari 159 | indvpear[i] = peari 160 | indvpval[i] = pearisig 161 | impn = np.sum(imp[:,nmask], axis = 1) 162 | if '--logsum' in sys.argv: 163 | impn = np.log10(impn + 1-np.amin(impn)) 164 | pearsub = pearsonr(impn, pexp)[0] 165 | indvimpact[i] = 1.-(pearsub/pearcomp) 166 | if pearisig < 0.01/len(s) and peari > 0: 167 | pdriver.append([i,1.-(pearsub/pearcomp),peari]) 168 | else: 169 | impn = np.sum(imp[:,nmask], axis = 1) 170 | if '--logsum' in sys.argv: 171 | impn = np.log10(impn + 1-np.amin(impn)) 172 | pearsub = pearsonr(impn, pexp)[0] 173 | pdriver.append([i,1.-(pearsub/pearcomp)-removed,indvpear[i]]) 174 | 175 | if len(pdriver) == 0: 176 | break 177 | pdriver = np.array(pdriver) 178 | maindriver = np.argmax(pdriver[:,1]) 179 | drivers.append(pdriver[maindriver]) 180 | removed += pdriver[maindriver][1] 181 | rounds += 1 182 | mask[int(drivers[-1][0])] = False 183 | impn = np.sum(imp[:,mask], axis = 1) 184 | if '--logsum' in sys.argv: 185 | impn = np.log10(impn + 1-np.amin(impn)) 186 | pearc = pearsonr(impn, pexp) 187 | print('Round', rounds, pearc, pearcomp) 188 | 189 | mask = np.ones(len(s)) == 1 190 | for dr in drivers: 191 | i, pearsub, peari = dr 192 | i = int(i) 193 | print(i,str(int(avars[i]))+' '+str(round(avals[i],5))+' '+str(round(peari,2))+' '+ str(round(pearsub,2))+' '+ str(round(indvimpact[i],2))+' '+ str(snpcount[i])+' '+str(snpper[i])) 194 | obj.write(str(int(avars[i]))+' '+str(round(avals[i],5))+' '+str(round(peari,2))+' '+ str(round(pearsub,2))+' '+ str(round(indvimpact[i],2))+' '+ str(snpcount[i])+' '+str(snpper[i])+'\n') 195 | # This process can visualized for every round 196 | if '--plot_test' in sys.argv: 197 | print(i) 198 | mask[i] = False 199 | impn = np.sum(imp[:,mask], axis = 1) 200 | if '--logsum' in sys.argv: 201 | impn = np.log10(impn + 1-np.amin(impn)) 202 | fig = plt.figure(figsize = (4,4)) 203 | ax = fig.add_subplot(111) 204 | ax.spines['top'].set_visible(False) 205 | ax.spines['right'].set_visible(False) 206 | ax.set_ylabel('Enformer prediction') 207 | ax.set_xlabel('Sum prediction') 208 | ax.set_title(gene+'-'+str(int(avars[i]))+'('+str(i)+')\n'+str(round(avals[i],2))+' Drv='+str(isdriver)[0]) 209 | ax.scatter(impm, pexp,color = 'grey', label = 'Full sum '+str(round(peartot,2))) 210 | ax.scatter(impn, pexp,color='navy', label = 'Sum without SNP '+str(round(pearsub,2))) 211 | ax.scatter(impi,pexp, color = 'goldenrod', label = 'Only SNP '+str(round(peari,2))) 212 | ax.legend() 213 | fig.savefig(os.path.splitext(sys.argv[1])[0]+'_'+addname+'driversbw'+str(i)+'.jpg', dpi = 150) 214 | print(os.path.splitext(sys.argv[1])[0]+'_'+addname+'driversbw'+str(i)+'.jpg') 215 | plt.close() 216 | 217 | 218 | -------------------------------------------------------------------------------- /fine_tuned_model/README.md: -------------------------------------------------------------------------------- 1 | This directory contains the parameters and Matlab script that was used to predict fine-tuned gene expression using the Enformer model. 2 | 3 | The Enformer model used can be found in https://github.com/deepmind/deepmind-research/tree/master/enformer. 4 | We separately predict gene expression for the paternal and maternal sequence for each individual. 5 | The script enformerFinetuned.m is used to finetune the predicted gene expression by assigning a weight to each of Enformer's 5313 output tracks. 6 | -------------------------------------------------------------------------------- /fine_tuned_model/enformerFinetuned.m: -------------------------------------------------------------------------------- 1 | %% Finetune Enformer output with GTEx mean expression model 2 | % Input: enformPat = #subjects x 5313 tracks enformer output matrix of a gene for paternal sequences 3 | % enformMat = #subjects x 5313 tracks enformer output matrix of a gene for maternal sequences 4 | % modelPath = Path where GTEx finetuned model is stored 5 | % Output: exprPred = #subjects x 1 predicted gene expression levels 6 | % Note: Assumes enformPat and enformMat = log2 of enformer outputs AFTER summing over output windows e.g. +/-2 bins from center 7 | function exprPred = enformerFinetuned(enformPat,enformMat,modelPath) 8 | %% Load mean expression model 9 | fid = fopen(modelPath); 10 | header = regexp(fgetl(fid),',','split'); 11 | temp = textscan(fid,'%s%f','Delimiter',','); 12 | fclose(fid); 13 | wt = temp{2}; 14 | beta = wt(1:end-1); 15 | a0 = wt(end); 16 | 17 | %% Enformer outputs for Paternal 18 | exprPredPat = enformPat*beta+a0; 19 | 20 | %% Enformer outputs for Maternal 21 | exprPredMat = enformMat*beta+a0; 22 | 23 | %% Combine paternal and maternal effect 24 | exprPred = exprPredPat+exprPredMat; 25 | -------------------------------------------------------------------------------- /process_genomic_data/Genotype_AMPAD_WGS_04_phasing_Eagle_script.sh: -------------------------------------------------------------------------------- 1 | wget https://data.broadinstitute.org/alkesgroup/Eagle/downloads/Eagle_v2.4.1.tar.gz 2 | tar xzvf Eagle_v2.4.1.tar.gz 3 | 4 | mkdir input 5 | cd input 6 | gsutil cp gs://shinya_test/Genotype/AMPAD_WGS/01_qc/output/pca_outliers.txt ./ 7 | 8 | PLINK_HOME=/data/shinya/temp/ 9 | BFILE=../../temp/NIA_JG_1898_samples_GRM_WGS_b37_JointAnalysis01_2017-12-08.recalibrated_variants_hg38_dn.QCed1 10 | $PLINK_HOME/plink --bfile $BFILE --remove pca_outliers.txt --keep-allele-order --split-x hg38 --make-bed --out plink 11 | for i in `seq 1 23`; 12 | do 13 | #$PLINK_HOME/plink --bfile plink --remove pca_outliers.txt --keep-allele-order --real-ref-alleles --chr $i --maf 0.0004 --hwe 1e-50 --geno 0.1 --recode vcf --out chr$i 14 | bgzip -c chr$i.vcf > chr$i.vcf.gz 15 | done 16 | 17 | for i in $(seq 1 23); do 18 | echo "../Eagle_v2.4.1/eagle --vcf=chr$i.vcf --geneticMapFile=../Eagle_v2.4.1/tables/genetic_map_hg38_withX.txt.gz --numThreads 8 --vcfOutFormat=z --outPrefix=../output/chr$i.phased > ../output/chr$i.phased.log" >> myCommands.txt 19 | done 20 | 21 | mkdir ../output 22 | less myCommands.txt | parallel -j8 {} 23 | 24 | 25 | # combine 26 | /home/shinya/Resource/GENETICS_Resource/softwares/bcftools-1.10.2/bcftools concat -Oz /data/shinya/temp_eagle/output/*.vcf.gz > /data/shinya/temp_eagle/output/chrAll.phased.vcf.gz 27 | /home/shinya/Resource/GENETICS_Resource/softwares/bcftools-1.10.2/bcftools index /data/shinya/temp_eagle/output/chrAll.phased.vcf.gz --threads 8 28 | ~/Resource/NGS_Resource/softwares/htslib-1.9/tabix -p vcf /data/shinya/temp_eagle/output/chrAll.phased.vcf.gz 29 | 30 | # plink 31 | PLINK_HOME=/data/shinya/temp/ 32 | $PLINK_HOME/plink --vcf /data/shinya/temp_eagle/output/chrAll.phased.vcf.gz --keep-allele-order --make-bed --out /data/shinya/temp_eagle/output/chrAll.phased 33 | 34 | -------------------------------------------------------------------------------- /process_genomic_data/README.md: -------------------------------------------------------------------------------- 1 | This directory contains scripts used to process the genomic data used to make and analyze predictions for differential gene expression between 839 individuals using the Enformer model. 2 | 3 | ## Genotype_AMPAD_WGS_04_phasing_Eagle_script.sh 4 | - create phased VCF files 5 | 6 | ## save_ref_seqs.py 7 | - save onehot-encoded reference genomic sequences within a window (here, 100k) around the TSS from fasta files 8 | - fasta files can be downloaded from UCSC genomc browser, for ex, rsync -avzP rsync://hgdownload.cse.ucsc.edu/goldenPath/hg38/chromosomes/chr1.fa.gz ./ 9 | 10 | ## save_snp_info.py 11 | - from vcf, save variant information by gene for each SNP within a window (here, 100k) of the TSS 12 | 13 | ## extractSeq.py 14 | - using reference genomic fasta files (downloaded as described in save_ref_seqs.py) and variant information by gene (saved as described in save_snp_info.py), create personalized one-hot encodings, saved by gene as sparse .npz files using the package sparse (https://anaconda.org/conda-forge/sparse) 15 | - variant information by gene is accessed in snp = pd.read_csv(filepath+'variantNucleotide'+winSiz+'/'+data.iloc[j,0]+'.csv',header=None,encoding='latin1') 16 | - when these one hot encodings are loaded using sparse (as in EnformerAssessment/enformer_analysis/basic_pred_gene_expr.py), they are of shape [8,200001,1161], where the first dimension of size 8 contains the 4 bases for paternal and the 4 bases for maternal sequences, the second dimension is the length, and the third 1161 individuals 17 | 18 | -------------------------------------------------------------------------------- /process_genomic_data/extractSeq.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import gzip 4 | import sparse 5 | import sys 6 | 7 | pittPath = '/bgfs/mchikina/byng/' 8 | filepath = pittPath+'rosmapAD/projects/insilicoMutagenesis/extractSequence/results/' 9 | refpath = pittPath+'humanRefGenome/data/hg38/' 10 | 11 | # Parameters 12 | #ch = 22 # Select chromosome 13 | ch = int(sys.argv[1]) 14 | gList = sys.argv[2] 15 | nSym = 8 # ATGCx2 for paternal and maternal 16 | nSub = 1161 17 | win = 1e5 18 | if win==1e4: 19 | winSiz = '10K' 20 | elif win==1e5: 21 | winSiz = '100K' 22 | wpl = 50 # Number of words per line in hg38 23 | 24 | # Load gene names and windows 25 | #data = pd.read_csv(filepath+'geneWin'+winSiz+'.txt',sep='\t',header=None) 26 | data = pd.read_csv(filepath+'geneWin'+winSiz+gList+'.txt',sep='\s+',header=None) 27 | data.columns = ['ensg','chr','winS','winE'] 28 | #data = data.sort_values(by='chr',axis=0) 29 | data = data.loc[data.chr==ch] 30 | nGene = len(data.index) 31 | print('geneWin'+winSiz+gList) # For checking if all genes ran 32 | 33 | # Loop through genes 34 | for j in np.arange(nGene): 35 | print(j) 36 | 37 | # Load reference sequence 38 | with gzip.open(refpath+'chr'+str(ch)+'.fa.gz','rt') as f: 39 | f.readline() # Remove non-sequence line 40 | #ref = f.read().replace('\n','') # Read file as a single string with \n removed but uses too much RAM 41 | start = np.mod(data.iloc[j,2],wpl)-1 42 | nLine = np.floor_divide(data.iloc[j,2],wpl) 43 | if start==-1: # gene j at the end of a line 44 | nLine -= 1 45 | for l in np.arange(nLine): 46 | f.readline() # Get to line where gene j is located 47 | if start==-1: 48 | f.read(wpl-1) # Get to end of the current line 49 | ref = f.read(int(2*win+1+np.floor_divide(2*win+1,wpl)+1)) # Read 2*win+1 bases + \n's + 1 50 | else: 51 | f.read(start) # Get to start of gene j location 52 | ref = f.read(int(2*win+1+np.floor_divide(2*win+1,wpl))) # Read 2*win+1 bases + \n's 53 | 54 | # Extract reference sequence within window 55 | #seqR = np.array(list(ref[data.iloc[j,2]-1:data.iloc[j,3]].upper())) 56 | seqR = np.array(list(ref.replace('\n','').upper())) 57 | 58 | # Check if variants exist 59 | f = open(filepath+'variantNucleotide'+winSiz+'/'+data.iloc[j,0]+'.csv') 60 | line = f.readline() 61 | f.close() 62 | if line != '': 63 | # Load variant nucleotide 64 | snp = pd.read_csv(filepath+'variantNucleotide'+winSiz+'/'+data.iloc[j,0]+'.csv',header=None,encoding='latin1') 65 | 66 | # Extract paternal variants 67 | snpP = snp.replace(['A|A','A|T','A|G','A|C'],'A') 68 | snpP = snpP.replace(['T|A','T|T','T|G','T|C'],'T') 69 | snpP = snpP.replace(['G|A','G|T','G|G','G|C'],'G') 70 | snpP = snpP.replace(['C|A','C|T','C|G','C|C'],'C') 71 | 72 | # Extract maternal variants 73 | snpM = snp.replace(['A|A','T|A','G|A','C|A'],'A') 74 | snpM = snpM.replace(['A|T','T|T','G|T','C|T'],'T') 75 | snpM = snpM.replace(['A|G','T|G','G|G','C|G'],'G') 76 | snpM = snpM.replace(['A|C','T|C','G|C','C|C'],'C') 77 | 78 | # Insert variants to reference sequence 79 | nVar = len(snp.index) 80 | 81 | # Loop over subjects 82 | onehot = np.zeros((nSym,int(2*win+1),nSub),dtype='i8') 83 | for k in np.arange(nSub): 84 | seqP = seqR.copy() 85 | seqM = seqR.copy() 86 | 87 | # Loop over variants 88 | for i in np.arange(nVar): 89 | seqP[snpP.iloc[i,1]-data.iloc[j,2]] = snpP.iloc[i,k+2] 90 | seqM[snpM.iloc[i,1]-data.iloc[j,2]] = snpM.iloc[i,k+2] 91 | 92 | # Convert to one-hot encoding 93 | onehot[:,:,k] = [seqP=='A',seqP=='T',seqP=='G',seqP=='C',seqM=='A',seqM=='T',seqM=='G',seqM=='C'] 94 | else: 95 | print('no variants') 96 | 97 | # Loop over subjects 98 | onehot = np.zeros((nSym,int(2*win+1),nSub),dtype='i8') 99 | for k in np.arange(nSub): 100 | # Convert to one-hot encoding 101 | onehot[:,:,k] = [seqR=='A',seqR=='T',seqR=='G',seqR=='C',seqR=='A',seqR=='T',seqR=='G',seqR=='C'] 102 | 103 | # Convert to sparse and save 104 | onehot = sparse.COO.from_numpy(onehot) 105 | sparse.save_npz(filepath+'/sequence'+winSiz+'/chr'+str(ch)+'/'+data.iloc[j,0],onehot) 106 | -------------------------------------------------------------------------------- /process_genomic_data/save_ref_seqs.py: -------------------------------------------------------------------------------- 1 | # For relevant genes, save reference sequence data from hg38 within 100k of of the TSS into one-hot encoded form 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import gzip 6 | import sparse 7 | import os 8 | import argparse 9 | 10 | WIN = 1e5 11 | NUM_SYM = 4 12 | 13 | win_path='/data/mostafavilab/bng/rosmapAD/projects/insilicoMutagenesis/extractSequence/results/geneWin100K.txt' 14 | save_path = '/data/aspiro17/enformer_res/ref_seqs/' 15 | ref_path = '/data/aspiro17/seqtoexp/hg38/' # using data downloaded from UCSC genome browser 16 | 17 | gene_win_info = pd.read_csv('/data/mostafavilab/bng/rosmapAD/projects/insilicoMutagenesis/extractSequence/results/geneWin100K.txt', sep='\t', header=None) 18 | all_gene_ids = np.loadtxt('/data/aspiro17/enformer_res/gene_lists/genes_to_run.txt',dtype=str) # genes to process 19 | 20 | for CHR in range(1,23): 21 | print(CHR) 22 | 23 | # load ref seq 24 | with gzip.open(ref_path+'chr'+str(CHR)+'.fa.gz','rt') as f: 25 | f.readline() # Remove non-sequence line 26 | ref = f.read().replace('\n','') 27 | seqR = np.array(list(ref.upper())) 28 | 29 | print('chrom loaded') 30 | 31 | chrom_len = len(seqR) 32 | 33 | # for relevant genes on this chrom 34 | rel_genes = gene_win_info[gene_win_info[1] == CHR] 35 | num_genes = rel_genes.shape[0] 36 | 37 | for i in range(num_genes): 38 | print(i) 39 | current_gene = seqR[int(rel_genes.iloc[i][2]):int(rel_genes.iloc[i][3])+1] # take the relevant window 40 | current_onehot = np.zeros((NUM_SYM,int(2*WIN+1)),dtype='i8') 41 | current_onehot[:,:] = [current_gene=='A',current_gene=='T',current_gene=='G',current_gene=='C'] # note: this order is different from what is used in Enformer, we adjust the onehot later to match that order 42 | 43 | np.save(save_path+rel_genes.iloc[i,0]+'.npy',current_onehot) 44 | 45 | 46 | -------------------------------------------------------------------------------- /process_genomic_data/save_snp_info.txt: -------------------------------------------------------------------------------- 1 | # From vcf, save variant information by gene for each SNP within 100k of the TSS 2 | 3 | base_path=/data/mostafavilab/bng/rosmapAD/projects/insilicoMutagenesis/extractSequence/results/ 4 | vcf_path=/data/mostafavilab/bng/rosmapAD/data/wholeGenomeSeq/ 5 | save_dir=/data/aspiro17/enformer_res/variant_info_100k/ 6 | all_rel_genes_path=/data/aspiro17/enformer_res/gene_lists/ # save information for these genes 7 | 8 | readarray -t all_rel_genes < ${all_rel_genes_path}to_process.txt 9 | 10 | # Creates an array from geneWin*K.txt 11 | ensg=(`cut -f1 ${base_path}geneWin100K.txt`) 12 | #echo $ensg 13 | chr=(`cut -f2 ${base_path}geneWin100K.txt`) 14 | #echo $chr 15 | winStart=(`cut -f3 ${base_path}geneWin100K.txt`) 16 | #echo $winStart 17 | winEnd=(`cut -f4 ${base_path}geneWin100K.txt`) 18 | #echo $winEnd 19 | 20 | # Extract bases for each gene 21 | for i in ${!ensg[@]} # ! returns the indices and @ returns all elements of an array 22 | do 23 | 24 | start=`date +%s` 25 | 26 | echo ${ensg[$i]} 27 | if [[ " ${all_rel_genes[*]} " =~ " ${ensg[$i]} " ]]; then 28 | bcftools view -v snps -r ${chr[$i]}:${winStart[$i]}-${winEnd[$i]} ${vcf_path}chrAll.phased.vcf.gz | bcftools query -f '%CHROM,%REF,%POS[,%TGT]\n' > ${save_dir}${ensg[$i]}.csv 29 | 30 | end=`date +%s` 31 | runtime=$((end-start)) 32 | echo ${runtime} 33 | fi 34 | done 35 | 36 | -------------------------------------------------------------------------------- /runPrediXcan.txt: -------------------------------------------------------------------------------- 1 | # Exemplar PrediXcan script adapted from https://github.com/hakyimlab/MetaXcan/wiki/Individual-level-PrediXcan:-introduction,-tutorials-and-manual 2 | # GTEx PrediXcan models can be found in https://predictdb.org/post/2021/07/21/gtex-v8-models-on-eqtl-and-sqtl/ 3 | 4 | metaxcanPath= 5 | modelPath= 6 | dataPath= 7 | outputPath= 8 | 9 | python $metaxcanPath/Predict.py \ 10 | --model_db_path $modelPath \ 11 | --model_db_snp_key varID \ 12 | --vcf_genotypes $dataPath \ 13 | --vcf_mode genotyped \ 14 | --on_the_fly_mapping METADATA "chr{}_{}_{}_{}_b38" \ 15 | --prediction_output ${outputPath}/prediXcanOutput.txt \ 16 | --prediction_summary_output ${outputPath}/prediXcanSummary.txt \ 17 | --verbosity 9 \ 18 | --throw 19 | 20 | --------------------------------------------------------------------------------