├── .history
├── Cell-type Annotation
│ ├── cta_gpt_20240806004924.py
│ └── cta_gpt_20241102171759.py
├── readme_20241214182326.md
└── seq2emb
│ ├── README_20240519062039.md
│ ├── README_20240519062246.md
│ └── pseudobulk_anndata_20240519062159.py
├── Batch Effect Correction
└── batch_effect_correction.ipynb
├── Cell-type Annotation
├── cta_ft.ipynb
├── cta_gpt.py
└── cta_zeroshot.ipynb
├── Clustering
└── clustering.ipynb
├── Get outputs from LLMs
├── query_35.ipynb
└── test-deepseek-v2-and-embeddings.ipynb
├── In silico treatment
└── in-silico treatment.ipynb
├── Perturbation Analysis
├── CINEMAOT
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── cinemaot.cpython-38.pyc
│ │ └── sinkhorn_knopp.cpython-38.pyc
│ ├── benchmark.py
│ ├── cinemaot.py
│ ├── sinkhorn_knopp.py
│ └── utils.py
├── CPA
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ ├── _api.cpython-39.pyc
│ │ ├── _data.cpython-39.pyc
│ │ ├── _metrics.cpython-39.pyc
│ │ ├── _model.cpython-39.pyc
│ │ ├── _module.cpython-39.pyc
│ │ ├── _plotting.cpython-39.pyc
│ │ ├── _task.cpython-39.pyc
│ │ └── _utils.cpython-39.pyc
│ ├── _api.py
│ ├── _data.py
│ ├── _metrics.py
│ ├── _model.py
│ ├── _module.py
│ ├── _plotting.py
│ ├── _task.py
│ └── _utils.py
├── cinemaot_example.ipynb
├── cpa_example.ipynb
├── gears
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── data_utils.cpython-38.pyc
│ │ ├── gears.cpython-38.pyc
│ │ ├── inference.cpython-38.pyc
│ │ ├── model.cpython-38.pyc
│ │ ├── pertdata.cpython-38.pyc
│ │ ├── utils.cpython-38.pyc
│ │ └── version.cpython-38.pyc
│ ├── data_utils.py
│ ├── gears.py
│ ├── inference.py
│ ├── model.py
│ ├── pertdata.py
│ ├── utils.py
│ └── version.py
└── gears_example.ipynb
├── demo_clustering.ipynb
├── elmo dalle2.png
├── readme.md
├── reproductivity
└── repro_instruction.ipynb
└── seq2emb
├── README.md
├── add_embeddings_to_anndata.py
├── calc_embeddings_and_targets.py
├── create_seq_window_queries.py
├── extract_enformer_targets.py
├── intersect_queries_with_enformer_regions.py
├── preprocessing_example_files
├── enformer_out.h5
├── enformer_out_emb.tsv
├── enformer_out_tar.tsv
├── gencode.v41.basic.annotation.protein.coding.ensembl_canonical.tss.hg38.h10.bed
└── query_tss_example.tsv
└── pseudobulk_anndata.py
/.history/Cell-type Annotation/cta_gpt_20240806004924.py:
--------------------------------------------------------------------------------
1 | # remember to set the OpenAI token in ahead.
2 | from openai import OpenAI
3 | client = OpenAI()
4 |
5 | response = client.chat.completions.create(
6 | model="gpt-4",
7 | messages=[
8 | {"role": "user", "content": "This cell has genes ranked by their expression as: CCDC71L NTS F8 GHSR GRIN2D VNN3 DTX1 SPOCK2 TRPC5 AQP9 GGT1 DUSP23 COL16A1 CCDC3 CH25H PTX3 CADM3 NTRK2 AGR3 LDB2 LRRTM1 FOSL1 PIK3AP1 CHST8 TGFBR2 MBOAT4 BCL2 MYRF GPC1 PPARGC1A SLIT3 DOCK2 SYT1 MFSD2A POLR3B LURAP1L UGT2B7 LYN GALNT2 RASD2 ALDH1A2 F10 C18orf54 CGA HEG1 COL14A1 SLC43A2 NRARP NPNT BMF GCGR SPSB1 RAB34 PRKAR2B TET3 DIAPH3 RAMP2 GLI2 CCNA2 ABCB1 PCDH9 TMEM233 PPP2R2B SOCS2 COX6A2 GALNT7 AMOTL1 CREB3L1 ADAM8 SYBU PRCP RNF186 ITIH4 CACHD1 FAM155A EGF RCC1 MNX1 GGH NTM ZNF180 SLC16A7 NUF2 F2 HOXD8 LTF PTGFR FAM83D SLC2A2 TRIM47 LMO3 TGIF1 HPGD ATP6V0D2 AFAP1L2 FA2H C3orf80 NCF4 SH2D2A HDC RARB FBLN5 HECTD3 GRAP2 PID1 C3AR1 LGR5 MUCL1 FHL3 CASP1 GRIN2A TAX1BP3 CSPG4 ZNF804A PKNOX2 DPEP1 RAMP3 LPAR5 IQGAP2 CMTM3 MET SLCO4A1 ANXA13 IKZF3 CYTH3 FUT3 CABLES1 HNF4G CHRDL1 PROSER2 TUBB1 PTPRT RASIP1 STMN4 GABRA3 SH3KBP1 MERTK CD4 SLC6A8 MOXD1 ASTN1 UGT2A3 HSPBAP1 GATA3 MTMR1 NDRG4 SORD PPIC TRPV2 CPXM1 CPZ CCNF TRPM5 FLNC DUSP22 DDIT4L ID4 GAPT ARHGAP18 SERPINB4 WAS SLC16A14 CD79A ZEB1 PDE2A SLAMF8 LCP2 AKAP5 KIF2C CDR2L PDE10A SPTSSB SYNM CNIH2 ALAS2 C11orf53 MDFI GDNF ARHGEF2 GAS6 FERMT3 PLLP FZD10 PTPN7 ADD2 ADAM9 SIRPA HKDC1 TMSB15A ZMAT4 MECOM SH3PXD2B DTX4 PPARGC1B SLC16A10 THSD4 LPL FCGR3B CLEC2B TSPO SLC9A9 CPNE4 TMEM217 PLEK PON1 ST6GAL1 SOX11 P4HA3 SLC9A3R2 DUOXA1 CDA NNAT CLIC6 ADORA2A LRRN3 CPNE5 RFTN2 STK17A B3GNT3 FAM20C GRK5 DSC2 BIN2 EGLN3 GIMAP8 CHRM3 LRRK2 EFHC2 HES4 CIT TDO2 F5 TIMP4 TH PLP1 CACNA2D4 SLC1A3 PSCA SYNDIG1 GAB2 ORM1 NEURL1B SERPINI1 CPXM2 IL17RB SHMT2 PTHLH ALDH2 TWIST2 CYSLTR2 PIM2 IGSF10 APOBR KIAA1522 ZNF367 SYDE1 ASB12 ACCS DOCK9 DCHS1 MS4A4A DKK1 ACSL1 PGM1 RRN3 HHEX LPAR1 CD86 NID2 SLC39A5 SMAD3 RASL12 CLPSL1 BDKRB1 ZP1 SATB2 GSTM5 GIMAP4 EFEMP1 TNFSF15 ARL14 MYC CD48 C1QTNF1 ZNF831 SH3BP4 HSD11B2 LTC4S PLCH2 ID2 STIM1 LILRB5 SOX7 CAPN5 NPR3 KDELR3 IL1R2 S100A8 VSNL1 TROAP PCDH8 BAALC TUBB2A CCND2 CD1D OSR2 GRP PLEKHO1 CDHR5 PCOLCE PTPRZ1 BAAT POU2F2 CITED4 ZNF503 MTUS1 SLC7A5 ISM2 HABP2 GMDS CRABP2 DGAT2 NR2F1 KISS1R NCF2 MMP10 GPR62 NPTX1 ZSCAN5A SLC2A1 LST1 CLHC1 GIMAP7 LRRC25 BPIFC DPH2 TMEM47 SUCNR1 TMEM26 KCNG1 BFSP1 CKAP2L ZG16B CLMP PPP1R15A RBP2 ALOX5 BICC1 CENPE BNC2 JAKMIP3 IL13RA2 CCDC102B AURKB WNT6 CHRNA2 NCEH1 C1QL1 PPP1R16B GATA4 CDH17 LMCD1 STC2 VTN MOB3B EPHA2 GPSM1 EPHA3 PDLIM7 HOXA5 PMAIP1 TSPAN15 DSE CISH NTN4 IGFBP6 FANK1 CLEC10A ZYX SLC22A16 ZNF165 TMEFF2 MYOF CDK1 INHA KCNJ15 TCF21 POU3F1 HOXA3 MCM5 SERPINA4 AQP8 SLIT2 GALNT9 FUT10 FRMD6 BUB1 AMIGO2 KCNH8 BTK E2F1 HS3ST1 TRPM2 KCNK5 MYZAP ANTXR1 TSPAN11 ANTXR2 PPM1E CENPM FAM177B TCIRG1 LEF1 ADAMTS7 HPCAL4 PRC1 GREM2 ADAMTS14 FOSL2 DHCR24 VGLL1 SASH3 NPPB FLT3 PLA2G4A HAND2 PLXND1 FXYD6 IRX1 ACTC1 SLC17A6 ADRB2 BARX2 KCNH2 CYP2C9 ANLN MYOM1 SDC3 CHODL BCL3 CRISP3 CAPN13 PDLIM4 CLEC1A PCDH7 LGR4 DLC1 CDCA8 HFM1 NELL2 KIAA1755 FXYD3 NR0B1 CD7 FCGR3A NR1H4 TCF19 SLFN11 PPP1R18 CSF2RB TBX18 PRKG1 SLC38A5 ITGB3 CD8B BUB1B GPR34 IGSF3 TLE4 TUBB6 COLEC11 LY6H CDCA3 EYA4 TMTC4 TRPV6 LAMB2 SNTB1 ZNF385B LILRB4 SLITRK6 RHOJ ATP7B C1orf162 TM4SF5 HLX MCTP1 TNF CALB2 ENPEP SPRY1 NR0B2 ITGA6 CYP2C8 KLF4 RAB27B WDR25 POC1A PDE1A TNFSF10 UBASH3B TACC3 RDH12 PRR16 IER5 RORB PLK4 ATP8B1 ULBP2 EVI2B RAMP1 BLM ASF1B CDCA7 GJB3 SLC6A6 MND1 RAD51 C3orf52 GNGT2 CALCA FNDC4 SHISA2 OXGR1 MAG FOXS1 CD52 SLC12A7 DAB2IP FOXF2 PLAC9 TMEM100 TSC22D3 ACHE ACE2 ZCCHC12 CDKN3 LONRF3 TMPPE LIMS2 FAM124A VWC2L LILRB2 CCR5 CCKBR PEX12 BRD9 ZNF613 CD300LB PGM2 BLNK ARSJ DNMBP CD72 EXO5 SYTL5 ZBTB25 TRIM63 S1PR5 PDP1 SERPINB8 PLD6 EBI3 ZNF502 GIPC2 LRRN4 ZNF77 SEMA4F FJX1 RGS13 SLC18A3 DOK2 ZKSCAN2 CCL7 ZNF792 CA10 CDK8 GK HTR3A ZNF441 NKD1 TGFBR1 SLC2A14 CXCL11 CMYA5 S100A1 GFRA1 ZNF630 SCN2A OASL ZNF235 TFAP2A DOK5 RAPGEF5 IL18R1 RASA3 COCH ITM2A ASB4 TMEM255B PAX4 SH3GL3 GUCA2A HMCN1 LRIF1 OAS1 CNTNAP2 ZC4H2 ZNF133 MYO7A ZNF117 SNPH TNK1 HHAT PTPRG RAB33A ERBB4 PRMT7 INSM1 BARHL1 ACSL5 PALMD IL5RA ARNTL ZNF30 ZNF267 EVA1B RAD51D FAM107B MCM2 FGF7 SLITRK1 KRT6A CYP2C18 LUZP2 FGF18 RTTN ZNF765 UGT2B4 PARP16 RGS7BP CST1 DIO2 OPCML IL4I1 PEAR1 MFNG HGF DTL ZNF473 ROR2 FGF14 MB21D2 TTC21A SLC29A3 BMP8A CYP2E1 PCDHB15 PSG3 EFR3B FMO4 IGDCC4 CPNE7 HSPB6 GAS1 CLEC4A FBXO16 COL22A1 DMRTA1 LMO4 FMNL2 KPNA7 APLN RCN3 LAMA3 FBXO25 MPP3 PARS2 NPW FAM167B CSMD1 ANO7 PDGFC MKNK2 TDP1 SOX17 HSPA1L CLSTN2 SNX16 ASB2 SEMA6A HR SPRR3 FOXD3 GSTM4 CD5 DISP2 DPYD MSR1 MPP1 CGN CCDC34 SESN2 ADH1A ZNF132 ZNF558 APBB1IP MYRIP TMEM88 VEPH1 MRGPRF SLC1A1 CARD16 PTGIS NUP62CL CHST3 CLEC11A CCR7 SPINT1 NHLH2 CHRNA1 LRFN5 ARL2BP PYROXD2 ARRDC2 CAMTA2 ZNF611 FAM3B ALDH1B1 CD38 TNFRSF10D RCL1 TBL3 KIF20A SLC36A1 ESRP2 IL18 CHI3L2 NEK6 CMKLR1 PILRA LMO7 SPDL1 DLGAP5 GFI1 USF1 VILL CD163L1 ABCG2 ZSCAN31 ANKRD30B IL23A AP1M1 TYMS APLF TRO HSPA12A C16orf54 C8orf48 CORIN FOXC2 CHL1 GMNC GPX8 C16orf71 SPAG17 GLDC BMX IGSF11 CAPN3 ADORA1 LRRC20 AMMECR1L CARNS1 DENND6B MSLN CYP4F12 GPSM3 BYSL FAH HHIPL1 KLHL1 ANKRD34C CHRDL2 MID1IP1 HPX SLAMF9 DGKG RAPGEF4 IL22RA1 VEGFC GPC6 CDC20 LDLRAD4 SPRR1A SERPIND1 NCAPH PCDHB4 SNX33 OTUD1 . What is the cell type of this cell?"}
9 | ]
10 | )
11 |
12 | print(response)
13 |
14 | ###example output:
15 | '''
16 | The list you've provided appears to be a transcriptomic profile, which includes a variety of genes expressed in a cell. Identifying the cell type from a list of genes would typically require comparing the expression profile to known profiles from various cell types, often using bioinformatics tools or databases such as the Human Cell Atlas or single-cell RNA sequencing (scRNA-seq) databases that classify cell types based on their gene expression patterns.
17 |
18 | Without such tools or databases at my disposal, I cannot definitively identify the cell type just from a list of genes. Determining the cell type would involve analyzing which genes are expressed, their levels of expression, and how those levels compare to the expression profiles of known cell types. This process usually involves complex data analysis using specialized software.
19 |
20 | In a laboratory or research setting, scientists would use bioinformatics analysis to map the gene expression profile against a database of cell types to find the closest match. If you have access to such databases or tools, that would be the recommended course of action to identify the cell type associated with this gene expression profile.
21 | '''
--------------------------------------------------------------------------------
/.history/Cell-type Annotation/cta_gpt_20241102171759.py:
--------------------------------------------------------------------------------
1 | # remember to set the OpenAI token in ahead.
2 | from openai import OpenAI
3 | client = OpenAI()
4 |
5 | response = client.chat.completions.create(
6 | model="gpt-4",
7 | messages=[
8 | {"role": "user", "content": "This cell has genes ranked by their expression as: CCDC71L NTS F8 GHSR GRIN2D VNN3 DTX1 SPOCK2 TRPC5 AQP9 GGT1 DUSP23 COL16A1 CCDC3 CH25H PTX3 CADM3 NTRK2 AGR3 LDB2 LRRTM1 FOSL1 PIK3AP1 CHST8 TGFBR2 MBOAT4 BCL2 MYRF GPC1 PPARGC1A SLIT3 DOCK2 SYT1 MFSD2A POLR3B LURAP1L UGT2B7 LYN GALNT2 RASD2 ALDH1A2 F10 C18orf54 CGA HEG1 COL14A1 SLC43A2 NRARP NPNT BMF GCGR SPSB1 RAB34 PRKAR2B TET3 DIAPH3 RAMP2 GLI2 CCNA2 ABCB1 PCDH9 TMEM233 PPP2R2B SOCS2 COX6A2 GALNT7 AMOTL1 CREB3L1 ADAM8 SYBU PRCP RNF186 ITIH4 CACHD1 FAM155A EGF RCC1 MNX1 GGH NTM ZNF180 SLC16A7 NUF2 F2 HOXD8 LTF PTGFR FAM83D SLC2A2 TRIM47 LMO3 TGIF1 HPGD ATP6V0D2 AFAP1L2 FA2H C3orf80 NCF4 SH2D2A HDC RARB FBLN5 HECTD3 GRAP2 PID1 C3AR1 LGR5 MUCL1 FHL3 CASP1 GRIN2A TAX1BP3 CSPG4 ZNF804A PKNOX2 DPEP1 RAMP3 LPAR5 IQGAP2 CMTM3 MET SLCO4A1 ANXA13 IKZF3 CYTH3 FUT3 CABLES1 HNF4G CHRDL1 PROSER2 TUBB1 PTPRT RASIP1 STMN4 GABRA3 SH3KBP1 MERTK CD4 SLC6A8 MOXD1 ASTN1 UGT2A3 HSPBAP1 GATA3 MTMR1 NDRG4 SORD PPIC TRPV2 CPXM1 CPZ CCNF TRPM5 FLNC DUSP22 DDIT4L ID4 GAPT ARHGAP18 SERPINB4 WAS SLC16A14 CD79A ZEB1 PDE2A SLAMF8 LCP2 AKAP5 KIF2C CDR2L PDE10A SPTSSB SYNM CNIH2 ALAS2 C11orf53 MDFI GDNF ARHGEF2 GAS6 FERMT3 PLLP FZD10 PTPN7 ADD2 ADAM9 SIRPA HKDC1 TMSB15A ZMAT4 MECOM SH3PXD2B DTX4 PPARGC1B SLC16A10 THSD4 LPL FCGR3B CLEC2B TSPO SLC9A9 CPNE4 TMEM217 PLEK PON1 ST6GAL1 SOX11 P4HA3 SLC9A3R2 DUOXA1 CDA NNAT CLIC6 ADORA2A LRRN3 CPNE5 RFTN2 STK17A B3GNT3 FAM20C GRK5 DSC2 BIN2 EGLN3 GIMAP8 CHRM3 LRRK2 EFHC2 HES4 CIT TDO2 F5 TIMP4 TH PLP1 CACNA2D4 SLC1A3 PSCA SYNDIG1 GAB2 ORM1 NEURL1B SERPINI1 CPXM2 IL17RB SHMT2 PTHLH ALDH2 TWIST2 CYSLTR2 PIM2 IGSF10 APOBR KIAA1522 ZNF367 SYDE1 ASB12 ACCS DOCK9 DCHS1 MS4A4A DKK1 ACSL1 PGM1 RRN3 HHEX LPAR1 CD86 NID2 SLC39A5 SMAD3 RASL12 CLPSL1 BDKRB1 ZP1 SATB2 GSTM5 GIMAP4 EFEMP1 TNFSF15 ARL14 MYC CD48 C1QTNF1 ZNF831 SH3BP4 HSD11B2 LTC4S PLCH2 ID2 STIM1 LILRB5 SOX7 CAPN5 NPR3 KDELR3 IL1R2 S100A8 VSNL1 TROAP PCDH8 BAALC TUBB2A CCND2 CD1D OSR2 GRP PLEKHO1 CDHR5 PCOLCE PTPRZ1 BAAT POU2F2 CITED4 ZNF503 MTUS1 SLC7A5 ISM2 HABP2 GMDS CRABP2 DGAT2 NR2F1 KISS1R NCF2 MMP10 GPR62 NPTX1 ZSCAN5A SLC2A1 LST1 CLHC1 GIMAP7 LRRC25 BPIFC DPH2 TMEM47 SUCNR1 TMEM26 KCNG1 BFSP1 CKAP2L ZG16B CLMP PPP1R15A RBP2 ALOX5 BICC1 CENPE BNC2 JAKMIP3 IL13RA2 CCDC102B AURKB WNT6 CHRNA2 NCEH1 C1QL1 PPP1R16B GATA4 CDH17 LMCD1 STC2 VTN MOB3B EPHA2 GPSM1 EPHA3 PDLIM7 HOXA5 PMAIP1 TSPAN15 DSE CISH NTN4 IGFBP6 FANK1 CLEC10A ZYX SLC22A16 ZNF165 TMEFF2 MYOF CDK1 INHA KCNJ15 TCF21 POU3F1 HOXA3 MCM5 SERPINA4 AQP8 SLIT2 GALNT9 FUT10 FRMD6 BUB1 AMIGO2 KCNH8 BTK E2F1 HS3ST1 TRPM2 KCNK5 MYZAP ANTXR1 TSPAN11 ANTXR2 PPM1E CENPM FAM177B TCIRG1 LEF1 ADAMTS7 HPCAL4 PRC1 GREM2 ADAMTS14 FOSL2 DHCR24 VGLL1 SASH3 NPPB FLT3 PLA2G4A HAND2 PLXND1 FXYD6 IRX1 ACTC1 SLC17A6 ADRB2 BARX2 KCNH2 CYP2C9 ANLN MYOM1 SDC3 CHODL BCL3 CRISP3 CAPN13 PDLIM4 CLEC1A PCDH7 LGR4 DLC1 CDCA8 HFM1 NELL2 KIAA1755 FXYD3 NR0B1 CD7 FCGR3A NR1H4 TCF19 SLFN11 PPP1R18 CSF2RB TBX18 PRKG1 SLC38A5 ITGB3 CD8B BUB1B GPR34 IGSF3 TLE4 TUBB6 COLEC11 LY6H CDCA3 EYA4 TMTC4 TRPV6 LAMB2 SNTB1 ZNF385B LILRB4 SLITRK6 RHOJ ATP7B C1orf162 TM4SF5 HLX MCTP1 TNF CALB2 ENPEP SPRY1 NR0B2 ITGA6 CYP2C8 KLF4 RAB27B WDR25 POC1A PDE1A TNFSF10 UBASH3B TACC3 RDH12 PRR16 IER5 RORB PLK4 ATP8B1 ULBP2 EVI2B RAMP1 BLM ASF1B CDCA7 GJB3 SLC6A6 MND1 RAD51 C3orf52 GNGT2 CALCA FNDC4 SHISA2 OXGR1 MAG FOXS1 CD52 SLC12A7 DAB2IP FOXF2 PLAC9 TMEM100 TSC22D3 ACHE ACE2 ZCCHC12 CDKN3 LONRF3 TMPPE LIMS2 FAM124A VWC2L LILRB2 CCR5 CCKBR PEX12 BRD9 ZNF613 CD300LB PGM2 BLNK ARSJ DNMBP CD72 EXO5 SYTL5 ZBTB25 TRIM63 S1PR5 PDP1 SERPINB8 PLD6 EBI3 ZNF502 GIPC2 LRRN4 ZNF77 SEMA4F FJX1 RGS13 SLC18A3 DOK2 ZKSCAN2 CCL7 ZNF792 CA10 CDK8 GK HTR3A ZNF441 NKD1 TGFBR1 SLC2A14 CXCL11 CMYA5 S100A1 GFRA1 ZNF630 SCN2A OASL ZNF235 TFAP2A DOK5 RAPGEF5 IL18R1 RASA3 COCH ITM2A ASB4 TMEM255B PAX4 SH3GL3 GUCA2A HMCN1 LRIF1 OAS1 CNTNAP2 ZC4H2 ZNF133 MYO7A ZNF117 SNPH TNK1 HHAT PTPRG RAB33A ERBB4 PRMT7 INSM1 BARHL1 ACSL5 PALMD IL5RA ARNTL ZNF30 ZNF267 EVA1B RAD51D FAM107B MCM2 FGF7 SLITRK1 KRT6A CYP2C18 LUZP2 FGF18 RTTN ZNF765 UGT2B4 PARP16 RGS7BP CST1 DIO2 OPCML IL4I1 PEAR1 MFNG HGF DTL ZNF473 ROR2 FGF14 MB21D2 TTC21A SLC29A3 BMP8A CYP2E1 PCDHB15 PSG3 EFR3B FMO4 IGDCC4 CPNE7 HSPB6 GAS1 CLEC4A FBXO16 COL22A1 DMRTA1 LMO4 FMNL2 KPNA7 APLN RCN3 LAMA3 FBXO25 MPP3 PARS2 NPW FAM167B CSMD1 ANO7 PDGFC MKNK2 TDP1 SOX17 HSPA1L CLSTN2 SNX16 ASB2 SEMA6A HR SPRR3 FOXD3 GSTM4 CD5 DISP2 DPYD MSR1 MPP1 CGN CCDC34 SESN2 ADH1A ZNF132 ZNF558 APBB1IP MYRIP TMEM88 VEPH1 MRGPRF SLC1A1 CARD16 PTGIS NUP62CL CHST3 CLEC11A CCR7 SPINT1 NHLH2 CHRNA1 LRFN5 ARL2BP PYROXD2 ARRDC2 CAMTA2 ZNF611 FAM3B ALDH1B1 CD38 TNFRSF10D RCL1 TBL3 KIF20A SLC36A1 ESRP2 IL18 CHI3L2 NEK6 CMKLR1 PILRA LMO7 SPDL1 DLGAP5 GFI1 USF1 VILL CD163L1 ABCG2 ZSCAN31 ANKRD30B IL23A AP1M1 TYMS APLF TRO HSPA12A C16orf54 C8orf48 CORIN FOXC2 CHL1 GMNC GPX8 C16orf71 SPAG17 GLDC BMX IGSF11 CAPN3 ADORA1 LRRC20 AMMECR1L CARNS1 DENND6B MSLN CYP4F12 GPSM3 BYSL FAH HHIPL1 KLHL1 ANKRD34C CHRDL2 MID1IP1 HPX SLAMF9 DGKG RAPGEF4 IL22RA1 VEGFC GPC6 CDC20 LDLRAD4 SPRR1A SERPIND1 NCAPH PCDHB4 SNX33 OTUD1 . What is the cell type of this cell?"}
9 | ]
10 | )
11 |
12 | print(response)
13 |
14 | ###example output:
15 | '''
16 | The list you've provided appears to be a transcriptomic profile, which includes a variety of genes expressed in a cell. Identifying the cell type from a list of genes would typically require comparing the expression profile to known profiles from various cell types, often using bioinformatics tools or databases such as the Human Cell Atlas or single-cell RNA sequencing (scRNA-seq) databases that classify cell types based on their gene expression patterns.
17 |
18 | Without such tools or databases at my disposal, I cannot definitively identify the cell type just from a list of genes. Determining the cell type would involve analyzing which genes are expressed, their levels of expression, and how those levels compare to the expression profiles of known cell types. This process usually involves complex data analysis using specialized software.
19 |
20 | In a laboratory or research setting, scientists would use bioinformatics analysis to map the gene expression profile against a database of cell types to find the closest match. If you have access to such databases or tools, that would be the recommended course of action to identify the cell type associated with this gene expression profile.
21 | '''
--------------------------------------------------------------------------------
/.history/readme_20241214182326.md:
--------------------------------------------------------------------------------
1 | #
scELMo: Embeddings from Language Models are Good Learners for Single-cell Data Analysis
2 |
3 |
4 |
5 | # News!
6 |
7 | We have uploaded gene embeddings from gpt4-o and drug embeddings from GPT 3.5 in our website, please check them if you wanna have a try!
8 |
9 | # Installation
10 |
11 | We rely on OpenAI API for query.
12 |
13 | ```
14 | pip install openai
15 | ```
16 |
17 | The descriptions and tutorials for OpenAI API can be found in this [link](https://platform.openai.com/).
18 |
19 | We reply on these packages for zero-shot learning analysis.
20 |
21 | ```
22 | pip install scib scib_metrics==0.3.3 pickle mygene scanpy==1.9.3 scikit-learn
23 | ```
24 |
25 | Installing hnswlib from the original Github profile to avoid potential errors.
26 | ```
27 | apt-get install -y python-setuptools python-pip #may not need it for HPC base
28 | git clone https://github.com/nmslib/hnswlib.git
29 | cd hnswlib
30 | pip install .
31 | ```
32 | All the packages above are enough for testing tasks absed on zero-shot learning.
33 |
34 | We rely on PyTorch for fine-tuning.
35 |
36 | ```
37 | conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
38 | conda install lightning -c conda-forge
39 | ```
40 |
41 | For the perturbation analysis, please install related pacakges based on their website and use the modifeid version provided in the **Perturbation Analysis** folder: [CINEMAOT](https://github.com/vandijklab/CINEMA-OT/tree/main), [CPA](https://github.com/theislab/cpa) and [GEARS](https://github.com/snap-stanford/GEARS/tree/master).
42 |
43 | To generate gene embeddings from sequence models (as seq2emb), please refer [seq2cells](https://github.com/GSK-AI/seq2cells) to install related packages.
44 |
45 |
46 | For users who cannot access OpenAI API, we provide an alternative solution based on [deepseekv2](https://www.deepseek.com/). Please refer the **Get outputs from LLMs** for more information.
47 |
48 | # Tutorials
49 |
50 | Please use the example ipynb notebook in each folders as instructions. Evaluations are included in the notebooks. The demo tutorial can be finished in a normal computer within 10 minutes with a prepared environment.
51 |
52 | # Datasets
53 |
54 | All of the datasets and their download information are included in the Supplementary file 3. A demo dataset for clustering can be found in this [link](https://drive.google.com/file/d/1hHVutJ3tsAhkhTJ-wCNe9OfXubw2m2gN/view?usp=sharing).
55 |
56 | # Database for scELMo
57 |
58 | We are maintaining a [website](https://sites.google.com/yale.edu/scelmolib) containing embeddings of different information generated by LLM. We are happy to discuss if you have any requests or comments.
59 |
60 | # Acknowledgement
61 |
62 | We refer the codes from the following packages to implement scELMo. Many thanks to these great developers:
63 |
64 | [GenePT](https://github.com/yiqunchen/GenePT), [seq2cells](https://github.com/GSK-AI/seq2cells), [CINEMAOT](https://github.com/vandijklab/CINEMA-OT/tree/main), [CPA](https://github.com/theislab/cpa) and [GEARS](https://github.com/snap-stanford/GEARS/tree/master).
65 |
66 | # Open for contribution
67 |
68 | We are happy to see if you have more exciting ideas about the extension of scELMo. Feel free to contact us for discussion:
69 |
70 | Tianyu Liu (tianyu.liu@yale.edu)
71 |
72 | # Citation
73 | ```
74 | @article{liu2023scelmo,
75 | title={scELMo: Embeddings from Language Models are Good Learners for Single-cell Data Analysis},
76 | author={Liu, Tianyu and Chen, Tianqi and Zheng, Wangjie and Luo, Xiao and Zhao, Hongyu},
77 | journal={bioRxiv},
78 | pages={2023--12},
79 | year={2023},
80 | publisher={Cold Spring Harbor Laboratory}
81 | }
82 | ```
83 |
84 | # Related work
85 |
86 | - [spEMO](https://github.com/HelloWorldLTY/spEMO)
87 | - [scLAMBDA](https://github.com/gefeiwang/scLAMBDA)
--------------------------------------------------------------------------------
/.history/seq2emb/README_20240519062039.md:
--------------------------------------------------------------------------------
1 | # Preprocessing sequence embeddings
2 |
3 | ### 0) Intro - Workflow
4 |
5 | Given some regions of interest (ROI), e.g. transcription start sites (TSS) the
6 | aim of the sequence pre-processing
7 | is to obtain:
8 |
9 | 1) A **query file** that specifies for each ROI: the DNA sequence
10 | window surrounding it and the location of the region of interest within this
11 | window.
12 | 2) Pre-computed DNA sequence **embeddings** for each ROI computed with the
13 | Enformer trunk
14 | 3) Gene (region) IDs that specifiy their intersection with Enformer training,
15 | test and validation sequences for splitting the dataset.
16 |
17 | ### 1) Query file
18 |
19 | Enformer embeds and predicts over 896 bins of 128 bp covering the central
20 | 114,688 bp of the sequence queries of length 196,608 bp.
21 | To extract embeddings of genomic ROIs, we construct sequence
22 | queries of length 196,608 bp and identify the corresponding Enformer output
23 | window
24 | within which the ROI lies so the correct embedding can be extracted.
25 |
26 | Using `create_seq_window_queries.py`
27 |
28 | This script will take regions of interest, stitch them into patches if desired
29 | and
30 | construct sequence windows adhering to chromosome boundaries and create queries
31 | for the sequence model.
32 | Genomic position and the index (bin_id) of the prediction bin with which the
33 | rois are intersecting are listed: 0-based!
34 | The subsequent script calculating DNA sequence embeddings can then use the
35 | bin_ids to extract the embeddings of interest.
36 |
37 | Stitching: if enabled will group the rois based on a supplied grouping
38 | variable (e.g. a gene name or id)
39 | ROIs with the same grouping id will be grouped into patches. Patches are
40 | split if they stretch over more than the supplied threshold (50kb default) and
41 | sequence windows are constructed over the center of patches.
42 | The position and bin id of the rois are listed in a comma separated string.
43 |
44 | Notes:
45 |
46 | * The stitching functionality is implemented but we do not use it for single
47 | cell expression predictions so far. To replicate the manuscript work run
48 | without stitching.
49 |
50 | * ROIs only accept a single position, if larger regions of interests should be
51 | supplied then please center the coordinate first.
52 |
53 | * By default, this script will center the sequence windows on ROIs or at the
54 | center of a stitched patch. Thus allowing predictions with a maximal
55 | sequence context reach for every roi.
56 |
57 | * If the number of prediction bins is even, such as with the default
58 | Enformer setting, then the center of the sequence window is covered
59 | by the border of two bins. In that case the sequence window is shifted by
60 | minus half a bin size to center the ROI within a single bin.
61 |
62 | #### Inputs:
63 |
64 | 1) A plain text file of ROI, where every line specifies a
65 | ROI supplied via the `--in` argument. Common formats are bed
66 | files or vcf file without header. Important, the genomic coodinates may be
67 | provided in bed-like (0-based, half open format) or as single column
68 | (1-based) vcf-like format.
69 | The coordinate handeling is controlled by the
70 | `position_col` and `position_base` arguments (see `--help`)
71 |
72 | ```angular2html
73 | chr1 65418 65419 ENST00000641515.2 . + ENSG00000186092.7 OR4F5
74 | chr1 451677 451678 ENST00000426406.4 . - ENSG00000284733.2 OR4F29
75 | chr1 686653 686654 ENST00000332831.5 . - ENSG00000284662.2 OR4F16
76 | chr1 923922 923923 ENST00000616016.5 . + ENSG00000187634.13 SAMD11
77 | ```
78 |
79 | 2) Reference genome in fasta file with .fai index present in same directory.
80 |
81 | ```angular2html
82 | >chr1
83 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
84 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
85 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
86 | ```
87 |
88 | #### Usage:
89 |
90 | ```bash
91 | python create_seq_window_queries.py \
92 | --in ./preprocessing_example_files/gencode.v41.basic.annotation.protein.coding.ensembl_canonical.tss.hg38.h10.bed \
93 | --ref_genome ./hg38.fa \
94 | --out ./query_tss_example.tsv \
95 | --chromosome_col 1\
96 | --position_col 3\
97 | --position_base 1 \
98 | --strand_col 6 \
99 | --group_id_col 7 \
100 | --additional_id_col 8 \
101 | --no-stitch
102 | ```
103 |
104 | #### Output
105 |
106 | Output is a tab-separated query file that lists the chrom start end strand of
107 | the sequence window the ids of the stitched patch and the grouping and
108 | additional_id, the center of the sequence window the number of regions of
109 | interest within the distance between multiple rois in the sequence and
110 | the strands, position and bin id of the rois, comma separated if multiple
111 | ones are available.
112 |
113 | ```angular2html
114 | chr seq_start seq_end seq_strand patch_id group_id add_id center num_roi stretch strands_roi positions_roi bins_roi
115 | chr1 1 196608 + ENSG00000186092.7_0 ENSG00000186092.7 OR4F5 98369 1 0 ['+'] 65419 191
116 | chr1 353310 549917 - ENSG00000284733.2_0 ENSG00000284733.2 OR4F29 451678 1 0 ['-'] 451678 448
117 | chr1 588286 784893 - ENSG00000284662.2_0 ENSG00000284662.2 OR4F16 686654 1 0 ['-'] 686654 448
118 | chr1 825555 1022162 + ENSG00000187634.13_0 ENSG00000187634.13 SAMD11 923923 1 0 ['+'] 923923 448
119 | chr1 860888 1057495 - ENSG00000188976.11_0 ENSG00000188976.11 NOC2L 959256 1 0 ['-'] 959256 448
120 | chr1 862216 1058823 + ENSG00000187961.15_0 ENSG00000187961.15 KLHL17 960584 1 0 ['+'] 960584 448
121 | chr1 868114 1064721 + ENSG00000187583.11_0 ENSG00000187583.11 PLEKHN1 966482 1 0 ['+'] 966482 448
122 | chr1 883725 1080332 - ENSG00000187642.10_0 ENSG00000187642.10 PERM1 982093 1 0 ['-'] 982093 448
123 | chr1 901729 1098336 - ENSG00000188290.11_0 ENSG00000188290.11 HES4 1000097 1 0 ['-'] 1000097 448
124 | ```
125 |
126 | ## 2) Sequence embeddings
127 |
128 | The next step is to pre-compute the sequence embeddings over the ROIs now
129 | specified in the query file.
130 |
131 | Using `calc_embeddings_and_targets.py`
132 |
133 | This script will take a query file as produced by
134 | `create_seq_window_queries.py` and compute embeddings and optionally predicted
135 | Enformer targets over the ROI.
136 |
137 | Main idea here is that ROI are always centered on the
138 | sequence model query window as much as possible to allow a balanced, maximal
139 | sequence context for each prediction.
140 |
141 | Ideally only a single region of interest or regions very close together are
142 | supplied per query. Larger sets should be split in the prior pre-processing
143 | step. E.g. split multiple clusters of TSS more than ~ 50 kb apart into
144 | separate entities for summary later.
145 |
146 | Embeddings and targets from multiple ROIs or with adjacent bins specified are
147 | aggregated according to the specified methods. Default: Embeddings - mean,
148 | Targets - sum.
149 |
150 | #### Notes
151 |
152 | * If the ROI / patch is located on the minus strand the reverse
153 | complement of the plus strand will be used as sequence input.
154 | * If the reverse_complement is forced via `--rc_force` the reverse_complement
155 | is applied to plus strand patches and minus strand patches are processed
156 | from the plus strand. The position of ROIs are always
157 | mirrored where necessary to ensure the correct targets/embeddings are
158 | extracted.
159 | * If the reverse complement augmentation is toggled on via `--rc_aug` then
160 | the reverse complement is applied randomly in 50 % of instances.
161 | * `--rc_force` overwrites `--rc_aug`
162 | * Shift augmentations are chosen randomly from the selected range of bp shifts
163 | selected a single bp shift if wanting to precisely control for that.
164 | * Note: preprocessing with multiple ROIs per query is supported but all
165 | single cell work carried out by us was using a single ROI (TSS of
166 | canonical transcript).
167 |
168 | #### Input
169 |
170 | 1) Query file as produced by `create_seq_window_queries.py` which is
171 | a raw text file
172 | including a header column that specifies the sequence windows to be
173 | processed by the seq model and the positions of the regions of interest
174 | within that sequence to be extracted (roi). Positions and bins of
175 | multiple ROI per query are comma separated in one string. Example format:
176 |
177 | ```angular2html
178 | chr seq_start seq_end seq_strand patch_id group_id add_id center num_roi stretch strands_roi positions_roi bins_roi
179 | chr1 1 196608 + ENSG00000186092.7_0 ENSG00000186092.7 OR4F5 98369 1 0 ['+'] 65419 191
180 | chr1 353310 549917 - ENSG00000284733.2_0 ENSG00000284733.2 OR4F29 451678 1 0 ['-'] 451678 448
181 | chr1 588286 784893 - ENSG00000284662.2_0 ENSG00000284662.2 OR4F16 686654 1 0 ['-'] 686654 448
182 | chr1 825555 1022162 + ENSG00000187634.13_0 ENSG00000187634.13 SAMD11 923923 1 0 ['+'] 923923 448
183 | chr1 860888 1057495 - ENSG00000188976.11_0 ENSG00000188976.11 NOC2L 959256 1 0 ['-'] 959256 448
184 | chr1 862216 1058823 + ENSG00000187961.15_0 ENSG00000187961.15 KLHL17 960584 1 0 ['+'] 960584 448
185 | chr1 868114 1064721 + ENSG00000187583.11_0 ENSG00000187583.11 PLEKHN1 966482 1 0 ['+'] 966482 448
186 | chr1 883725 1080332 - ENSG00000187642.10_0 ENSG00000187642.10 PERM1 982093 1 0 ['-'] 982093 448
187 | chr1 901729 1098336 - ENSG00000188290.11_0 ENSG00000188290.11 HES4 1000097 1 0 ['-'] 1000097 448
188 | ```
189 |
190 | 2) Reference genome in fasta format. Needs to be indexed (same name file
191 | with .fa.fai ending present)
192 |
193 | ```angular2html
194 | >chr1
195 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
196 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
197 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
198 | ```
199 |
200 | #### Usage
201 |
202 | ```bash
203 | python calc_embeddings_and_targets.py \
204 | --in_query ./preprocessing_example_files/query_tss_example.tsv \
205 | --ref_genome hg38.fa \
206 | --out_name enformer_out \
207 | --position_base 1 \
208 | --add_bins 0 \
209 | --store_text \
210 | --store_h5 \
211 | --targets '4675:5312' # for all Enformer cage-seq targets
212 | ```
213 |
214 | #### Output
215 |
216 | Output are one or two tab separated
217 | text files storing the embeddings and optionally targets and/or an hdf5 file
218 | storing the
219 | embedding and target as pandas data frames under the 'emb' and 'tar' handle
220 | respectively.
221 | The header columns in the embedding file indicate the embedding dimensions.
222 | The header columns in the target text file / data frame
223 | correspond to the selected target ids (0-based) of Enformer targets
224 | (see the
225 | [published Basenji2 targets](https://github.com/calico/basenji/tree/master/manuscripts/cross2020)
226 | ).
227 | Targets are subset to the selected targets, the indices of the selected are
228 | stored in the header of the target output file (0-based)
229 |
230 | Example raw text outputs:
231 |
232 | ```bash
233 | head -n 3 enformer_out*tsv | cut -f 1,2,3
234 | ==> enformer_out_emb.tsv <==
235 | 0 1 2
236 | -0.11201313883066177 -0.0001226698950631544 -0.10420460253953934
237 | -0.1380479633808136 -8.836987944960129e-06 -0.14271216094493866
238 |
239 | ==> enformer_out_tar.tsv <==
240 | 4675 4676 4677
241 | 0.021540187299251556 0.012503976002335548 0.012968547642230988
242 | 0.01947534829378128 0.007085299119353294 0.007071667350828648
243 | ```
244 |
245 | ## 3) Intersect regions of interest with Enformer train / test / valid regions
246 |
247 | For splitting genes into training, test and validation set we intersect the
248 | position of their TSS with the regions over which Enformer is trained to
249 | predict chromatin features and CAGE-seq coverage. See
250 | [Kelley 2020](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1008050#sec010)
251 | For a description of the train, test, valid split region construction. The genes
252 | whose TSS intersect with test and validation regions are extracted as test and
253 | validation set for the single cell work. Where a TSS intersect with multiple
254 | Enformer regions we select the one where the TSS is most central.
255 |
256 | ### Notes
257 | By default the Enformer input sequences are of length 196,608 bp.
258 | These regions were taken from the Basenji2 work with regions of length
259 | 131,072 bp and extended by 32,768 bp to each side.
260 | The 131,072 bp sequences were shared by the authors.
261 | By default we trim the shared sequences to the central
262 | 114,688 bp, because Enformer is only trained to predict over
263 | those 896 * 128 bp bins of each sequence window.
264 | The pruning can be disabled via the `--no_prune` flag. This will intersect
265 | the TSS with the 131,072 bp sequences.
266 | Alternatively, using `--extend` flag the sequence windows can be extended to
267 | the full 196,608 bp.
268 |
269 | #### Input
270 |
271 | 1) Query file as produced by `create_seq_window_queries.py` which is
272 | a raw text file
273 | including a header column that specifies the sequence windows to be
274 | processed by the seq model and the positions of the regions of interest
275 | within that sequence to be extracted (roi). Positions and bins of
276 | multiple ROI per query are comma separated in one string.
277 | The 'patch_id' column is used for unique RSS/ROI identification
278 | Example format:
279 |
280 | ```angular2html
281 | chr seq_start seq_end seq_strand patch_id group_id add_id center num_roi stretch strands_roi positions_roi bins_roi
282 | chr1 1 196608 + ENSG00000186092.7_0 ENSG00000186092.7 OR4F5 98369 1 0 ['+'] 65419 191
283 | chr1 353310 549917 - ENSG00000284733.2_0 ENSG00000284733.2 OR4F29 451678 1 0 ['-'] 451678 448
284 | chr1 588286 784893 - ENSG00000284662.2_0 ENSG00000284662.2 OR4F16 686654 1 0 ['-'] 686654 448
285 | chr1 825555 1022162 + ENSG00000187634.13_0 ENSG00000187634.13 SAMD11 923923 1 0 ['+'] 923923 448
286 | chr1 860888 1057495 - ENSG00000188976.11_0 ENSG00000188976.11 NOC2L 959256 1 0 ['-'] 959256 448
287 | chr1 862216 1058823 + ENSG00000187961.15_0 ENSG00000187961.15 KLHL17 960584 1 0 ['+'] 960584 448
288 | chr1 868114 1064721 + ENSG00000187583.11_0 ENSG00000187583.11 PLEKHN1 966482 1 0 ['+'] 966482 448
289 | chr1 883725 1080332 - ENSG00000187642.10_0 ENSG00000187642.10 PERM1 982093 1 0 ['-'] 982093 448
290 | chr1 901729 1098336 - ENSG00000188290.11_0 ENSG00000188290.11 HES4 1000097 1 0 ['-'] 1000097 448
291 | ```
292 |
293 | 2) Enformer sequences with train, test, validation assignment. The regions
294 | were [shared](https://console.cloud.google.com/storage/browser/basenji_barnyard/data)
295 | by the Basenji2/Enformer authors. And are also stored with thre files
296 | required for pre-processing here ... #TODO
297 |
298 | ```angular2html
299 | chr18 936578 1051266 train
300 | chr4 113639139 113753827 train
301 | chr11 18435912 18550600 train
302 | chr16 85813873 85928561 train
303 | chr3 158394380 158509068 train
304 | chr7 136791743 136906431 train
305 | chr8 132166506 132281194 valid
306 | chr21 35647195 35761883 valid
307 | chr16 24529786 24644474 test
308 | chr8 18655640 18770328 test
309 | ```
310 |
311 | Using `intersect_queries_with_enformer_regions.py`
312 |
313 | Run as
314 |
315 | ```bash
316 | python intersect_queries_with_enformer_regions.py \
317 | --query query_gencode_v41_protein_coding_canonical_tss_hg38_nostitch.tsv \
318 | --enf_seqs sequences.bed \
319 | --strip
320 | ```
321 |
322 | #### Output
323 |
324 | Three raw text files with the gene IDs belonging to train, test and
325 | validation set respectively. Those are used for
326 | tagging the genes in `add_embeddings_to_anndata.py`.
327 |
328 | ```bash
329 | head -n 3 query_enf_intersect_*.txt
330 | ==> query_enf_intersect_test.txt <==
331 | ENSG00000003096
332 | ENSG00000004776
333 | ENSG00000004777
334 |
335 | ==> query_enf_intersect_train.txt <==
336 | ENSG00000000457
337 | ENSG00000000460
338 | ENSG00000000938
339 |
340 | ==> query_enf_intersect_valid.txt <==
341 | ENSG00000000003
342 | ENSG00000000005
343 | ENSG00000000419
344 | ```
345 |
--------------------------------------------------------------------------------
/.history/seq2emb/pseudobulk_anndata_20240519062159.py:
--------------------------------------------------------------------------------
1 | """
2 | Create embedding query from DNA sequence window and regions of interest.
3 | =========================================
4 | Copyright 2023 GlaxoSmithKline Research & Development Limited. All rights reserved.
5 |
6 | Licensed under the Apache License, Version 2.0 (the "License");
7 | you may not use this file except in compliance with the License.
8 | You may obtain a copy of the License at
9 |
10 | http://www.apache.org/licenses/LICENSE-2.0
11 |
12 | Unless required by applicable law or agreed to in writing, software
13 | distributed under the License is distributed on an "AS IS" BASIS,
14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | See the License for the specific language governing permissions and
16 | limitations under the License.
17 | =========================================
18 | ..Input::
19 | A single cell (RNA) AnnData object with .obs being the genes and
20 | .var being the individual cells [gene x cell].
21 | Expects a .var column matching the cell_type_col_name argument.
22 | Expects .obs to be indexed of gene ID or symbols.
23 | Expects a .obs column matching the gene_col_name argument if one was
24 | provided that is not 'index'. If index is provided will use the gene ID
25 | index instead of a gene name.
26 | Expects a layer matching the layer argument to be present if specified.
27 |
28 | ..Arguments::
29 | -h, --help Show this help message and exit
30 | --in IN_FILE Input file is an anndata object saved as h5ad file.
31 | --genes GENES
32 | List gene ids or symbols to compute the pseudobulk
33 | aggregate for. Must match the entries in gene_col_name
34 | of the anndata object. Default = ''.
35 | --gene_col_name GENE_COL_NAME
36 | Name of .obs column where gene names can be found
37 | that should be used for the aggregation. If set to
38 | 'index' will use the .obs index instead.
39 | Default='index"
40 | --cell_type_col_name CELL_TYPE_COL
41 | Name of the .var column that indicates the cell types
42 | that will be used for the pseudobulking.
43 | Default='cell types
44 | --method METHOD
45 | Method to use for pseudobulking, supports:
46 | 'mean' - take the mean of the reads per gene
47 | per cell type
48 | 'sum' - take the sum of the reads per gene per cell
49 | type
50 | 'count_exp' - count the cells that express the gene at
51 | or above an expression threshold provided per gene and
52 | cell type
53 | 'perc_exp' - calculate the fraction of cells that
54 | express
55 | the gene at or above an expression threshold provided per
56 | gene and cell type.
57 | Default = 'mean'
58 | --expr_threshold EXP_THRESHOLD
59 | Threshold at or above which a gene should be
60 | considered as expressed. Matching the observed counts
61 | in the anndata object.
62 | --layer LATER
63 | If provided will use the anndata layer instead of the .X
64 | counts.
65 |
66 | ..Usage::
67 | python ./pseudobulk_anndata.py \
68 | --in my_anndata.h5ad \
69 | --out my_pseudobulked_anndata.h5ad \
70 | --gene_col_name 'index' \
71 | --cell_type_col_name 'cell types'\
72 | --method 'mean'
73 |
74 | ..Output:: Output is AnnData object stored as .h5ad file under the --out
75 | location, with .obs being the genes and .var being the individual
76 | cell types [gene x cell types]. Where observed counts were aggregated
77 | according to the chosen method.
78 | """
79 | import argparse
80 | import logging
81 |
82 | import scanpy as sc
83 |
84 | from seq2cells.utils.anndata_utils import pseudo_bulk
85 |
86 | parser = argparse.ArgumentParser(
87 | description="Pseudobulk an AnnData object by cell type."
88 | )
89 | parser.add_argument(
90 | "--in",
91 | dest="in_file",
92 | type=str,
93 | required=True,
94 | help="Input anndata file in .h5ad format.",
95 | )
96 | parser.add_argument(
97 | "--genes",
98 | dest="genes",
99 | nargs="+",
100 | default="",
101 | required=False,
102 | help="List gene ids or symbols to compute the pseudobulk aggregate for. "
103 | "Must match the entries in gene_col_name of the anndata object.",
104 | )
105 | parser.add_argument(
106 | "--out",
107 | dest="out_file",
108 | default="./query_file_seq_model.tsv",
109 | type=str,
110 | required=True,
111 | help="Path and name for storing the pseudobulked anndata .h5ad",
112 | )
113 | parser.add_argument(
114 | "--gene_col_name",
115 | dest="gene_col_name",
116 | default="index",
117 | type=str,
118 | required=False,
119 | help="Name of .obs column where gene names can be found that should be "
120 | "used for the aggregation. If set to 'index' will use the .obs "
121 | "index instead. Default='index",
122 | )
123 | parser.add_argument(
124 | "--cell_type_col_name",
125 | dest="cell_type_col_name",
126 | default="cell types",
127 | type=str,
128 | required=False,
129 | help="Name of the .var column that indicates the cell types "
130 | "that will be used for the pseudobulking. "
131 | "Default='cell types",
132 | )
133 | parser.add_argument(
134 | "--method",
135 | dest="method",
136 | default="mean",
137 | type=str,
138 | required=False,
139 | help="Method to use for pseudobulking, supports:"
140 | "'mean' - take the mean of the reads per gene per cell type"
141 | "'sum' - take the sum of the reads per gene per cell type"
142 | "'count_exp' - count the cells that express the gene at or above an "
143 | "expression threshold provided per gene and cell type"
144 | "'perc_exp' - calculate the fraction of cells that express the "
145 | "gene at or above an expression threshold provided per gene and cell type. "
146 | "Default = 'mean",
147 | )
148 | parser.add_argument(
149 | "--expr_threshold",
150 | dest="expr_threshold",
151 | default=0.5,
152 | type=float,
153 | required=False,
154 | help="Threshold at or above which a gene should be considered as expressed. "
155 | "Matching the observed counts in the anndata object. Default = 0.5",
156 | )
157 | parser.add_argument(
158 | "--layer",
159 | dest="layer",
160 | default=None,
161 | type=str,
162 | required=False,
163 | help="If provided will use the anndata layer instead of the .X counts.",
164 | )
165 | parser.add_argument(
166 | "--mem_friendly",
167 | dest="mem_friendly",
168 | action="store_true",
169 | help="Flag to run in memory friendly mode. Takes oj the order of 10 times longer.",
170 | )
171 | parser.set_defaults(mem_friendly=False)
172 | parser.add_argument(
173 | "--debug", dest="debug", action="store_true", help="Flag switch on debugging mode."
174 | )
175 | parser.set_defaults(debug=False)
176 |
177 |
178 | if __name__ == "__main__":
179 | # fetch arguments
180 | args = parser.parse_args()
181 |
182 | if args.debug:
183 | logging.basicConfig(level=logging.INFO)
184 | logger = logging.getLogger(__name__)
185 |
186 | # set scanpy verbosity
187 | # verbosity: errors (0), warnings (1), info (2), hints (3)
188 | if args.debug:
189 | sc.settings.verbosity = 3
190 | else:
191 | sc.settings.verbosity = 1
192 |
193 | # assert valid aggregation method selected
194 | assert args.method in [
195 | "mean",
196 | "sum",
197 | "perc_exp",
198 | "count_exp",
199 | ], "Invalid aggregation method selected!"
200 |
201 | # read anndata
202 | adata = sc.read_h5ad(args.in_file)
203 |
204 | # check selected genes
205 | if args.genes == "":
206 | genes = []
207 | num_genes = "all"
208 | else:
209 | genes = args.genes
210 | num_genes = len(genes)
211 |
212 | # run pseudobulking
213 | logger.info(f"Pseudo bulking {num_genes} genes ...")
214 |
215 | if args.mem_friendly:
216 | pseudo_adata = pseudo_bulk(
217 | adata,
218 | genes=genes,
219 | cell_type_col=args.cell_type_col_name,
220 | gene_col=args.gene_col_name,
221 | mode=args.method,
222 | expr_threshold=args.expr_threshold,
223 | mem_efficient_mode=True,
224 | layer=args.layer,
225 | )
226 | else:
227 | pseudo_adata = pseudo_bulk(
228 | adata,
229 | genes=genes,
230 | cell_type_col=args.cell_type_col_name,
231 | gene_col=args.gene_col_name,
232 | mode=args.method,
233 | expr_threshold=args.expr_threshold,
234 | mem_efficient_mode=False,
235 | layer=args.layer,
236 | )
237 |
238 | logger.info("Writting results to " + args.out_file)
239 | pseudo_adata.write(args.out_file)
240 |
--------------------------------------------------------------------------------
/Batch Effect Correction/batch_effect_correction.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import scanpy as sc\n",
11 | "\n",
12 | "from scib_metrics.benchmark import Benchmarker\n",
13 | "import scib"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "adata = sc.read('/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/protein_data/reference.h5ad')\n",
23 | "query = sc.read('/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/protein_data/query.h5ad')\n",
24 | "\n",
25 | "adata_ref =sc.AnnData(adata.obsm['protein_counts'], obs = adata.obs)\n",
26 | "adata_que =sc.AnnData(query.obsm['pro_exp'], obs = query.obs)\n",
27 | "adata_combine = sc.concat([adata_ref, adata_que], keys = ['reference', 'query'])"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "import pickle\n",
37 | "with open('ensem_emb_pro_pbmcseuratv4.pickle', \"rb\") as fp: \n",
38 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
39 | "gene_names= adata_combine.var_names\n",
40 | "count_missing = 0\n",
41 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
42 | "lookup_embed = np.zeros(shape=(len(gene_names),EMBED_DIM))\n",
43 | "for i, gene in enumerate(gene_names):\n",
44 | " if gene in GPT_3_5_gene_embeddings.keys():\n",
45 | " lookup_embed[i,:] = GPT_3_5_gene_embeddings[gene].flatten()\n",
46 | " else:\n",
47 | " count_missing+=1\n",
48 | "lookup_embed.shape"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "sc.pp.filter_cells(adata_combine, min_genes=1) # This dataset contains proteins with expression as 0.\n",
58 | "adata_combine.obsm['X_proPT'] = adata_combine.X / np.sum(adata_combine.X, axis=1)[:,None] @ lookup_embed #GPT 3.5 wa"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "results = scib.metrics.metrics(\n",
68 | " adata,\n",
69 | " adata_int=adata,\n",
70 | " batch_key=\"dataset_name\",\n",
71 | " label_key=\"celltype.l2\",\n",
72 | " embed='X_proPT',\n",
73 | " isolated_labels_asw_=False,\n",
74 | " silhouette_=True,\n",
75 | " hvg_score_=False,\n",
76 | " graph_conn_=True,\n",
77 | " pcr_=True,\n",
78 | " isolated_labels_f1_=False,\n",
79 | " trajectory_=False,\n",
80 | " nmi_=True, # use the clustering, bias to the best matching\n",
81 | " ari_=True, # use the clustering, bias to the best matching\n",
82 | " cell_cycle_=False,\n",
83 | " kBET_=True, # kBET return nan sometimes, need to examine\n",
84 | " ilisi_=True,\n",
85 | " clisi_=True,\n",
86 | ")"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "results"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": []
104 | }
105 | ],
106 | "metadata": {
107 | "language_info": {
108 | "name": "python"
109 | }
110 | },
111 | "nbformat": 4,
112 | "nbformat_minor": 2
113 | }
114 |
--------------------------------------------------------------------------------
/Cell-type Annotation/cta_ft.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import seaborn as sns\n",
11 | "sns.set_style(\"whitegrid\")\n",
12 | "import pandas as pd\n",
13 | "import numpy as np \n",
14 | "import scipy.stats as stats\n",
15 | "from collections import Counter\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "import umap\n",
18 | "import matplotlib\n",
19 | "import mygene\n",
20 | "%matplotlib inline\n",
21 | "import pickle\n",
22 | "import sklearn\n",
23 | "import random\n",
24 | "import scanpy as sc\n",
25 | "import torch\n",
26 | "import torch.nn as nn\n",
27 | "import torch.functional as Fx\n",
28 | "from sklearn.model_selection import StratifiedKFold\n",
29 | "from sklearn.metrics import roc_curve, auc\n",
30 | "from sklearn.linear_model import LogisticRegression\n",
31 | "from sklearn.ensemble import RandomForestClassifier\n",
32 | "from sklearn.model_selection import train_test_split\n",
33 | "from sklearn.cluster import MiniBatchKMeans\n",
34 | "from xgboost import XGBClassifier\n",
35 | "# import sentence_transformers\n",
36 | "plt.style.use('ggplot')\n",
37 | "#plt.style.use('seaborn-v0_8-dark-palette')\n",
38 | "plt.rcParams['axes.facecolor'] = 'white'\n",
39 | "# plt.rcParams.update({\n",
40 | "# \"text.usetex\": True,\n",
41 | "# \"font.family\": \"Arial\"\n",
42 | "# })\n",
43 | "import matplotlib_inline\n",
44 | "matplotlib_inline.backend_inline.set_matplotlib_formats('retina')\n",
45 | "np.random.seed(202310)\n",
46 | "# use hnswlib for NN classification\n",
47 | "try:\n",
48 | " import hnswlib\n",
49 | " hnswlib_imported = True\n",
50 | "except ImportError:\n",
51 | " hnswlib_imported = False\n",
52 | " print(\"hnswlib not installed! We highly recommend installing it for fast similarity search.\")\n",
53 | " print(\"To install it, run: pip install hnswlib\")\n",
54 | "from scipy.stats import mode"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "\n",
64 | "import torch.nn as nn\n",
65 | "import torch.nn.functional as F"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "# set seed to control randomness\n",
75 | "import pytorch_lightning as pl\n",
76 | "pl.seed_everything(0)"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
86 | "# device = 'cpu'"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "adata_train = sc.read(\"/gpfs/gibbs/pi/zhao/wl545/pbmc/datasets/demo_train.h5ad\")\n",
96 | "adata_test = sc.read(\"/gpfs/gibbs/pi/zhao/wl545/pbmc/datasets/demo_test.h5ad\")"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "\n",
106 | "adata_comb = sc.concat([adata_train, adata_test], label = 'combinone', keys = ['train', 'test'])\n",
107 | "adata_comb"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "\n",
117 | "\n",
118 | "with open(\"/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/ensem_emb_gpt3.5all.pickle\", \"rb\") as fp:\n",
119 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
120 | "gene_names= list(adata_comb.var.index)\n",
121 | "count_missing = 0\n",
122 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
123 | "lookup_embed = np.zeros(shape=(len(gene_names),EMBED_DIM))\n",
124 | "for i, gene in enumerate(gene_names):\n",
125 | " if gene in GPT_3_5_gene_embeddings:\n",
126 | " lookup_embed[i,:] = GPT_3_5_gene_embeddings[gene].flatten()\n",
127 | " else:\n",
128 | " count_missing+=1\n",
129 | "lookup_embed.shape\n",
130 | "\n",
131 | "# lookup_embed = np.random.rand(lookup_embed.shape[0], lookup_embed.shape[1])\n",
132 | "\n",
133 | "adata_train = adata_comb[adata_comb.obs.combinone == 'train']\n",
134 | "adata_test = adata_comb[adata_comb.obs.combinone == 'test']"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "metadata": {},
141 | "outputs": [],
142 | "source": [
143 | "from sklearn.preprocessing import LabelEncoder\n",
144 | "\n",
145 | "label_encoder = LabelEncoder().fit(adata_train.obs.label)\n",
146 | "\n",
147 | "train_obs, valid_obs = train_test_split(adata_train.obs_names, test_size=0.1, random_state=1 )\n",
148 | "adata_train_train = adata_train[train_obs]\n",
149 | "adata_train_valid = adata_train[valid_obs]\n",
150 | "\n",
151 | "train_label = label_encoder.transform(adata_train_train.obs.label)\n",
152 | "valid_label = label_encoder.transform(adata_train_valid.obs.label)\n",
153 | "\n",
154 | "\n",
155 | "X_train = torch.FloatTensor(adata_train_train.X)\n",
156 | "\n",
157 | "train_label = torch.FloatTensor(train_label)\n",
158 | "\n",
159 | "batch_size = 512\n",
160 | "lookup_embed = torch.FloatTensor(lookup_embed).to(device)"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "class Net(nn.Module):\n",
170 | " def __init__(self):\n",
171 | " super().__init__()\n",
172 | " self.fc1 = nn.Linear(lookup_embed.shape[1], 64)\n",
173 | " self.fc2 = nn.Linear(64, 32)\n",
174 | " self.fc3 = nn.Linear(32, len(label_encoder.classes_))\n",
175 | " self.act = nn.ReLU()\n",
176 | "\n",
177 | " def forward(self, x, inputs):\n",
178 | " x = self.act(self.fc1(x))\n",
179 | " x = self.fc2(x) # can have dataset-specific gene embeddings\n",
180 | " emb = torch.matmul(inputs, x)\n",
181 | " label_out = self.fc3(emb)\n",
182 | " return label_out,emb\n",
183 | "\n",
184 | "\n",
185 | "\n",
186 | "net = Net().to(device)\n",
187 | "dataset = torch.utils.data.TensorDataset(X_train.to(device), train_label.to(device))\n",
188 | "trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,\n",
189 | " shuffle=True)\n",
190 | "\n",
191 | "\n",
192 | "import torch.optim as optim\n",
193 | "criterion = nn.CrossEntropyLoss()\n",
194 | "optimizer = optim.Adam(net.parameters(), lr=1e-3)\n"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "def model_evaluation(model, data, labels):\n",
204 | " model.eval()\n",
205 | " data = data.to(device)\n",
206 | " labels = labels.to(device)\n",
207 | " outputs,_ = net(lookup_embed, data)\n",
208 | "\n",
209 | " _, predicted = torch.max(outputs, 1)\n",
210 | " \n",
211 | " return (predicted == labels).sum().item() / len(labels)\n",
212 | "\n",
213 | "def model_output(model, data):\n",
214 | " model.eval()\n",
215 | " data = data.to(device)\n",
216 | " outputs,emb = net(lookup_embed, data)\n",
217 | "\n",
218 | " _, predicted = torch.max(outputs, 1)\n",
219 | " \n",
220 | " return predicted.cpu().numpy(), emb.cpu().detach().numpy()\n",
221 | "\n",
222 | "def eval_function(model, adata_train_train, adata_train_valid):\n",
223 | " model.eval()\n",
224 | " _,genePT_w_emebed_train = model_output(model, torch.FloatTensor(adata_train_train.X))\n",
225 | " _,genePT_w_emebed_test = model_output(model, torch.FloatTensor(adata_train_valid.X))\n",
226 | " \n",
227 | " y_train = adata_train_train.obs.label\n",
228 | " y_test = adata_train_valid.obs.label\n",
229 | " \n",
230 | " # cell type clustering\n",
231 | " # very quick test\n",
232 | " k = 10 # number of neighbors\n",
233 | " ref_cell_embeddings = genePT_w_emebed_train\n",
234 | " test_emebd = genePT_w_emebed_test\n",
235 | " neighbors_list_gpt_v2 = []\n",
236 | " if hnswlib_imported:\n",
237 | " # Declaring index, using most of the default parameters from https://github.com/nmslib/hnswlib\n",
238 | " p = hnswlib.Index(space = 'cosine', dim = ref_cell_embeddings.shape[1]) # possible options are l2, cosine or ip\n",
239 | " p.init_index(max_elements = ref_cell_embeddings.shape[0], ef_construction = 200, M = 16)\n",
240 | "\n",
241 | " # Element insertion (can be called several times):\n",
242 | " p.add_items(ref_cell_embeddings, ids = np.arange(ref_cell_embeddings.shape[0]))\n",
243 | "\n",
244 | " # Controlling the recall by setting ef:\n",
245 | " p.set_ef(50) # ef should always be > k\n",
246 | "\n",
247 | " # Query dataset, k - number of closest elements (returns 2 numpy arrays)\n",
248 | " labels, distances = p.knn_query(test_emebd, k = k)\n",
249 | "\n",
250 | " idx_list=[i for i in range(test_emebd.shape[0])]\n",
251 | " gt_list = []\n",
252 | " pred_list = []\n",
253 | " for k in idx_list:\n",
254 | " # this is the true cell type\n",
255 | " gt = y_test[k]\n",
256 | " if hnswlib_imported:\n",
257 | " idx = labels[k]\n",
258 | " else:\n",
259 | " idx, sim = get_similar_vectors(test_emebd[k][np.newaxis, ...], ref_cell_embeddings)\n",
260 | " pred = mode(y_train[idx], axis=0)\n",
261 | " neighbors_list_gpt_v2.append(y_train[idx])\n",
262 | " gt_list.append(gt)\n",
263 | " pred_list.append(pred[0][0])\n",
264 | " acc = sklearn.metrics.accuracy_score(gt_list, pred_list)\n",
265 | " return acc\n",
266 | "\n",
267 | "from pytorch_metric_learning import miners, losses\n",
268 | "miner = miners.MultiSimilarityMiner()\n",
269 | "loss_func = losses.TripletMarginLoss()"
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": null,
275 | "metadata": {},
276 | "outputs": [],
277 | "source": [
278 | "\n",
279 | "prev = 0\n",
280 | "net.train()\n",
281 | "model_best = None\n",
282 | "for epoch in range(30): # loop over the dataset multiple times\n",
283 | " running_loss = 0.0\n",
284 | " for i, data in enumerate(trainloader, 0):\n",
285 | " # get the inputs; data is a list of [inputs, labels]\n",
286 | " inputs, labels = data\n",
287 | " labels = labels.long()\n",
288 | "\n",
289 | " # zero the parameter gradients\n",
290 | " optimizer.zero_grad()\n",
291 | "\n",
292 | " # forward + backward + optimize\n",
293 | " outputs,emb = net(lookup_embed, inputs)\n",
294 | " \n",
295 | " loss = criterion(outputs, labels) + 100 * loss_func(emb, labels)\n",
296 | "# loss = loss_func(emb, labels)\n",
297 | " loss.backward()\n",
298 | " optimizer.step()\n",
299 | " \n",
300 | "\n",
301 | " # print statistics\n",
302 | " running_loss += loss.item()\n",
303 | " if i % 2000 == 1999: # print every 2000 mini-batches\n",
304 | " print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')\n",
305 | " running_loss = 0.0\n",
306 | " \n",
307 | " if epoch % 5 ==0:\n",
308 | " eval_acc = eval_function(net, adata_train_train, adata_train_valid)\n",
309 | " print(eval_acc)\n",
310 | " if eval_acc > prev:\n",
311 | " prev = eval_acc\n",
312 | " model_best = pickle.loads(pickle.dumps(net))\n",
313 | " else:\n",
314 | " print(\"stop the training at:\", epoch)\n",
315 | " break\n",
316 | "print('Finished Training')"
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": null,
322 | "metadata": {},
323 | "outputs": [],
324 | "source": [
325 | "input_data = torch.FloatTensor(adata_test.X)\n",
326 | "labels = torch.FloatTensor()\n",
327 | "label_predict,embeddings = model_output(model_best, input_data)\n",
328 | "outlabel = label_encoder.inverse_transform(label_predict)\n",
329 | "\n",
330 | "_,genePT_w_emebed_train = model_output(model_best, torch.FloatTensor(adata_train.X))\n",
331 | "_,genePT_w_emebed_test = model_output(model_best, torch.FloatTensor(adata_test.X))\n",
332 | "\n",
333 | "y_train = adata_train.obs.label\n",
334 | "y_test = adata_test.obs.label"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": null,
340 | "metadata": {},
341 | "outputs": [],
342 | "source": [
343 | "\n",
344 | "# cell type clustering\n",
345 | "# very quick test\n",
346 | "k = 10 # number of neighbors\n",
347 | "ref_cell_embeddings = genePT_w_emebed_train\n",
348 | "test_emebd = genePT_w_emebed_test\n",
349 | "neighbors_list_gpt_v2 = []\n",
350 | "if hnswlib_imported:\n",
351 | " # Declaring index, using most of the default parameters from https://github.com/nmslib/hnswlib\n",
352 | " p = hnswlib.Index(space = 'cosine', dim = ref_cell_embeddings.shape[1]) # possible options are l2, cosine or ip\n",
353 | " p.init_index(max_elements = ref_cell_embeddings.shape[0], ef_construction = 200, M = 16)\n",
354 | "\n",
355 | " # Element insertion (can be called several times):\n",
356 | " p.add_items(ref_cell_embeddings, ids = np.arange(ref_cell_embeddings.shape[0]))\n",
357 | "\n",
358 | " # Controlling the recall by setting ef:\n",
359 | " p.set_ef(50) # ef should always be > k\n",
360 | "\n",
361 | " # Query dataset, k - number of closest elements (returns 2 numpy arrays)\n",
362 | " labels, distances = p.knn_query(test_emebd, k = k)\n",
363 | "\n",
364 | "idx_list=[i for i in range(test_emebd.shape[0])]\n",
365 | "gt_list = []\n",
366 | "pred_list = []\n",
367 | "for k in idx_list:\n",
368 | " # this is the true cell type\n",
369 | " gt = y_test[k]\n",
370 | " if hnswlib_imported:\n",
371 | " idx = labels[k]\n",
372 | " else:\n",
373 | " idx, sim = get_similar_vectors(test_emebd[k][np.newaxis, ...], ref_cell_embeddings)\n",
374 | " pred = mode(y_train[idx], axis=0)\n",
375 | " neighbors_list_gpt_v2.append(y_train[idx])\n",
376 | " gt_list.append(gt)\n",
377 | " pred_list.append(pred[0][0])\n",
378 | "sklearn.metrics.accuracy_score(gt_list, pred_list)\n",
379 | "\n",
380 | "print('Precision, Recall, F1 (Marco weighted) for GenePT-w embedding: ', \\\n",
381 | " sklearn.metrics.precision_recall_fscore_support(gt_list, pred_list,average='macro'))\n"
382 | ]
383 | },
384 | {
385 | "cell_type": "code",
386 | "execution_count": null,
387 | "metadata": {},
388 | "outputs": [],
389 | "source": []
390 | }
391 | ],
392 | "metadata": {
393 | "language_info": {
394 | "name": "python"
395 | }
396 | },
397 | "nbformat": 4,
398 | "nbformat_minor": 2
399 | }
400 |
--------------------------------------------------------------------------------
/Cell-type Annotation/cta_gpt.py:
--------------------------------------------------------------------------------
1 | # remember to set the OpenAI token in ahead.
2 | from openai import OpenAI
3 | client = OpenAI()
4 |
5 | response = client.chat.completions.create(
6 | model="gpt-4",
7 | messages=[
8 | {"role": "user", "content": "This cell has genes ranked by their expression as: CCDC71L NTS F8 GHSR GRIN2D VNN3 DTX1 SPOCK2 TRPC5 AQP9 GGT1 DUSP23 COL16A1 CCDC3 CH25H PTX3 CADM3 NTRK2 AGR3 LDB2 LRRTM1 FOSL1 PIK3AP1 CHST8 TGFBR2 MBOAT4 BCL2 MYRF GPC1 PPARGC1A SLIT3 DOCK2 SYT1 MFSD2A POLR3B LURAP1L UGT2B7 LYN GALNT2 RASD2 ALDH1A2 F10 C18orf54 CGA HEG1 COL14A1 SLC43A2 NRARP NPNT BMF GCGR SPSB1 RAB34 PRKAR2B TET3 DIAPH3 RAMP2 GLI2 CCNA2 ABCB1 PCDH9 TMEM233 PPP2R2B SOCS2 COX6A2 GALNT7 AMOTL1 CREB3L1 ADAM8 SYBU PRCP RNF186 ITIH4 CACHD1 FAM155A EGF RCC1 MNX1 GGH NTM ZNF180 SLC16A7 NUF2 F2 HOXD8 LTF PTGFR FAM83D SLC2A2 TRIM47 LMO3 TGIF1 HPGD ATP6V0D2 AFAP1L2 FA2H C3orf80 NCF4 SH2D2A HDC RARB FBLN5 HECTD3 GRAP2 PID1 C3AR1 LGR5 MUCL1 FHL3 CASP1 GRIN2A TAX1BP3 CSPG4 ZNF804A PKNOX2 DPEP1 RAMP3 LPAR5 IQGAP2 CMTM3 MET SLCO4A1 ANXA13 IKZF3 CYTH3 FUT3 CABLES1 HNF4G CHRDL1 PROSER2 TUBB1 PTPRT RASIP1 STMN4 GABRA3 SH3KBP1 MERTK CD4 SLC6A8 MOXD1 ASTN1 UGT2A3 HSPBAP1 GATA3 MTMR1 NDRG4 SORD PPIC TRPV2 CPXM1 CPZ CCNF TRPM5 FLNC DUSP22 DDIT4L ID4 GAPT ARHGAP18 SERPINB4 WAS SLC16A14 CD79A ZEB1 PDE2A SLAMF8 LCP2 AKAP5 KIF2C CDR2L PDE10A SPTSSB SYNM CNIH2 ALAS2 C11orf53 MDFI GDNF ARHGEF2 GAS6 FERMT3 PLLP FZD10 PTPN7 ADD2 ADAM9 SIRPA HKDC1 TMSB15A ZMAT4 MECOM SH3PXD2B DTX4 PPARGC1B SLC16A10 THSD4 LPL FCGR3B CLEC2B TSPO SLC9A9 CPNE4 TMEM217 PLEK PON1 ST6GAL1 SOX11 P4HA3 SLC9A3R2 DUOXA1 CDA NNAT CLIC6 ADORA2A LRRN3 CPNE5 RFTN2 STK17A B3GNT3 FAM20C GRK5 DSC2 BIN2 EGLN3 GIMAP8 CHRM3 LRRK2 EFHC2 HES4 CIT TDO2 F5 TIMP4 TH PLP1 CACNA2D4 SLC1A3 PSCA SYNDIG1 GAB2 ORM1 NEURL1B SERPINI1 CPXM2 IL17RB SHMT2 PTHLH ALDH2 TWIST2 CYSLTR2 PIM2 IGSF10 APOBR KIAA1522 ZNF367 SYDE1 ASB12 ACCS DOCK9 DCHS1 MS4A4A DKK1 ACSL1 PGM1 RRN3 HHEX LPAR1 CD86 NID2 SLC39A5 SMAD3 RASL12 CLPSL1 BDKRB1 ZP1 SATB2 GSTM5 GIMAP4 EFEMP1 TNFSF15 ARL14 MYC CD48 C1QTNF1 ZNF831 SH3BP4 HSD11B2 LTC4S PLCH2 ID2 STIM1 LILRB5 SOX7 CAPN5 NPR3 KDELR3 IL1R2 S100A8 VSNL1 TROAP PCDH8 BAALC TUBB2A CCND2 CD1D OSR2 GRP PLEKHO1 CDHR5 PCOLCE PTPRZ1 BAAT POU2F2 CITED4 ZNF503 MTUS1 SLC7A5 ISM2 HABP2 GMDS CRABP2 DGAT2 NR2F1 KISS1R NCF2 MMP10 GPR62 NPTX1 ZSCAN5A SLC2A1 LST1 CLHC1 GIMAP7 LRRC25 BPIFC DPH2 TMEM47 SUCNR1 TMEM26 KCNG1 BFSP1 CKAP2L ZG16B CLMP PPP1R15A RBP2 ALOX5 BICC1 CENPE BNC2 JAKMIP3 IL13RA2 CCDC102B AURKB WNT6 CHRNA2 NCEH1 C1QL1 PPP1R16B GATA4 CDH17 LMCD1 STC2 VTN MOB3B EPHA2 GPSM1 EPHA3 PDLIM7 HOXA5 PMAIP1 TSPAN15 DSE CISH NTN4 IGFBP6 FANK1 CLEC10A ZYX SLC22A16 ZNF165 TMEFF2 MYOF CDK1 INHA KCNJ15 TCF21 POU3F1 HOXA3 MCM5 SERPINA4 AQP8 SLIT2 GALNT9 FUT10 FRMD6 BUB1 AMIGO2 KCNH8 BTK E2F1 HS3ST1 TRPM2 KCNK5 MYZAP ANTXR1 TSPAN11 ANTXR2 PPM1E CENPM FAM177B TCIRG1 LEF1 ADAMTS7 HPCAL4 PRC1 GREM2 ADAMTS14 FOSL2 DHCR24 VGLL1 SASH3 NPPB FLT3 PLA2G4A HAND2 PLXND1 FXYD6 IRX1 ACTC1 SLC17A6 ADRB2 BARX2 KCNH2 CYP2C9 ANLN MYOM1 SDC3 CHODL BCL3 CRISP3 CAPN13 PDLIM4 CLEC1A PCDH7 LGR4 DLC1 CDCA8 HFM1 NELL2 KIAA1755 FXYD3 NR0B1 CD7 FCGR3A NR1H4 TCF19 SLFN11 PPP1R18 CSF2RB TBX18 PRKG1 SLC38A5 ITGB3 CD8B BUB1B GPR34 IGSF3 TLE4 TUBB6 COLEC11 LY6H CDCA3 EYA4 TMTC4 TRPV6 LAMB2 SNTB1 ZNF385B LILRB4 SLITRK6 RHOJ ATP7B C1orf162 TM4SF5 HLX MCTP1 TNF CALB2 ENPEP SPRY1 NR0B2 ITGA6 CYP2C8 KLF4 RAB27B WDR25 POC1A PDE1A TNFSF10 UBASH3B TACC3 RDH12 PRR16 IER5 RORB PLK4 ATP8B1 ULBP2 EVI2B RAMP1 BLM ASF1B CDCA7 GJB3 SLC6A6 MND1 RAD51 C3orf52 GNGT2 CALCA FNDC4 SHISA2 OXGR1 MAG FOXS1 CD52 SLC12A7 DAB2IP FOXF2 PLAC9 TMEM100 TSC22D3 ACHE ACE2 ZCCHC12 CDKN3 LONRF3 TMPPE LIMS2 FAM124A VWC2L LILRB2 CCR5 CCKBR PEX12 BRD9 ZNF613 CD300LB PGM2 BLNK ARSJ DNMBP CD72 EXO5 SYTL5 ZBTB25 TRIM63 S1PR5 PDP1 SERPINB8 PLD6 EBI3 ZNF502 GIPC2 LRRN4 ZNF77 SEMA4F FJX1 RGS13 SLC18A3 DOK2 ZKSCAN2 CCL7 ZNF792 CA10 CDK8 GK HTR3A ZNF441 NKD1 TGFBR1 SLC2A14 CXCL11 CMYA5 S100A1 GFRA1 ZNF630 SCN2A OASL ZNF235 TFAP2A DOK5 RAPGEF5 IL18R1 RASA3 COCH ITM2A ASB4 TMEM255B PAX4 SH3GL3 GUCA2A HMCN1 LRIF1 OAS1 CNTNAP2 ZC4H2 ZNF133 MYO7A ZNF117 SNPH TNK1 HHAT PTPRG RAB33A ERBB4 PRMT7 INSM1 BARHL1 ACSL5 PALMD IL5RA ARNTL ZNF30 ZNF267 EVA1B RAD51D FAM107B MCM2 FGF7 SLITRK1 KRT6A CYP2C18 LUZP2 FGF18 RTTN ZNF765 UGT2B4 PARP16 RGS7BP CST1 DIO2 OPCML IL4I1 PEAR1 MFNG HGF DTL ZNF473 ROR2 FGF14 MB21D2 TTC21A SLC29A3 BMP8A CYP2E1 PCDHB15 PSG3 EFR3B FMO4 IGDCC4 CPNE7 HSPB6 GAS1 CLEC4A FBXO16 COL22A1 DMRTA1 LMO4 FMNL2 KPNA7 APLN RCN3 LAMA3 FBXO25 MPP3 PARS2 NPW FAM167B CSMD1 ANO7 PDGFC MKNK2 TDP1 SOX17 HSPA1L CLSTN2 SNX16 ASB2 SEMA6A HR SPRR3 FOXD3 GSTM4 CD5 DISP2 DPYD MSR1 MPP1 CGN CCDC34 SESN2 ADH1A ZNF132 ZNF558 APBB1IP MYRIP TMEM88 VEPH1 MRGPRF SLC1A1 CARD16 PTGIS NUP62CL CHST3 CLEC11A CCR7 SPINT1 NHLH2 CHRNA1 LRFN5 ARL2BP PYROXD2 ARRDC2 CAMTA2 ZNF611 FAM3B ALDH1B1 CD38 TNFRSF10D RCL1 TBL3 KIF20A SLC36A1 ESRP2 IL18 CHI3L2 NEK6 CMKLR1 PILRA LMO7 SPDL1 DLGAP5 GFI1 USF1 VILL CD163L1 ABCG2 ZSCAN31 ANKRD30B IL23A AP1M1 TYMS APLF TRO HSPA12A C16orf54 C8orf48 CORIN FOXC2 CHL1 GMNC GPX8 C16orf71 SPAG17 GLDC BMX IGSF11 CAPN3 ADORA1 LRRC20 AMMECR1L CARNS1 DENND6B MSLN CYP4F12 GPSM3 BYSL FAH HHIPL1 KLHL1 ANKRD34C CHRDL2 MID1IP1 HPX SLAMF9 DGKG RAPGEF4 IL22RA1 VEGFC GPC6 CDC20 LDLRAD4 SPRR1A SERPIND1 NCAPH PCDHB4 SNX33 OTUD1 . What is the cell type of this cell?"}
9 | ]
10 | )
11 |
12 | print(response)
13 |
14 | ###example output:
15 | '''
16 | The list you've provided appears to be a transcriptomic profile, which includes a variety of genes expressed in a cell. Identifying the cell type from a list of genes would typically require comparing the expression profile to known profiles from various cell types, often using bioinformatics tools or databases such as the Human Cell Atlas or single-cell RNA sequencing (scRNA-seq) databases that classify cell types based on their gene expression patterns.
17 |
18 | Without such tools or databases at my disposal, I cannot definitively identify the cell type just from a list of genes. Determining the cell type would involve analyzing which genes are expressed, their levels of expression, and how those levels compare to the expression profiles of known cell types. This process usually involves complex data analysis using specialized software.
19 |
20 | In a laboratory or research setting, scientists would use bioinformatics analysis to map the gene expression profile against a database of cell types to find the closest match. If you have access to such databases or tools, that would be the recommended course of action to identify the cell type associated with this gene expression profile.
21 | '''
--------------------------------------------------------------------------------
/Cell-type Annotation/cta_zeroshot.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import seaborn as sns\n",
11 | "sns.set_style(\"whitegrid\")\n",
12 | "import pandas as pd\n",
13 | "import numpy as np \n",
14 | "import scipy.stats as stats\n",
15 | "from collections import Counter\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "import umap\n",
18 | "import matplotlib\n",
19 | "import mygene\n",
20 | "%matplotlib inline\n",
21 | "import pickle\n",
22 | "import sklearn\n",
23 | "import random\n",
24 | "import scanpy as sc\n",
25 | "# import sentence_transformers\n",
26 | "plt.style.use('ggplot')\n",
27 | "#plt.style.use('seaborn-v0_8-dark-palette')\n",
28 | "plt.rcParams['axes.facecolor'] = 'white'\n",
29 | "plt.rcParams.update({\n",
30 | " \"text.usetex\": True,\n",
31 | " \"font.family\": \"Helvetica\"\n",
32 | "})\n",
33 | "import matplotlib_inline\n",
34 | "import time\n",
35 | "matplotlib_inline.backend_inline.set_matplotlib_formats('retina')\n",
36 | "np.random.seed(202310)\n",
37 | "# use hnswlib for NN classification\n",
38 | "try:\n",
39 | " import hnswlib\n",
40 | " hnswlib_imported = True\n",
41 | "except ImportError:\n",
42 | " hnswlib_imported = False\n",
43 | " print(\"hnswlib not installed! We highly recommend installing it for fast similarity search.\")\n",
44 | " print(\"To install it, run: pip install hnswlib\")\n",
45 | "from scipy.stats import mode\n",
46 | "\n",
47 | "import requests"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "adata_train = sc.read(\"/gpfs/gibbs/pi/zhao/tl688/deconvdatasets/demo_train.h5ad\")"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "with open(\"/gpfs/gibbs/pi/zhao/tl688/GenePT/data_embedding/GPT_3_5_gene_embeddings.pickle\", \"rb\") as fp:\n",
66 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
67 | "gene_names= list(adata_train.var.index)\n",
68 | "count_missing = 0\n",
69 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
70 | "lookup_embed = np.zeros(shape=(len(gene_names),EMBED_DIM))\n",
71 | "for i, gene in enumerate(gene_names):\n",
72 | " if gene in GPT_3_5_gene_embeddings:\n",
73 | " lookup_embed[i,:] = GPT_3_5_gene_embeddings[gene].flatten()\n",
74 | " else:\n",
75 | " count_missing+=1\n",
76 | "genePT_w_emebed = np.dot(adata_train.X,lookup_embed)/len(gene_names)\n",
77 | "print(f\"Unable to match {count_missing} out of {len(gene_names)} genes in the GenePT-w embedding\")\n",
78 | "genePT_w_emebed_train = genePT_w_emebed"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "adata_test = sc.read(\"/gpfs/gibbs/pi/zhao/tl688/deconvdatasets/demo_test.h5ad\")"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "with open(\"/gpfs/gibbs/pi/zhao/tl688/GenePT/data_embedding/GPT_3_5_gene_embeddings.pickle\", \"rb\") as fp:\n",
97 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
98 | "gene_names= list(adata_test.var.index)\n",
99 | "count_missing = 0\n",
100 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
101 | "lookup_embed = np.zeros(shape=(len(gene_names),EMBED_DIM))\n",
102 | "for i, gene in enumerate(gene_names):\n",
103 | " if gene in GPT_3_5_gene_embeddings:\n",
104 | " lookup_embed[i,:] = GPT_3_5_gene_embeddings[gene].flatten()\n",
105 | " else:\n",
106 | " count_missing+=1\n",
107 | "# genePT_w_emebed = np.dot(adata_test.X,lookup_embed)/len(gene_names)\n",
108 | "print(f\"Unable to match {count_missing} out of {len(gene_names)} genes in the GenePT-w embedding\")\n",
109 | "\n",
110 | "genePT_w_emebed_test = genePT_w_emebed"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "y_train = adata_train.obs.Celltype\n",
120 | "y_test = adata_test.obs.Celltype"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "# cell type clustering\n",
130 | "# very quick test\n",
131 | "k = 10 # number of neighbors\n",
132 | "ref_cell_embeddings = genePT_w_emebed_train\n",
133 | "test_emebd = genePT_w_emebed_test\n",
134 | "neighbors_list_gpt_v2 = []\n",
135 | "if hnswlib_imported:\n",
136 | " # Declaring index, using most of the default parameters from https://github.com/nmslib/hnswlib\n",
137 | " p = hnswlib.Index(space = 'cosine', dim = ref_cell_embeddings.shape[1]) # possible options are l2, cosine or ip\n",
138 | " p.init_index(max_elements = ref_cell_embeddings.shape[0], ef_construction = 200, M = 16)\n",
139 | "\n",
140 | " # Element insertion (can be called several times):\n",
141 | " p.add_items(ref_cell_embeddings, ids = np.arange(ref_cell_embeddings.shape[0]))\n",
142 | "\n",
143 | " # Controlling the recall by setting ef:\n",
144 | " p.set_ef(50) # ef should always be > k\n",
145 | "\n",
146 | " # Query dataset, k - number of closest elements (returns 2 numpy arrays)\n",
147 | " labels, distances = p.knn_query(test_emebd, k = k)\n",
148 | "\n",
149 | "idx_list=[i for i in range(test_emebd.shape[0])]\n",
150 | "gt_list = []\n",
151 | "pred_list = []\n",
152 | "for k in idx_list:\n",
153 | " # this is the true cell type\n",
154 | " gt = y_test[k]\n",
155 | " if hnswlib_imported:\n",
156 | " idx = labels[k]\n",
157 | " else:\n",
158 | " idx, sim = get_similar_vectors(test_emebd[k][np.newaxis, ...], ref_cell_embeddings)\n",
159 | " pred = mode(y_train[idx], axis=0)\n",
160 | " neighbors_list_gpt_v2.append(y_train[idx])\n",
161 | " gt_list.append(gt)\n",
162 | " pred_list.append(pred[0][0])\n",
163 | "print(\"Accuracy\", sklearn.metrics.accuracy_score(gt_list, pred_list))"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "\n",
173 | "print('Precision, Recall, F1 (Marco weighted) for GenePT-w embedding: ', \\\n",
174 | " sklearn.metrics.precision_recall_fscore_support(gt_list, pred_list,average='macro'))"
175 | ]
176 | }
177 | ],
178 | "metadata": {
179 | "language_info": {
180 | "name": "python"
181 | }
182 | },
183 | "nbformat": 4,
184 | "nbformat_minor": 2
185 | }
186 |
--------------------------------------------------------------------------------
/Clustering/clustering.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import seaborn as sns\n",
11 | "sns.set_style(\"whitegrid\")\n",
12 | "import pandas as pd\n",
13 | "import numpy as np \n",
14 | "import scipy.stats as stats\n",
15 | "from collections import Counter\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "import umap\n",
18 | "import matplotlib\n",
19 | "import mygene\n",
20 | "%matplotlib inline\n",
21 | "import pickle\n",
22 | "import sklearn\n",
23 | "import random\n",
24 | "import scanpy as sc\n",
25 | "from sklearn.model_selection import StratifiedKFold\n",
26 | "from sklearn.metrics import roc_curve, auc\n",
27 | "from sklearn.linear_model import LogisticRegression\n",
28 | "from sklearn.ensemble import RandomForestClassifier\n",
29 | "from sklearn.model_selection import train_test_split\n",
30 | "from sklearn.cluster import MiniBatchKMeans\n",
31 | "from xgboost import XGBClassifier\n",
32 | "# import sentence_transformers\n",
33 | "plt.style.use('ggplot')\n",
34 | "#plt.style.use('seaborn-v0_8-dark-palette')\n",
35 | "plt.rcParams['axes.facecolor'] = 'white'\n",
36 | "# plt.rcParams.update({\n",
37 | "# \"text.usetex\": False,\n",
38 | "# \"font.family\": \"Helvetica\"\n",
39 | "# })\n",
40 | "import matplotlib_inline\n",
41 | "import scib_metrics\n",
42 | "matplotlib_inline.backend_inline.set_matplotlib_formats('retina')\n",
43 | "import openai\n",
44 | "# remember to set your open AI API key!\n",
45 | "openai.api_key = '' #replace it with your own API\n",
46 | "np.random.seed(202310)\n",
47 | "# use hnswlib for NN classification\n",
48 | "try:\n",
49 | " import hnswlib\n",
50 | " hnswlib_imported = True\n",
51 | "except ImportError:\n",
52 | " hnswlib_imported = False\n",
53 | " print(\"hnswlib not installed! We highly recommend installing it for fast similarity search.\")\n",
54 | " print(\"To install it, run: pip install hnswlib\")\n",
55 | "from scipy.stats import mode"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "# Here we consider steps to read different datasets.\n",
65 | "\n",
66 | "# adata = sc.read(\"/gpfs/gibbs/pi/zhao/tl688/deconvdatasets/demo_train.h5ad\")\n",
67 | "\n",
68 | "# adata = sc.read(\"/gpfs/gibbs/pi/zhao/tl688/deconvdatasets/demo_test.h5ad\")\n",
69 | "\n",
70 | "# sampled_adata = sc.read_h5ad(\"../sample_aorta_data_updated.h5ad\")\n",
71 | "# sampled_adata = sampled_adata[np.where(sampled_adata.obs.celltype!='Unknown')[0]]\n",
72 | "# adata = sampled_adata.copy()\n",
73 | "# adata.obs['Celltype'] = adata.obs['celltype'].copy()\n",
74 | "\n",
75 | "adata = sc.read(\"/gpfs/gibbs/pi/zhao/wl545/pbmc/datasets/3k_test.h5ad\")\n",
76 | "adata.obs['Celltype'] = adata.obs['label'].copy()"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "\n",
86 | "def evaluate_nmi_ari(adata, label = 'scClassify', key='X_gpt3.5'):\n",
87 | " labels = np.array(list(adata.obs[label]))\n",
88 | " result1 = scib_metrics.nmi_ari_cluster_labels_leiden(adata.obsp['connectivities'], labels = labels, n_jobs = -1)\n",
89 | " result2 = scib_metrics.silhouette_label(adata.obsm[key], labels = labels, rescale=True, chunk_size=256)\n",
90 | " print(result1)\n",
91 | " print(result2)\n",
92 | " return result1, result2"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "\n",
102 | "with open(\"/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/ensem_emb_gpt3.5all.pickle\", \"rb\") as fp:\n",
103 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
104 | "gene_names= list(adata.var.index)\n",
105 | "count_missing = 0\n",
106 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
107 | "lookup_embed_genept = np.zeros(shape=(len(gene_names),EMBED_DIM))\n",
108 | "for i, gene in enumerate(gene_names):\n",
109 | " if gene in GPT_3_5_gene_embeddings:\n",
110 | " lookup_embed_genept[i,:] = GPT_3_5_gene_embeddings[gene].flatten()\n",
111 | " else:\n",
112 | " count_missing+=1\n",
113 | " \n",
114 | "with open(\"/gpfs/gibbs/pi/zhao/tl688/GenePT/data_embedding/GPT_3_5_gene_embeddings.pickle\", \"rb\") as fp:\n",
115 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
116 | "gene_names= list(adata.var.index)\n",
117 | "count_missing = 0\n",
118 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
119 | "lookup_embed_gpt35 = np.zeros(shape=(len(gene_names),EMBED_DIM))\n",
120 | "for i, gene in enumerate(gene_names):\n",
121 | " if gene in GPT_3_5_gene_embeddings:\n",
122 | " lookup_embed_gpt35[i,:] = GPT_3_5_gene_embeddings[gene].flatten()\n",
123 | " else:\n",
124 | " count_missing+=1 "
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "# we have different settings for embeddings\n",
134 | "\n",
135 | "genePT_w_emebed = (adata.X @ lookup_embed_genept) /len(gene_names) # GenePTWW\n",
136 | "\n",
137 | "genePT_w_emebed = (adata.X @ lookup_embed_gpt35) /len(gene_names) # GPT 3.5 aa\n",
138 | "\n",
139 | "genePT_w_emebed = adata.X / np.sum(adata.X, axis=1) [:,None] @ lookup_embed_gpt35 # GPT 3.5 wa\n",
140 | "\n",
141 | "\n",
142 | "lookup_embed = lookup_embed_genept + lookup_embed_gpt35 \n",
143 | "genePT_w_emebed = adata.X / np.sum(adata.X, axis=1) [:,None] @ lookup_embed # GenePT + GPT 3.5 wa\n",
144 | "\n",
145 | "lookup_embed = np.concatenate([lookup_embed_genept, lookup_embed_gpt35], axis=1)\n",
146 | "genePT_w_emebed = adata.X / np.sum(adata.X, axis=1) [:,None] @ lookup_embed # GenePT || GPT 3.5 wa\n",
147 | "\n",
148 | "print(f\"Unable to match {count_missing} out of {len(gene_names)} genes in the GenePT-w embedding\")\n",
149 | "genePT_w_emebed_test = genePT_w_emebed"
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "# For cell type, considering aggregration embeddings\n",
159 | "with open('ensem_emb_celltype.pickle', 'rb') as f:\n",
160 | " ct_name_getembedding = pickle.load(f)\n",
161 | "\n",
162 | "lookup_embed_ct = np.zeros(shape=(len(adata),EMBED_DIM))\n",
163 | "for i, gene in enumerate(adata.obs_names):\n",
164 | " lookup_embed_ct[i,:] = ct_name_getembedding[adata[gene].obs.Celltype.values[0]]\n",
165 | "\n",
166 | "genePT_w_emebed_test += lookup_embed_ct"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "# For PCA, we follow the pipeline from scanpy\n",
176 | "sc.pp.scale(adata)\n",
177 | "sc.tl.pca(adata)\n",
178 | "genePT_w_emebed_test = adata.obsm['X_pca']"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": null,
184 | "metadata": {},
185 | "outputs": [],
186 | "source": [
187 | "adata.obsm['X_gpt3.5'] = genePT_w_emebed_test \n",
188 | "sc.pp.neighbors(adata, use_rep='X_gpt3.5')\n",
189 | "evaluate_nmi_ari(adata, label = 'Celltype')"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": null,
195 | "metadata": {},
196 | "outputs": [],
197 | "source": [
198 | "sc.tl.umap(adata)\n",
199 | "sc.pl.umap(adata, color='Celltype')"
200 | ]
201 | }
202 | ],
203 | "metadata": {
204 | "language_info": {
205 | "name": "python"
206 | }
207 | },
208 | "nbformat": 4,
209 | "nbformat_minor": 2
210 | }
211 |
--------------------------------------------------------------------------------
/Get outputs from LLMs/query_35.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "\n",
11 | "import openai\n",
12 | "import time\n",
13 | "delay_sec = 5\n",
14 | "# remember to set your open AI API key!\n",
15 | "openai.api_key = '' #replace it with your own API\n",
16 | "\n",
17 | "import numpy as np\n",
18 | "import pickle\n",
19 | "\n",
20 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
21 | "lookup_embed = np.zeros(shape=(len(gene_all),EMBED_DIM))\n",
22 | "\n",
23 | "def get_gpt_embedding(text, model=\"text-embedding-ada-002\"):\n",
24 | " text = text.replace(\"\\n\", \" \")\n",
25 | " return np.array(openai.Embedding.create(input = [text], model=model)['data'][0]['embedding'])"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "gene_name_to_GPT_response = {}\n",
35 | "gene_name_getembedding = {}"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "df = pd.read_csv(\"gpt_ncbi_allgene.csv\", index_col = 0) # load gene name, modify this pathway is acceptable. If the name is in ensemble id format, using mygene to transfer the format.\n",
45 | "gene_all = list(df['gene'].values)"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "gene_completion_test = gene_all\n",
55 | "for gene in gene_completion_test:\n",
56 | " print(gene)\n",
57 | " try:\n",
58 | " completion = openai.ChatCompletion.create(model=\"gpt-3.5-turbo-1106\", \n",
59 | " messages=[{\"role\": \"user\", \n",
60 | " \"content\": f\"Please summarize the major function of gene: {gene}. Use academic language in one paragraph and include pathway information.\"}]) # replace the prompt for different metadata.\n",
61 | " gene_name_to_GPT_response[gene] = completion.choices[0].message.content\n",
62 | " gene_name_getembedding[gene] = get_gpt_embedding(gene_name_to_GPT_response[gene])\n",
63 | " time.sleep(1)\n",
64 | " except (openai.APIError, \n",
65 | " openai.error.APIError, \n",
66 | " openai.error.APIConnectionError, \n",
67 | " openai.error.RateLimitError, \n",
68 | " openai.error.ServiceUnavailableError, \n",
69 | " openai.error.Timeout) as e:\n",
70 | " #Handle API error here, e.g. retry or log\n",
71 | " print(f\"OpenAI API returned an API Error: {e}\")\n",
72 | " pass"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "\n",
82 | "with open('ensem_describe.pickle', 'wb') as handle:\n",
83 | " pickle.dump(gene_name_to_GPT_response, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
84 | " \n",
85 | "with open('ensem_emb_35.pickle', 'wb') as handle:\n",
86 | " pickle.dump(gene_name_getembedding, handle, protocol=pickle.HIGHEST_PROTOCOL)"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": []
95 | }
96 | ],
97 | "metadata": {
98 | "language_info": {
99 | "name": "python"
100 | }
101 | },
102 | "nbformat": 4,
103 | "nbformat_minor": 2
104 | }
105 |
--------------------------------------------------------------------------------
/In silico treatment/in-silico treatment.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import seaborn as sns\n",
11 | "sns.set_style(\"whitegrid\")\n",
12 | "import pandas as pd\n",
13 | "import numpy as np \n",
14 | "import scipy.stats as stats\n",
15 | "from collections import Counter\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "import umap\n",
18 | "import matplotlib\n",
19 | "import mygene\n",
20 | "%matplotlib inline\n",
21 | "import pickle\n",
22 | "import sklearn\n",
23 | "import random\n",
24 | "import scanpy as sc\n",
25 | "import torch\n",
26 | "import torch.nn as nn\n",
27 | "import torch.functional as F\n",
28 | "# import sentence_transformers\n",
29 | "plt.style.use('ggplot')\n",
30 | "#plt.style.use('seaborn-v0_8-dark-palette')\n",
31 | "plt.rcParams['axes.facecolor'] = 'white'\n",
32 | "import matplotlib_inline\n",
33 | "matplotlib_inline.backend_inline.set_matplotlib_formats('retina')\n",
34 | "np.random.seed(202310)\n",
35 | "# use hnswlib for NN classification\n",
36 | "try:\n",
37 | " import hnswlib\n",
38 | " hnswlib_imported = True\n",
39 | "except ImportError:\n",
40 | " hnswlib_imported = False\n",
41 | " print(\"hnswlib not installed! We highly recommend installing it for fast similarity search.\")\n",
42 | " print(\"To install it, run: pip install hnswlib\")\n",
43 | "from scipy.stats import mode"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
53 | "# device = 'cpu'\n",
54 | "sampled_adata = sc.read(\"/gpfs/gibbs/pi/zhao/tl688/board_heartcell/SCP1303/anndata/Cardiomyocyte_data_subsample0.1.h5ad\")"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "sc.pp.normalize_per_cell(sampled_adata)\n",
64 | "sc.pp.log1p(sampled_adata)\n",
65 | "\n",
66 | "sampled_adata.uns['log1p']['base'] = None\n",
67 | "sc.pp.highly_variable_genes(sampled_adata, n_top_genes=2000)\n",
68 | "sampled_adata = sampled_adata[:,sampled_adata.var.highly_variable]"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "adata_comb = sampled_adata"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "with open(\"/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/ensem_emb_gpt3.5all.pickle\", \"rb\") as fp:\n",
87 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
88 | "gene_names= list(adata_comb.var.index)\n",
89 | "count_missing = 0\n",
90 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
91 | "lookup_embed = np.zeros(shape=(len(gene_names),EMBED_DIM))\n",
92 | "for i, gene in enumerate(gene_names):\n",
93 | " if gene in GPT_3_5_gene_embeddings:\n",
94 | " lookup_embed[i,:] = GPT_3_5_gene_embeddings[gene].flatten()\n",
95 | " else:\n",
96 | " count_missing+=1\n",
97 | "lookup_embed.shape"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "train_obs,test_obs,train_label,test_label = train_test_split(sampled_adata.obs_names, \n",
107 | " sampled_adata.obs.disease,\n",
108 | " test_size=0.20, random_state=2023)\n",
109 | "\n",
110 | "adata_train = sampled_adata[train_obs]\n",
111 | "adata_test = sampled_adata[test_obs]"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "\n",
121 | "lookup_embed = torch.FloatTensor(lookup_embed).to(device)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "label_encoder = LabelEncoder().fit(adata_train.obs.disease)\n",
131 | "\n",
132 | "adata_train_train = adata_train\n",
133 | "\n",
134 | "train_label = label_encoder.transform(adata_train_train.obs.disease)\n",
135 | "\n",
136 | "\n",
137 | "X_train = torch.FloatTensor(adata_train_train.X.toarray())\n",
138 | "\n",
139 | "train_label = torch.FloatTensor(train_label)\n",
140 | "\n",
141 | "dataset = torch.utils.data.TensorDataset(X_train, train_label)\n",
142 | "\n",
143 | "batch_size = 512\n",
144 | "\n",
145 | "trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,\n",
146 | " shuffle=True, num_workers=2)"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "import torch.nn as nn\n",
156 | "import torch.nn.functional as F\n",
157 | "\n",
158 | "\n",
159 | "class Net(nn.Module):\n",
160 | " def __init__(self):\n",
161 | " super().__init__()\n",
162 | " self.fc1 = nn.Linear(lookup_embed.shape[1], 64)\n",
163 | " self.fc2 = nn.Linear(64, 32)\n",
164 | " self.fc3 = nn.Linear(32, len(label_encoder.classes_))\n",
165 | " self.act = nn.ReLU()\n",
166 | "\n",
167 | " def forward(self, x, inputs):\n",
168 | " x = self.act(self.fc1(x))\n",
169 | " x = self.fc2(x)\n",
170 | " emb = torch.matmul(inputs, x)\n",
171 | " label_out = self.fc3(emb)\n",
172 | " return label_out,emb\n",
173 | "\n",
174 | "\n",
175 | "\n",
176 | "net = Net().to(device)\n",
177 | "dataset = torch.utils.data.TensorDataset(X_train.to(device), train_label.to(device))\n",
178 | "trainloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,\n",
179 | " shuffle=True)\n",
180 | "import torch.optim as optim\n",
181 | "\n",
182 | "criterion = nn.CrossEntropyLoss()\n",
183 | "optimizer = optim.Adam(net.parameters(), lr=1e-3)\n",
184 | "\n",
185 | "def model_evaluation(model, data, labels):\n",
186 | " model.eval()\n",
187 | " data = data.to(device)\n",
188 | " labels = labels.to(device)\n",
189 | " outputs,_ = net(lookup_embed, data)\n",
190 | "\n",
191 | " _, predicted = torch.max(outputs, 1)\n",
192 | " \n",
193 | " return (predicted == labels).sum().item() / len(labels)\n",
194 | "\n",
195 | "def model_output(model, data):\n",
196 | " model.eval()\n",
197 | " data = data.to(device)\n",
198 | " outputs,emb = net(lookup_embed, data)\n",
199 | "\n",
200 | " _, predicted = torch.max(outputs, 1)\n",
201 | " \n",
202 | " return predicted.cpu().numpy(), emb.cpu().detach().numpy()"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": null,
208 | "metadata": {},
209 | "outputs": [],
210 | "source": [
211 | "from pytorch_metric_learning import miners, losses\n",
212 | "miner = miners.MultiSimilarityMiner()\n",
213 | "loss_func = losses.TripletMarginLoss()\n",
214 | "\n",
215 | "prev = 0\n",
216 | "# here the validing section is not very important, since we pay more attention to generating distinguishable embeddings.\n",
217 | "for epoch in range(40): # loop over the dataset multiple times\n",
218 | " running_loss = 0.0\n",
219 | " for i, data in enumerate(trainloader, 0):\n",
220 | " # get the inputs; data is a list of [inputs, labels]\n",
221 | " inputs, labels = data\n",
222 | " labels = labels.long()\n",
223 | "\n",
224 | " # zero the parameter gradients\n",
225 | " optimizer.zero_grad()\n",
226 | "\n",
227 | " # forward + backward + optimize\n",
228 | " outputs,emb = net(lookup_embed, inputs)\n",
229 | " \n",
230 | " loss = criterion(outputs, labels) + 100 * loss_func(emb, labels)\n",
231 | "# loss = loss_func(emb, labels)\n",
232 | " loss.backward()\n",
233 | " optimizer.step()\n",
234 | " \n",
235 | "\n",
236 | " # print statistics\n",
237 | " running_loss += loss.item()\n",
238 | " if i % 2000 == 1999: # print every 2000 mini-batches\n",
239 | " print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')\n",
240 | " running_loss = 0.0\n",
241 | "print('Finished Training')"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": null,
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "_,genePT_w_emebed_train = model_output(net, torch.FloatTensor(adata_train.X.toarray()))\n",
251 | "_,genePT_w_emebed_test = model_output(net, torch.FloatTensor(adata_test.X.toarray()))"
252 | ]
253 | },
254 | {
255 | "cell_type": "code",
256 | "execution_count": null,
257 | "metadata": {},
258 | "outputs": [],
259 | "source": [
260 | "adata_test.obsm['X_pca'] = genePT_w_emebed_test"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "metadata": {},
267 | "outputs": [],
268 | "source": [
269 | "adata_test.obsm['X_genept'] = genePT_w_emebed_test\n",
270 | "\n",
271 | "meanv = np.mean(adata_test[adata_test.obs['disease'] == 'NF'].obsm['X_genept'],axis=0)\n",
272 | "\n",
273 | "meanv_ascend = np.mean(adata_test[adata_test.obs['disease'] == 'DCM'].obsm['X_genept'],axis=0)\n",
274 | "\n",
275 | "import scipy\n",
276 | "\n",
277 | "raw_cs = 1 - scipy.spatial.distance.cosine(meanv, meanv_ascend)"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {},
284 | "outputs": [],
285 | "source": [
286 | "np.random.seed(202310)\n",
287 | "sc.tl.rank_genes_groups(sampled_adata, groupby='disease')"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": null,
293 | "metadata": {},
294 | "outputs": [],
295 | "source": [
296 | "disease = 'DCM'\n",
297 | "control = 'NF'"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {},
304 | "outputs": [],
305 | "source": [
306 | "for i in sampled_adata.uns['rank_genes_groups']['names'][disease][0:10]:\n",
307 | " adata_test_new = adata_test.copy()\n",
308 | " adata_test_new[:,i].X = 0\n",
309 | " \n",
310 | " _,genePT_w_emebed_test = model_output(net, torch.FloatTensor(adata_test_new.X.toarray()))\n",
311 | " adata_test_new.obsm['X_genept'] = genePT_w_emebed_test\n",
312 | " meanv = np.mean(adata_test_new[adata_test_new.obs['disease'] == control].obsm['X_genept'],axis=0)\n",
313 | " meanv_ascend = np.mean(adata_test_new[adata_test_new.obs['disease'] == disease].obsm['X_genept'],axis=0)\n",
314 | "# print(i)\n",
315 | " print(1 - scipy.spatial.distance.cosine(meanv, meanv_ascend) - raw_cs)\n",
316 | " "
317 | ]
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": null,
322 | "metadata": {},
323 | "outputs": [],
324 | "source": []
325 | }
326 | ],
327 | "metadata": {
328 | "language_info": {
329 | "name": "python"
330 | }
331 | },
332 | "nbformat": 4,
333 | "nbformat_minor": 2
334 | }
335 |
--------------------------------------------------------------------------------
/Perturbation Analysis/CINEMAOT/__init__.py:
--------------------------------------------------------------------------------
1 | """CINEMA-OT - Causal Independent Effect Module Attribution + Optimal Transport, for single-cell level treatment effect identification"""
2 | __version__ = "0.0.3"
3 | from . import cinemaot
--------------------------------------------------------------------------------
/Perturbation Analysis/CINEMAOT/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CINEMAOT/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CINEMAOT/__pycache__/cinemaot.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CINEMAOT/__pycache__/cinemaot.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CINEMAOT/__pycache__/sinkhorn_knopp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CINEMAOT/__pycache__/sinkhorn_knopp.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CINEMAOT/benchmark.py:
--------------------------------------------------------------------------------
1 | import scib
2 | import numpy as np
3 | import pandas as pd
4 | import scanpy as sc
5 | from sklearn.neighbors import NearestNeighbors
6 | from scipy.sparse import csr_matrix
7 |
8 | # In this newer version we use the Python implementation of xicor
9 | # import rpy2.robjects as ro
10 | # import rpy2.robjects.numpy2ri
11 | # import rpy2.robjects.pandas2ri
12 | # from rpy2.robjects.packages import importr
13 | # rpy2.robjects.numpy2ri.activate()
14 | # rpy2.robjects.pandas2ri.activate()
15 |
16 | from scipy.stats.stats import pearsonr
17 | from sklearn.decomposition import FastICA
18 | from sklearn.metrics import roc_curve
19 | from sklearn.metrics import auc
20 | from sklearn.metrics import pairwise_distances
21 | from . import sinkhorn_knopp as skp
22 |
23 | from sklearn.preprocessing import OneHotEncoder
24 | from scipy.stats import ttest_1samp
25 | import harmonypy as hm
26 |
27 | def mixscape(adata,obs_label, ref_label, expr_label, nn=20, return_te = True):
28 | X_pca1 = adata.obsm['X_pca'][adata.obs[obs_label]==expr_label,:]
29 | X_pca2 = adata.obsm['X_pca'][adata.obs[obs_label]==ref_label,:]
30 | nbrs = NearestNeighbors(n_neighbors=nn, algorithm='ball_tree').fit(X_pca1)
31 | mixscape_pca = adata.obsm['X_pca'].copy()
32 | mixscapematrix = nbrs.kneighbors_graph(X_pca2).toarray()
33 | mixscape_pca[adata.obs[obs_label]==ref_label,:] = np.dot(mixscapematrix, mixscape_pca[adata.obs[obs_label]==expr_label,:])/20
34 | if return_te:
35 | te2 = adata.X[adata.obs[obs_label]==ref_label,:] - (mixscapematrix/np.sum(mixscapematrix,axis=1)[:,None]) @ (adata.X[adata.obs[obs_label]==expr_label,:])
36 | return mixscape_pca, mixscapematrix, te2
37 | else:
38 | return mixscape_pca, mixscapematrix
39 |
40 | def harmony_mixscape(adata,obs_label, ref_label, expr_label,nn=20, return_te = True):
41 | meta_data = adata.obs
42 | data_mat=adata.obsm['X_pca']
43 | vars_use=[obs_label]
44 | ho = hm.run_harmony(data_mat, meta_data,vars_use)
45 | hmdata = ho.Z_corr.T
46 | X_pca1 = hmdata[adata.obs[obs_label]==expr_label,:]
47 | X_pca2 = hmdata[adata.obs[obs_label]==ref_label,:]
48 | nbrs = NearestNeighbors(n_neighbors=nn, algorithm='ball_tree').fit(X_pca1)
49 | hmmatrix = nbrs.kneighbors_graph(X_pca2).toarray()
50 | if return_te:
51 | te2 = adata.X[adata.obs[obs_label]==ref_label,:] - np.matmul(hmmatrix/np.sum(hmmatrix,axis=1)[:,None],adata.X[adata.obs[obs_label]==expr_label,:])
52 | return hmdata, hmmatrix, te2
53 | else:
54 | return hmdata, hmmatrix
55 |
56 | def OT(adata,obs_label, ref_label, expr_label,thres=0.01, return_te = True):
57 | cf1 = adata.obsm['X_pca'][adata.obs[obs_label]==expr_label,0:20]
58 | cf2 = adata.obsm['X_pca'][adata.obs[obs_label]==ref_label,0:20]
59 | r = np.zeros([cf1.shape[0],1])
60 | c = np.zeros([cf2.shape[0],1])
61 | r[:,0] = 1/cf1.shape[0]
62 | c[:,0] = 1/cf2.shape[0]
63 | sk = skp.SinkhornKnopp(setr=r,setc=c,epsilon=1e-2)
64 | dis = pairwise_distances(cf1,cf2)
65 | e = thres * adata.obsm['X_pca'].shape[1]
66 | af = np.exp(-dis * dis / e)
67 | ot = sk.fit(af).T
68 | OT_pca = adata.obsm['X_pca'].copy()
69 | OT_pca[adata.obs[obs_label]==ref_label,:] = np.matmul(ot/np.sum(ot,axis=1)[:,None],OT_pca[adata.obs[obs_label]==expr_label,:])
70 | if return_te:
71 | te2 = adata.X[adata.obs[obs_label]==ref_label,:] - np.matmul(ot/np.sum(ot,axis=1)[:,None],adata.X[adata.obs[obs_label]==expr_label,:])
72 | return OT_pca, ot, te2
73 | else:
74 | return OT_pca, ot
75 |
76 |
77 | def evaluate_cinema(matrix,ite,gt,gite):
78 | #includes four statistics: knn-AUC, treatment effect pearson correlation, treatment effect spearman correlation, ttest AUC
79 | aucdata = np.zeros(gt.shape[0])
80 | corr_ = np.zeros(gt.shape[0])
81 | scorr_ = np.zeros(gt.shape[0])
82 | #genesig = np.zeros(gite.shape[1])
83 | for i in range(gt.shape[0]):
84 | fpr, tpr, thres = roc_curve(gt[i,:],matrix[i,:])
85 | aucdata[i] = auc(fpr,tpr)
86 | for i in range(ite.shape[0]):
87 | corr_[i], pval = pearsonr(ite[i,1000:],gite[i,1000:])
88 | scorr_[i],pval = spearmanr(ite[i,1000:],gite[i,1000:])
89 | corr_[i], pval = pearsonr(ite[i,:],gite[i,:])
90 | scorr_[i],pval = spearmanr(ite[i,:],gite[i,:])
91 | return np.median(aucdata), np.median(corr_), np.median(scorr_)
92 |
93 | def evaluate_batch(sig, adata,obs_label, label, continuity,asw=True,silhouette=True,graph_conn=True,pcr=True,nmi=True,ari=True,diff_coefs=False):
94 | #Label is a list!!!
95 | newsig = sc.AnnData(X=sig, obs = adata.obs)
96 | sc.pp.pca(newsig,n_comps=min(15,newsig.X.shape[1]-1))
97 | #newsig.obsm['X_pca'] = newsig.X
98 | k0=15
99 | sc.pp.neighbors(newsig, n_neighbors=k0)
100 | sc.tl.diffmap(newsig, n_comps=min(15,newsig.X.shape[1]-1))
101 | eigen = newsig.obsm['X_diffmap']
102 | #newsig_nbrs = NearestNeighbors(n_neighbors=10, algorithm='ball_tree').fit(newsig.X)
103 | #newsig_con = newsig_nbrs.kneighbors_graph(newsig.X)
104 | #newsig.obsp['connectivities'] = newsig_con
105 | newsig_metrics = scib.metrics.metrics(adata,newsig,obs_label,label[0],
106 | isolated_labels_asw_= asw,
107 | graph_conn_= graph_conn,
108 | silhouette_ = silhouette,
109 | nmi_=nmi,
110 | ari_=ari,
111 | pcr_=pcr)
112 | if diff_coefs:
113 | for i in range(len(label)):
114 | steps = adata.obs[label[i]].values
115 | #also we test max correlation to see strong functional dependence between steps and signals, for each state_group population
116 | if continuity[i]:
117 | xi = np.zeros(eigen.shape[1])
118 | #pval = np.zeros(eigen.shape[1])
119 | j = 0
120 | for source_row in eigen.T:
121 | #rresults = xicor(ro.FloatVector(source_row), ro.FloatVector(steps), pvalue = True)
122 | xi_obj = Xi(source_row,steps.astype(np.float))
123 | xi[j] = xi_obj.correlation
124 | j = j+1
125 | maxcoef = np.max(xi)
126 | #newsig_metrics.rename(index={'trajectory':'trajectory_coef'},inplace=True)
127 | #newsig_metrics.iloc[13,0] = np.max(xi)
128 | newsig_metrics.loc[label[i]] = maxcoef
129 | else:
130 | encoder = OneHotEncoder(sparse=False)
131 | onehot = encoder.fit_transform(np.array(adata.obs[label[i]].values.tolist()).reshape(-1, 1))
132 | yi = np.zeros([onehot.shape[1],eigen.shape[1]])
133 | k = 0
134 | #ind = onehot.T[0] * 0
135 | m = onehot.T.shape[0]
136 | for indicator in onehot.T[0:m-1]:
137 | j = 0
138 | #ind = ind + indicator
139 | for source_row in eigen.T:
140 | xi_obj = Xi(source_row,indicator*1)
141 | yi[k,j] = xi_obj.correlation
142 | j = j+1
143 | k = k+1
144 |
145 | #newsig_metrics.rename(index={'hvg_overlap':'state_coef'},inplace=True)
146 | #newsig_metrics.iloc[12,0] = np.mean(np.max(yi,axis=1))
147 | newsig_metrics.loc[label[i]] = np.mean(np.max(yi,axis=1))
148 |
149 | return newsig_metrics
150 |
151 |
152 | class Xi:
153 | """
154 | x and y are the data vectors
155 | """
156 |
157 | def __init__(self, x, y):
158 |
159 | self.x = x
160 | self.y = y
161 |
162 | @property
163 | def sample_size(self):
164 | return len(self.x)
165 |
166 | @property
167 | def x_ordered_rank(self):
168 | # PI is the rank vector for x, with ties broken at random
169 | # Not mine: source (https://stackoverflow.com/a/47430384/1628971)
170 | # random shuffling of the data - reason to use random.choice is that
171 | # pd.sample(frac=1) uses the same randomizing algorithm
172 | len_x = len(self.x)
173 | randomized_indices = np.random.choice(np.arange(len_x), len_x, replace=False)
174 | randomized = [self.x[idx] for idx in randomized_indices]
175 | # same as pandas rank method 'first'
176 | rankdata = ss.rankdata(randomized, method="ordinal")
177 | # Reindexing based on pairs of indices before and after
178 | unrandomized = [
179 | rankdata[j] for i, j in sorted(zip(randomized_indices, range(len_x)))
180 | ]
181 | return unrandomized
182 |
183 | @property
184 | def y_rank_max(self):
185 | # f[i] is number of j s.t. y[j] <= y[i], divided by n.
186 | return ss.rankdata(self.y, method="max") / self.sample_size
187 |
188 | @property
189 | def g(self):
190 | # g[i] is number of j s.t. y[j] >= y[i], divided by n.
191 | return ss.rankdata([-i for i in self.y], method="max") / self.sample_size
192 |
193 | @property
194 | def x_ordered(self):
195 | # order of the x's, ties broken at random.
196 | return np.argsort(self.x_ordered_rank)
197 |
198 | @property
199 | def x_rank_max_ordered(self):
200 | x_ordered_result = self.x_ordered
201 | y_rank_max_result = self.y_rank_max
202 | # Rearrange f according to ord.
203 | return [y_rank_max_result[i] for i in x_ordered_result]
204 |
205 | @property
206 | def mean_absolute(self):
207 | x1 = self.x_rank_max_ordered[0 : (self.sample_size - 1)]
208 | x2 = self.x_rank_max_ordered[1 : self.sample_size]
209 |
210 | return (
211 | np.mean(
212 | np.abs(
213 | [
214 | x - y
215 | for x, y in zip(
216 | x1,
217 | x2,
218 | )
219 | ]
220 | )
221 | )
222 | * (self.sample_size - 1)
223 | / (2 * self.sample_size)
224 | )
225 |
226 | @property
227 | def inverse_g_mean(self):
228 | gvalue = self.g
229 | return np.mean(gvalue * (1 - gvalue))
230 |
231 | @property
232 | def correlation(self):
233 | """xi correlation"""
234 | return 1 - self.mean_absolute / self.inverse_g_mean
235 |
236 | @classmethod
237 | def xi(cls, x, y):
238 | return cls(x, y)
239 |
240 | def pval_asymptotic(self, ties=False, nperm=1000):
241 | """
242 | Returns p values of the correlation
243 | Args:
244 | ties: boolean
245 | If ties is true, the algorithm assumes that the data has ties
246 | and employs the more elaborated theory for calculating
247 | the P-value. Otherwise, it uses the simpler theory. There is
248 | no harm in setting tiles True, even if there are no ties.
249 | nperm: int
250 | The number of permutations for the permutation test, if needed.
251 | default 1000
252 | Returns:
253 | p value
254 | """
255 | # If there are no ties, return xi and theoretical P-value:
256 |
257 | if ties:
258 | return 1 - ss.norm.cdf(
259 | np.sqrt(self.sample_size) * self.correlation / np.sqrt(2 / 5)
260 | )
261 |
262 | # If there are ties, and the theoretical method
263 | # is to be used for calculation P-values:
264 | # The following steps calculate the theoretical variance
265 | # in the presence of ties:
266 | sorted_ordered_x_rank = sorted(self.x_rank_max_ordered)
267 |
268 | ind = [i + 1 for i in range(self.sample_size)]
269 | ind2 = [2 * self.sample_size - 2 * ind[i - 1] + 1 for i in ind]
270 |
271 | a = (
272 | np.mean([i * j * j for i, j in zip(ind2, sorted_ordered_x_rank)])
273 | / self.sample_size
274 | )
275 |
276 | c = (
277 | np.mean([i * j for i, j in zip(ind2, sorted_ordered_x_rank)])
278 | / self.sample_size
279 | )
280 |
281 | cq = np.cumsum(sorted_ordered_x_rank)
282 |
283 | m = [
284 | (i + (self.sample_size - j) * k) / self.sample_size
285 | for i, j, k in zip(cq, ind, sorted_ordered_x_rank)
286 | ]
287 |
288 | b = np.mean([np.square(i) for i in m])
289 | v = (a - 2 * b + np.square(c)) / np.square(self.inverse_g_mean)
290 |
291 | return 1 - ss.norm.cdf(
292 | np.sqrt(self.sample_size) * self.correlation / np.sqrt(v)
293 | )
--------------------------------------------------------------------------------
/Perturbation Analysis/CINEMAOT/sinkhorn_knopp.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 |
5 |
6 | class SinkhornKnopp:
7 | """
8 | Sinkhorn Knopp Algorithm
9 |
10 | Takes a non-negative square matrix P, where P =/= 0
11 | and iterates through Sinkhorn Knopp's algorithm
12 | to convert P to a doubly stochastic matrix.
13 | Guaranteed convergence if P has total support.
14 |
15 | For reference see original paper:
16 | http://msp.org/pjm/1967/21-2/pjm-v21-n2-p14-s.pdf
17 |
18 | Parameters
19 | ----------
20 | max_iter : int, default=1000
21 | The maximum number of iterations.
22 |
23 | epsilon : float, default=1e-3
24 | Metric used to compute the stopping condition,
25 | which occurs if all the row and column sums are
26 | within epsilon of 1. This should be a very small value.
27 | Epsilon must be between 0 and 1.
28 |
29 | Attributes
30 | ----------
31 | _max_iter : int, default=1000
32 | User defined parameter. See above.
33 |
34 | _epsilon : float, default=1e-3
35 | User defined paramter. See above.
36 |
37 | _stopping_condition: string
38 | Either "max_iter", "epsilon", or None, which is a
39 | description of why the algorithm stopped iterating.
40 |
41 | _iterations : int
42 | The number of iterations elapsed during the algorithm's
43 | run-time.
44 |
45 | _D1 : 2d-array
46 | Diagonal matrix obtained after a stopping condition was met
47 | so that _D1.dot(P).dot(_D2) is close to doubly stochastic.
48 |
49 | _D2 : 2d-array
50 | Diagonal matrix obtained after a stopping condition was met
51 | so that _D1.dot(P).dot(_D2) is close to doubly stochastic.
52 |
53 | Example
54 | -------
55 |
56 | .. code-block:: python
57 | >>> import numpy as np
58 | >>> from sinkhorn_knopp import sinkhorn_knopp as skp
59 | >>> sk = skp.SinkhornKnopp()
60 | >>> P = [[.011, .15], [1.71, .1]]
61 | >>> P_ds = sk.fit(P)
62 | >>> P_ds
63 | array([[ 0.06102561, 0.93897439],
64 | [ 0.93809928, 0.06190072]])
65 | >>> np.sum(P_ds, axis=0)
66 | array([ 0.99912489, 1.00087511])
67 | >>> np.sum(P_ds, axis=1)
68 | array([ 1., 1.])
69 |
70 | """
71 |
72 | def __init__(self, max_iter=1000, setr=0, setc=0, epsilon=1e-3):
73 | assert isinstance(max_iter, int) or isinstance(max_iter, float),\
74 | "max_iter is not of type int or float: %r" % max_iter
75 | assert max_iter > 0,\
76 | "max_iter must be greater than 0: %r" % max_iter
77 | self._max_iter = int(max_iter)
78 |
79 | assert isinstance(epsilon, int) or isinstance(epsilon, float),\
80 | "epsilon is not of type float or int: %r" % epsilon
81 | assert epsilon > 0 and epsilon < 1,\
82 | "epsilon must be between 0 and 1 exclusive: %r" % epsilon
83 | self._epsilon = epsilon
84 | self._setr = setr
85 | self._setc = setc
86 | self._stopping_condition = None
87 | self._iterations = 0
88 | self._D1 = np.ones(1)
89 | self._D2 = np.ones(1)
90 |
91 | def fit(self, P):
92 | """Fit the diagonal matrices in Sinkhorn Knopp's algorithm
93 |
94 | Parameters
95 | ----------
96 | P : 2d array-like
97 | Must be a square non-negative 2d array-like object, that
98 | is convertible to a numpy array. The matrix must not be
99 | equal to 0 and it must have total support for the algorithm
100 | to converge.
101 |
102 | Returns
103 | -------
104 | A double stochastic matrix.
105 |
106 | """
107 | P = np.asarray(P)
108 | assert np.all(P >= 0)
109 | assert P.ndim == 2
110 |
111 | N = P.shape[0]
112 | if np.sum(abs(self._setr)) == 0:
113 | rsum = P.shape[1]
114 | else:
115 | rsum = self._setr
116 | if np.sum(abs(self._setc)) == 0:
117 | csum = P.shape[0]
118 | else:
119 | csum = self._setc
120 | max_threshr = rsum + self._epsilon
121 | min_threshr = rsum - self._epsilon
122 | max_threshc = csum + self._epsilon
123 | min_threshc = csum - self._epsilon
124 | # Initialize r and c, the diagonals of D1 and D2
125 | # and warn if the matrix does not have support.
126 | r = np.ones((N, 1))
127 | pdotr = P.T.dot(r)
128 | total_support_warning_str = (
129 | "Matrix P must have total support. "
130 | "See documentation"
131 | )
132 | if not np.all(pdotr != 0):
133 | warnings.warn(total_support_warning_str, UserWarning)
134 |
135 | c = 1 / pdotr
136 | pdotc = P.dot(c)
137 | if not np.all(pdotc != 0):
138 | warnings.warn(total_support_warning_str, UserWarning)
139 |
140 | r = 1 / pdotc
141 | del pdotr, pdotc
142 |
143 | P_eps = np.copy(P)
144 | while np.any(np.sum(P_eps, axis=1) < min_threshr) \
145 | or np.any(np.sum(P_eps, axis=1) > max_threshr) \
146 | or np.any(np.sum(P_eps, axis=0) < min_threshc) \
147 | or np.any(np.sum(P_eps, axis=0) > max_threshc):
148 |
149 | c = csum / P.T.dot(r)
150 | r = rsum / P.dot(c)
151 |
152 | self._D1 = np.diag(np.squeeze(r))
153 | self._D2 = np.diag(np.squeeze(c))
154 |
155 | P_eps = np.diag(self._D1)[:,None] * P * np.diag(self._D2)[None,:]
156 |
157 |
158 | self._iterations += 1
159 |
160 | if self._iterations >= self._max_iter:
161 | self._stopping_condition = "max_iter"
162 | break
163 |
164 | if not self._stopping_condition:
165 | self._stopping_condition = "epsilon"
166 |
167 | self._D1 = np.diag(np.squeeze(r))
168 | self._D2 = np.diag(np.squeeze(c))
169 | P_eps = np.diag(self._D1)[:,None] * P * np.diag(self._D2)[None,:]
170 |
171 | return P_eps
172 |
--------------------------------------------------------------------------------
/Perturbation Analysis/CINEMAOT/utils.py:
--------------------------------------------------------------------------------
1 | import gseapy as gp
2 | import pandas as pd
3 | from scipy.stats import wilcoxon
4 | import numpy as np
5 | import scanpy as sc
6 | #import scib
7 | from sklearn.linear_model import LogisticRegression
8 | from sklearn.preprocessing import OneHotEncoder
9 | from scipy.stats import kstest
10 | import plotly.graph_objects as go
11 | import plotly.express as px
12 |
13 | # import rpy2.robjects as ro
14 | # import rpy2.robjects.numpy2ri
15 | # import rpy2.robjects.pandas2ri
16 | # from rpy2.robjects.packages import importr
17 | # rpy2.robjects.numpy2ri.activate()
18 | # rpy2.robjects.pandas2ri.activate()
19 |
20 |
21 | def dominantcluster(adata,ctobs,clobs):
22 | clustername = []
23 | clustertime = np.zeros(adata.obs[ctobs].value_counts().values.shape[0])
24 | for i in adata.obs[clobs].value_counts().sort_index().index.values:
25 | tmp = adata.obs[ctobs][adata.obs[clobs]==i].value_counts().sort_index()
26 | ind = np.argmax(tmp.values)
27 | clustername.append(tmp.index.values[ind] + str(int(clustertime[ind])))
28 | clustertime[ind] = clustertime[ind] + 1
29 | return clustername
30 |
31 | def assignleiden(adata,ctobs,clobs,label):
32 | clustername = dominantcluster(adata,ctobs,clobs)
33 | ss = adata.obs[clobs].values.tolist()
34 | for i in range(len(ss)):
35 | ss[i] = clustername[int(ss[i])]
36 | adata.obs[label] = ss
37 | return
38 |
39 | def clustertest_synergy(adata1,adata2,clobs,thres,fthres,path,genesetpath,organism):
40 | # In this simplified function, we return the gene set only. The function is only designed for synergy computation.
41 | mkup = []
42 | mkdown = []
43 | for i in list(set(adata1.obs[clobs].values.tolist())):
44 | adata = adata1
45 | clusterindex = (adata.obs[clobs].values==i)
46 | tmpte = adata.X[clusterindex,:]
47 | clustername = i
48 | pv = np.zeros(tmpte.shape[1])
49 | for k in range(tmpte.shape[1]):
50 | st, pv[k] = wilcoxon(tmpte[:,k],zero_method='zsplit')
51 | genenames = adata.var_names.values
52 | upindex = (((pv0)*1) * (np.abs(np.median(tmpte,axis=0))>fthres))>0
53 | downindex = (((pvfthres))>0
54 | allindex = (((pvfthres))>0
55 | upgenes1 = genenames[upindex]
56 | downgenes1 = genenames[downindex]
57 | allgenes1 = genenames[allindex]
58 | adata = adata2
59 | clusterindex = (adata.obs[clobs].values==i)
60 | tmpte = adata.X[clusterindex,:]
61 | clustername = i
62 | pv = np.zeros(tmpte.shape[1])
63 | for k in range(tmpte.shape[1]):
64 | st, pv[k] = wilcoxon(tmpte[:,k],zero_method='zsplit')
65 | genenames = adata.var_names.values
66 | upindex = (((pv0)*1) * (np.abs(np.median(tmpte,axis=0))>fthres))>0
67 | downindex = (((pvfthres))>0
68 | allindex = (((pvfthres))>0
69 | upgenes2 = genenames[upindex]
70 | downgenes2 = genenames[downindex]
71 | allgenes2 = genenames[allindex]
72 | up1syn = list(set(upgenes1.tolist()) - set(upgenes2.tolist()))
73 | up2syn = list(set(upgenes2.tolist()) - set(upgenes1.tolist()))
74 | down1syn = list(set(downgenes1.tolist()) - set(downgenes2.tolist()))
75 | down2syn = list(set(downgenes2.tolist()) - set(downgenes1.tolist()))
76 | allgenes = list(set(up1syn) | set(up2syn) | set(down1syn) | set(down2syn))
77 | enr_up1 = gp.enrichr(gene_list=up1syn, gene_sets=genesetpath,
78 | no_plot=True,organism=organism,
79 | outdir=path, format='png')
80 | enr_up2 = gp.enrichr(gene_list=up2syn, gene_sets=genesetpath,
81 | no_plot=True,organism=organism,
82 | outdir=path, format='png')
83 | enr_down1 = gp.enrichr(gene_list=down1syn, gene_sets=genesetpath,
84 | no_plot=True,organism=organism,
85 | outdir=path, format='png')
86 | enr_down2 = gp.enrichr(gene_list=down2syn, gene_sets=genesetpath,
87 | no_plot=True,organism=organism,
88 | outdir=path, format='png')
89 | if not enr_up1.results.empty:
90 | enr_up1.results.iloc[enr_up1.results['Adjusted P-value'].values<1e-2,:].to_csv(path+'/Up1'+clustername+'.csv')
91 | if not enr_up2.results.empty:
92 | enr_up2.results.iloc[enr_up2.results['Adjusted P-value'].values<1e-2,:].to_csv(path+'/Up2'+clustername+'.csv')
93 | if not enr_down1.results.empty:
94 | enr_down1.results.iloc[enr_down1.results['Adjusted P-value'].values<1e-2,:].to_csv(path+'/Down1'+clustername+'.csv')
95 | if not enr_down2.results.empty:
96 | enr_down2.results.iloc[enr_down2.results['Adjusted P-value'].values<1e-2,:].to_csv(path+'/Down2'+clustername+'.csv')
97 | upgenes1df = pd.DataFrame(index=up1syn)
98 | upgenes2df = pd.DataFrame(index=up2syn)
99 | downgenes1df = pd.DataFrame(index=down1syn)
100 | downgenes2df = pd.DataFrame(index=down2syn)
101 | allgenesdf = pd.DataFrame(index=allgenes)
102 | upgenes1df.to_csv(path+'/Upnames1'+clustername+'.csv')
103 | upgenes2df.to_csv(path+'/Upnames2'+clustername+'.csv')
104 | downgenes1df.to_csv(path+'/Downnames1'+clustername+'.csv')
105 | downgenes2df.to_csv(path+'/Downnames2'+clustername+'.csv')
106 | allgenesdf.to_csv(path+'/names'+clustername+'.csv')
107 |
108 | return
109 |
110 |
111 | def clustertest(adata,clobs,thres,fthres,label,path,genesetpath,organism,onlyup=False):
112 | # Changed from ttest to Wilcoxon test
113 | clusternum = int(np.max((np.asfarray(adata.obs[clobs].values))))
114 | genenum = np.zeros([clusternum+1])
115 | mk = []
116 | for i in range(clusternum+1):
117 | clusterindex = (np.asfarray(adata.obs[clobs].values)==i)
118 | tmpte = adata.X[clusterindex,:]
119 | clustername = adata.obs[label][clusterindex][0]
120 | pv = np.zeros(tmpte.shape[1])
121 | for k in range(tmpte.shape[1]):
122 | st, pv[k] = wilcoxon(tmpte[:,k],zero_method='zsplit')
123 | genenames = adata.var_names.values
124 | upindex = (((pv0)*1) * (np.abs(np.median(tmpte,axis=0))>fthres))>0
125 | downindex = (((pvfthres))>0
126 | allindex = (((pvfthres))>0
127 | upgenes = genenames[upindex]
128 | downgenes = genenames[downindex]
129 | allgenes = genenames[allindex]
130 | mk.extend(allgenes.tolist())
131 | mk = list(set(mk))
132 | genenum[i] = np.sum(((pvfthres)))
133 | enr_up = gp.enrichr(gene_list=upgenes.tolist(), gene_sets=genesetpath,
134 | no_plot=True,organism=organism,
135 | outdir=path, format='png')
136 | enr_down = gp.enrichr(gene_list=downgenes.tolist(), gene_sets=genesetpath,
137 | no_plot=True,organism=organism,
138 | outdir=path, format='png')
139 | enr = gp.enrichr(gene_list=allgenes.tolist(), gene_sets=genesetpath,
140 | no_plot=True,organism=organism,
141 | outdir=path, format='png')
142 | if not enr_up.results.empty:
143 | enr_up.results.iloc[enr_up.results['Adjusted P-value'].values<1e-3,:].to_csv(path+'/Up'+clustername+'.csv')
144 | if not enr_down.results.empty:
145 | enr_down.results.iloc[enr_down.results['Adjusted P-value'].values<1e-3,:].to_csv(path+'/Down'+clustername+'.csv')
146 | if not enr.results.empty:
147 | enr.results.iloc[enr.results['Adjusted P-value'].values<1e-3,:].to_csv(path+'/'+clustername+'.csv')
148 | upgenesdf = pd.DataFrame(index=upgenes)
149 | downgenesdf = pd.DataFrame(index=downgenes)
150 | allgenesdf = pd.DataFrame(index=allgenes)
151 | upgenesdf.to_csv(path+'/Upnames'+clustername+'.csv')
152 | downgenesdf.to_csv(path+'/Downnames'+clustername+'.csv')
153 | allgenesdf.to_csv(path+'/names'+clustername+'.csv')
154 | if onlyup:
155 | enr = enr_up
156 |
157 | if not enr.results.empty:
158 | if i == 0:
159 | df = enr.results.transpose().iloc[4:5,:]
160 | df.columns = enr.results['Term'][:]
161 | df.index.values[0] = clustername
162 | else:
163 | tmp = enr.results.transpose().iloc[4:5,:]
164 | tmp.columns = enr.results['Term'][:]
165 | tmp.index.values[0] = clustername
166 | df = pd.concat([df,tmp])
167 | #df.values = -np.log10(df.values)
168 | #DF = sc.AnnData(df.transpose())
169 | #sc.pl.clustermap(DF,cmap='viridis', col_cluster=False)
170 | return genenum, df, mk
171 |
172 |
173 | def concordance_map(confounder,response,obs_label, cl_label, condition):
174 | #deprecated
175 | cf = confounder[confounder.obs[obs_label] == condition,:]
176 | cf.obs['res_cl'] = response.obs[cl_label].values
177 | aswmatrix = np.zeros([len(list(set(cf.obs['res_cl'].values.tolist()))),len(list(set(cf.obs['res_cl'].values.tolist())))])
178 | indnummatrix = pd.DataFrame(None,list(set(cf.obs['res_cl'].values.tolist())),list(set(cf.obs['res_cl'].values.tolist())))
179 | k = 0
180 | #return aswmatrix
181 | for i in list(set(cf.obs['res_cl'].values.tolist())):
182 | l = 0
183 | for j in list(set(cf.obs['res_cl'].values.tolist())):
184 | if i != j:
185 | tmpcf = cf[cf.obs['res_cl'].isin([i,j]),:].copy()
186 | sc.pp.pca(tmpcf)
187 | encoder = OneHotEncoder(sparse=False)
188 | onehot = encoder.fit_transform(np.array(tmpcf.obs['res_cl'].values.tolist()).reshape(-1, 1))
189 | label = onehot[:,0]
190 | lc = LogisticRegression(penalty='l1',solver='liblinear',C=1)
191 | lc.fit(tmpcf.X, label)
192 | prob = lc.predict_proba(tmpcf.X)
193 | prob1 = prob[label==1,0]
194 | prob2 = prob[label==0,0]
195 | st, pv = kstest(prob1,prob2)
196 | #yi = np.zeros([onehot.shape[1],eigen.shape[1]])
197 | aswmatrix[k,l] = -np.log10(pv+1e-20)
198 | if np.sum(lc.coef_!=0)>0:
199 | indnummatrix.iloc[k,l] = str(np.argwhere(lc.coef_[0] !=0)[:,0].tolist())[1:-1]
200 | else:
201 | aswmatrix[k,l] = 0
202 | l = l + 1
203 | k = k + 1
204 | aswmatrix = pd.DataFrame(aswmatrix,list(set(cf.obs['res_cl'].values.tolist())),list(set(cf.obs['res_cl'].values.tolist())))
205 | return aswmatrix, indnummatrix
206 |
207 |
208 | def coarse_matching(de,de_label,ref,ref_label,ot,scaling=1e6,mode='mean'):
209 | coarse_ot = pd.DataFrame(index=sorted(set(de.obs[de_label].values.tolist())),columns=sorted(set(ref.obs[ref_label].values.tolist())),dtype=float)
210 | for i in set(de.obs[de_label].values.tolist()):
211 | for j in set(ref.obs[ref_label].values.tolist()):
212 | tmp_ot = ot[de.obs[de_label]==i,:]
213 | if mode=='mean':
214 | coarse_ot[j][i] = np.mean(tmp_ot[:,ref.obs[ref_label]==j]) * scaling
215 | else:
216 | coarse_ot[j][i] = np.sum(tmp_ot[:,ref.obs[ref_label]==j]) * scaling
217 | return coarse_ot
218 |
219 | def sankey_plot(coarse_ot,thres1=0.1,thres2=0.1,title_text="Sankey Diagram",width=600,height=400):
220 | new_coarse_ot = pd.DataFrame(np.zeros([coarse_ot.shape[0]*coarse_ot.shape[1],3]))
221 | k = 0
222 | for i in range(coarse_ot.shape[0]):
223 | for j in range(coarse_ot.shape[1]):
224 | thres_ = max(thres1 * np.sum(coarse_ot.values[i,:]), thres2 * np.sum(coarse_ot.values[:,j]))
225 | if coarse_ot.values[i,j] > thres_:
226 | new_coarse_ot.iloc[k,1] = 'Response: ' + coarse_ot.index[i]
227 | new_coarse_ot.iloc[k,0] = coarse_ot.columns[j]
228 | new_coarse_ot.iloc[k,2] = coarse_ot.values[i,j]
229 |
230 | k = k + 1
231 | new_coarse_ot = new_coarse_ot.loc[new_coarse_ot.iloc[:,2]>0,:]
232 | a = set(new_coarse_ot[0].values.tolist())
233 | b = set(new_coarse_ot[1].values.tolist())
234 | a0 = []
235 | for i in range(len(list(a))):
236 | a0.append(list(a)[i][:-1])
237 | a0 = list(set(a0))
238 |
239 | source = np.arange(new_coarse_ot.shape[0] + new_coarse_ot.shape[0])
240 | target = np.arange(new_coarse_ot.shape[0] + new_coarse_ot.shape[0])
241 |
242 | for i in range(new_coarse_ot.shape[0]):
243 | source[i+new_coarse_ot.shape[0]] = np.argwhere(np.array(list(a))==new_coarse_ot[0].values[i])[0][0]
244 | target[i+new_coarse_ot.shape[0]] = np.argwhere(np.array(list(b))==new_coarse_ot[1].values[i])[0][0]
245 |
246 | target = target + len(list(a))
247 |
248 | for i in range(new_coarse_ot.shape[0]):
249 | source[i] = np.argwhere(np.array(a0)==new_coarse_ot[0].values[i][:-1])[0][0]
250 | target[i] = np.argwhere(np.array(list(a))==new_coarse_ot[0].values[i])[0][0]
251 |
252 | target = target + len(a0)
253 | source[new_coarse_ot.shape[0]:] = source[new_coarse_ot.shape[0]:] + len(a0)
254 | values = np.zeros(2*new_coarse_ot.shape[0])
255 | for i in range(new_coarse_ot.shape[0]):
256 | values[i] = np.sum(new_coarse_ot.values[:,2][new_coarse_ot.values[:,0]==new_coarse_ot.values[i,0]]) / np.sum(new_coarse_ot.values[:,0]==new_coarse_ot.values[i,0])
257 |
258 | values[new_coarse_ot.shape[0]:] = new_coarse_ot.values[:,2]
259 | colorlist = px.colors.qualitative.Plotly
260 | colors = np.array(a0 + list(a) + list(b))
261 | colors[0:len(a0)] = colorlist[0:len(a0)]
262 | for i in range(len(a0),len(a0)+len(list(a))):
263 | colors[i] = colors[0:len(a0)][np.array(a0)==(list(a)[i-len(a0)][:-1])][0]
264 | for i in range(len(a0)+len(list(a)),len(a0)+len(list(a))+len(list(b))):
265 | colors[i] = colors[0:len(a0)][np.array(a0)==(list(b)[i-len(a0)-len(list(a))][10:-1])][0]
266 |
267 | fig = go.Figure(data=[go.Sankey(
268 | node = dict(
269 | pad = 15,
270 | thickness = 20,
271 | #line = dict(color = "black", width = 0.5),
272 | label = a0 + list(a) + list(b),
273 | color = colors
274 | ),
275 | link = dict(
276 | source = source, # indices correspond to labels, eg A1, A2, A1, B1, ...
277 | target = target,
278 | value = values
279 | ))])
280 |
281 | fig.update_layout(title_text="Sankey Diagram", font_family="Arial", font_size=10,width=width, height=height)
282 | fig.show()
283 | return
284 |
285 |
286 |
287 |
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.simplefilter('ignore')
4 |
5 | from ._model import CPA
6 | from ._module import CPAModule
7 | from . import _plotting as pl
8 | from ._api import ComPertAPI
9 |
10 | from importlib.metadata import version
11 |
12 | package_name = "cpa-tools"
13 | __version__ = version(package_name)
14 |
15 | __all__ = [
16 | "CPA",
17 | "CPAModule",
18 | "ComPertAPI",
19 | "pl",
20 | ]
21 |
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CPA/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/__pycache__/_api.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CPA/__pycache__/_api.cpython-39.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/__pycache__/_data.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CPA/__pycache__/_data.cpython-39.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/__pycache__/_metrics.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CPA/__pycache__/_metrics.cpython-39.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/__pycache__/_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CPA/__pycache__/_model.cpython-39.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/__pycache__/_module.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CPA/__pycache__/_module.cpython-39.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/__pycache__/_plotting.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CPA/__pycache__/_plotting.cpython-39.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/__pycache__/_task.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CPA/__pycache__/_task.cpython-39.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/__pycache__/_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/CPA/__pycache__/_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/_data.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from scvi import settings
4 | from scvi.data import AnnDataManager
5 | from scvi.dataloaders import DataSplitter, AnnDataLoader
6 | from scvi.model._utils import parse_use_gpu_arg
7 |
8 |
9 | class AnnDataSplitter(DataSplitter):
10 | def __init__(
11 | self,
12 | adata_manager: AnnDataManager,
13 | train_indices,
14 | valid_indices,
15 | test_indices,
16 | use_gpu: bool = False,
17 | **kwargs,
18 | ):
19 | super().__init__(adata_manager)
20 | self.data_loader_kwargs = kwargs
21 | self.use_gpu = use_gpu
22 | self.train_idx = train_indices
23 | self.val_idx = valid_indices
24 | self.test_idx = test_indices
25 |
26 | def setup(self, stage: Optional[str] = None):
27 | accelerator, _, self.device = parse_use_gpu_arg(
28 | self.use_gpu, return_device=True
29 | )
30 | self.pin_memory = (
31 | True
32 | if (settings.dl_pin_memory_gpu_training and accelerator == "gpu")
33 | else False
34 | )
35 |
36 | def train_dataloader(self):
37 | if len(self.train_idx) > 0:
38 | return AnnDataLoader(
39 | self.adata_manager,
40 | indices=self.train_idx,
41 | shuffle=True,
42 | pin_memory=self.pin_memory,
43 | **self.data_loader_kwargs,
44 | )
45 | else:
46 | pass
47 |
48 | def val_dataloader(self):
49 | if len(self.val_idx) > 0:
50 | data_loader_kwargs = self.data_loader_kwargs.copy()
51 | # if len(self.valid_indices < 4096):
52 | # data_loader_kwargs.update({'batch_size': len(self.valid_indices)})
53 | # else:
54 | # data_loader_kwargs.update({'batch_size': 2048})
55 | return AnnDataLoader(
56 | self.adata_manager,
57 | indices=self.val_idx,
58 | shuffle=True,
59 | pin_memory=self.pin_memory,
60 | **data_loader_kwargs,
61 | )
62 | else:
63 | pass
64 |
65 | def test_dataloader(self):
66 | if len(self.test_idx) > 0:
67 | return AnnDataLoader(
68 | self.adata_manager,
69 | indices=self.test_idx,
70 | shuffle=True,
71 | pin_memory=self.pin_memory,
72 | **self.data_loader_kwargs,
73 | )
74 | else:
75 | pass
76 |
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.stats import entropy
3 | from sklearn.neighbors import NearestNeighbors
4 | from sklearn.preprocessing import LabelEncoder
5 |
6 |
7 | def knn_purity(data, labels: np.ndarray, n_neighbors=30):
8 | """Computes KNN Purity for ``data`` given the labels.
9 | Parameters
10 | ----------
11 | data:
12 | Numpy ndarray of data
13 | labels
14 | Numpy ndarray of labels
15 | n_neighbors: int
16 | Number of nearest neighbors.
17 | Returns
18 | -------
19 | score: float
20 | KNN purity score. A float between 0 and 1.
21 | """
22 | labels = LabelEncoder().fit_transform(labels.ravel())
23 |
24 | nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(data)
25 | indices = nbrs.kneighbors(data, return_distance=False)[:, 1:]
26 | neighbors_labels = np.vectorize(lambda i: labels[i])(indices)
27 |
28 | # pre cell purity scores
29 | scores = ((neighbors_labels - labels.reshape(-1, 1)) == 0).mean(axis=1)
30 | res = [
31 | np.mean(scores[labels == i]) for i in np.unique(labels)
32 | ] # per cell-type purity
33 |
34 | return np.mean(res)
35 |
36 |
37 | def entropy_batch_mixing(data, labels,
38 | n_neighbors=50, n_pools=50, n_samples_per_pool=100):
39 | """Computes Entory of Batch mixing metric for ``adata`` given the batch column name.
40 | Parameters
41 | ----------
42 | data
43 | Numpy ndarray of data
44 | labels
45 | Numpy ndarray of labels
46 | n_neighbors: int
47 | Number of nearest neighbors.
48 | n_pools: int
49 | Number of EBM computation which will be averaged.
50 | n_samples_per_pool: int
51 | Number of samples to be used in each pool of execution.
52 | Returns
53 | -------
54 | score: float
55 | EBM score. A float between zero and one.
56 | """
57 |
58 | def __entropy_from_indices(indices, n_cat):
59 | return entropy(np.array(np.unique(indices, return_counts=True)[1].astype(np.int32)), base=n_cat)
60 |
61 | n_cat = len(np.unique(labels))
62 | # print(f'Calculating EBM with n_cat = {n_cat}')
63 |
64 | neighbors = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(data)
65 | indices = neighbors.kneighbors(data, return_distance=False)[:, 1:]
66 | batch_indices = np.vectorize(lambda i: labels[i])(indices)
67 |
68 | entropies = np.apply_along_axis(__entropy_from_indices, axis=1, arr=batch_indices, n_cat=n_cat)
69 |
70 | # average n_pools entropy results where each result is an average of n_samples_per_pool random samples.
71 | if n_pools == 1:
72 | score = np.mean(entropies)
73 | else:
74 | score = np.mean([
75 | np.mean(entropies[np.random.choice(len(entropies), size=n_samples_per_pool)])
76 | for _ in range(n_pools)
77 | ])
78 |
79 | return score
80 |
--------------------------------------------------------------------------------
/Perturbation Analysis/CPA/_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from scvi.distributions import NegativeBinomial
7 |
8 | from scvi.nn import FCLayers
9 | from torch.distributions import Normal
10 | from typing import Optional
11 |
12 |
13 | class _REGISTRY_KEYS:
14 | X_KEY: str = "X"
15 | X_CTRL_KEY: str = None
16 | BATCH_KEY: str = None
17 | CATEGORY_KEY: str = "cpa_category"
18 | PERTURBATION_KEY: str = None
19 | PERTURBATION_DOSAGE_KEY: str = None
20 | PERTURBATIONS: str = "perts"
21 | PERTURBATIONS_DOSAGES: str = "perts_doses"
22 | SIZE_FACTOR_KEY: str = "size_factor"
23 | CAT_COV_KEYS: List[str] = []
24 | MAX_COMB_LENGTH: int = 2
25 | CONTROL_KEY: str = None
26 | DEG_MASK: str = None
27 | DEG_MASK_R2: str = None
28 | PADDING_IDX: int = 0
29 |
30 |
31 | CPA_REGISTRY_KEYS = _REGISTRY_KEYS()
32 |
33 |
34 | class VanillaEncoder(nn.Module):
35 | def __init__(
36 | self,
37 | n_input,
38 | n_output,
39 | n_hidden,
40 | n_layers,
41 | n_cat_list,
42 | use_layer_norm=True,
43 | use_batch_norm=False,
44 | output_activation: str = 'linear',
45 | dropout_rate: float = 0.1,
46 | activation_fn=nn.ReLU,
47 | ):
48 | super().__init__()
49 | self.n_output = n_output
50 | self.output_activation = output_activation
51 |
52 | self.network = FCLayers(
53 | n_in=n_input,
54 | n_out=n_hidden,
55 | n_cat_list=n_cat_list,
56 | n_layers=n_layers,
57 | n_hidden=n_hidden,
58 | use_layer_norm=use_layer_norm,
59 | use_batch_norm=use_batch_norm,
60 | dropout_rate=dropout_rate,
61 | activation_fn=activation_fn,
62 | )
63 | self.z = nn.Linear(n_hidden, n_output)
64 |
65 | def forward(self, inputs, *cat_list):
66 | if self.output_activation == 'linear':
67 | z = self.z(self.network(inputs, *cat_list))
68 | elif self.output_activation == 'relu':
69 | z = F.relu(self.z(self.network(inputs, *cat_list)))
70 | else:
71 | raise ValueError(f'Unknown output activation: {self.output_activation}')
72 | return z
73 |
74 |
75 | class GeneralizedSigmoid(nn.Module):
76 | """
77 | Sigmoid, log-sigmoid or linear functions for encoding dose-response for
78 | drug perurbations.
79 | """
80 |
81 | def __init__(self, n_drugs, non_linearity='sigmoid'):
82 | """Sigmoid modeling of continuous variable.
83 | Params
84 | ------
85 | nonlin : str (default: logsigm)
86 | One of logsigm, sigm.
87 | """
88 | super(GeneralizedSigmoid, self).__init__()
89 | self.non_linearity = non_linearity
90 | self.n_drugs = n_drugs
91 |
92 | self.beta = torch.nn.Parameter(
93 | torch.ones(1, n_drugs),
94 | requires_grad=True
95 | )
96 | self.bias = torch.nn.Parameter(
97 | torch.zeros(1, n_drugs),
98 | requires_grad=True
99 | )
100 |
101 | self.vmap = None
102 |
103 | def forward(self, x, y):
104 | """
105 | Parameters
106 | ----------
107 | x: (batch_size, max_comb_len)
108 | y: (batch_size, max_comb_len)
109 | """
110 | y = y.long()
111 | if self.non_linearity == 'logsigm':
112 | bias = self.bias[0][y]
113 | beta = self.beta[0][y]
114 | c0 = bias.sigmoid()
115 | return (torch.log1p(x) * beta + bias).sigmoid() - c0
116 | elif self.non_linearity == 'sigm':
117 | bias = self.bias[0][y]
118 | beta = self.beta[0][y]
119 | c0 = bias.sigmoid()
120 | return (x * beta + bias).sigmoid() - c0
121 | else:
122 | return x
123 |
124 | def one_drug(self, x, i):
125 | if self.non_linearity == 'logsigm':
126 | c0 = self.bias[0][i].sigmoid()
127 | return (torch.log1p(x) * self.beta[0][i] + self.bias[0][i]).sigmoid() - c0
128 | elif self.non_linearity == 'sigm':
129 | c0 = self.bias[0][i].sigmoid()
130 | return (x * self.beta[0][i] + self.bias[0][i]).sigmoid() - c0
131 | else:
132 | return x
133 |
134 |
135 | class PerturbationNetwork(nn.Module):
136 | def __init__(self,
137 | n_perts,
138 | n_latent,
139 | doser_type='logsigm',
140 | n_hidden=None,
141 | n_layers=None,
142 | dropout_rate: float = 0.0,
143 | drug_embeddings=None,):
144 | super().__init__()
145 | self.n_latent = n_latent
146 |
147 | if drug_embeddings is not None:
148 | self.pert_embedding = drug_embeddings
149 | self.pert_transformation = nn.Linear(drug_embeddings.embedding_dim, n_latent)
150 | self.use_rdkit = True
151 | else:
152 | self.use_rdkit = False
153 | self.pert_embedding = nn.Embedding(n_perts, n_latent, padding_idx=CPA_REGISTRY_KEYS.PADDING_IDX)
154 |
155 | self.doser_type = doser_type
156 | if self.doser_type == 'mlp':
157 | self.dosers = nn.ModuleList()
158 | for _ in range(n_perts):
159 | self.dosers.append(
160 | FCLayers(
161 | n_in=1,
162 | n_out=1,
163 | n_hidden=n_hidden,
164 | n_layers=n_layers,
165 | use_batch_norm=False,
166 | use_layer_norm=True,
167 | dropout_rate=dropout_rate
168 | )
169 | )
170 | else:
171 | self.dosers = GeneralizedSigmoid(n_perts, non_linearity=self.doser_type)
172 |
173 | def forward(self, perts, dosages):
174 | """
175 | perts: (batch_size, max_comb_len)
176 | dosages: (batch_size, max_comb_len)
177 | """
178 | bs, max_comb_len = perts.shape
179 | perts = perts.long()
180 | scaled_dosages = self.dosers(dosages, perts) # (batch_size, max_comb_len)
181 |
182 | drug_embeddings = self.pert_embedding(perts) # (batch_size, max_comb_len, n_drug_emb_dim)
183 |
184 | if self.use_rdkit:
185 | drug_embeddings = self.pert_transformation(drug_embeddings.view(bs * max_comb_len, -1)).view(bs, max_comb_len, -1)
186 |
187 | z_drugs = torch.einsum('bm,bme->bme', [scaled_dosages, drug_embeddings]) # (batch_size, n_latent)
188 |
189 | z_drugs = torch.einsum('bmn,bm->bmn', z_drugs, (perts != CPA_REGISTRY_KEYS.PADDING_IDX).int()).sum(dim=1) # mask single perts
190 |
191 | return z_drugs # (batch_size, n_latent)
192 |
193 | class FocalLoss(nn.Module):
194 | """ Inspired by https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py
195 |
196 | Focal Loss, as described in https://arxiv.org/abs/1708.02002.
197 | It is essentially an enhancement to cross entropy loss and is
198 | useful for classification tasks when there is a large class imbalance.
199 | x is expected to contain raw, unnormalized scores for each class.
200 | y is expected to contain class labels.
201 | Shape:
202 | - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
203 | - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
204 | """
205 |
206 | def __init__(self,
207 | alpha: Optional[torch.Tensor] = None,
208 | gamma: float = 2.,
209 | reduction: str = 'mean',
210 | ):
211 | """
212 | Args:
213 | alpha (Tensor, optional): Weights for each class. Defaults to None.
214 | gamma (float, optional): A constant, as described in the paper.
215 | Defaults to 0.
216 | reduction (str, optional): 'mean', 'sum' or 'none'.
217 | Defaults to 'mean'.
218 | """
219 | if reduction not in ('mean', 'sum', 'none'):
220 | raise ValueError(
221 | 'Reduction must be one of: "mean", "sum", "none".')
222 |
223 | super().__init__()
224 | self.alpha = alpha
225 | self.gamma = gamma
226 | self.reduction = reduction
227 |
228 | self.nll_loss = nn.NLLLoss(
229 | weight=alpha, reduction='none')
230 |
231 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
232 | if len(y_true) == 0:
233 | return torch.tensor(0.)
234 |
235 | # compute weighted cross entropy term: -alpha * log(pt)
236 | # (alpha is already part of self.nll_loss)
237 | log_p = F.log_softmax(y_pred, dim=-1)
238 | ce = self.nll_loss(log_p, y_true)
239 |
240 | # get true class column from each row
241 | all_rows = torch.arange(len(y_pred))
242 | log_pt = log_p[all_rows, y_true]
243 |
244 | # compute focal term: (1 - pt)^gamma
245 | pt = log_pt.exp()
246 | focal_term = (1 - pt) ** self.gamma
247 |
248 | # the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
249 | loss = focal_term * ce
250 |
251 | if self.reduction == 'mean':
252 | loss = loss.mean()
253 | elif self.reduction == 'sum':
254 | loss = loss.sum()
255 |
256 | return loss
--------------------------------------------------------------------------------
/Perturbation Analysis/cinemaot_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import scanpy as sc\n",
11 | "import pickle\n",
12 | "from scib_metrics.benchmark import Benchmarker\n",
13 | "import scib"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "import CINEMAOT as cnm"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "adata = sc.read_h5ad('/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/cinemaot_data/Integrated_subset.h5ad')\n",
32 | "adata_raw = sc.AnnData(adata.raw.X, obs = adata.obs, var = adata.raw.var)"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "pert_cond = 'IFNg' # modify it for different perturbation cases."
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "# adata_raw = adata_raw[:, adata.var_names]\n",
51 | "sc.pp.highly_variable_genes(adata_raw, n_top_genes=500) # users can modify the number of genes here\n",
52 | "adata_raw = adata_raw[:, adata_raw.var.highly_variable]\n",
53 | "\n",
54 | "adata_ = adata_raw[adata_raw.obs['perturbation'].isin(['No stimulation', pert_cond])]\n",
55 | "\n",
56 | "with open(\"/gpfs/gibbs/pi/zhao/tl688/GenePT/data_embedding/GPT_3_5_gene_embeddings.pickle\", \"rb\") as fp:\n",
57 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
58 | "gene_names= list(adata_.var.index)\n",
59 | "count_missing = 0\n",
60 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
61 | "lookup_embed = np.zeros(shape=(len(gene_names),EMBED_DIM))\n",
62 | "for i, gene in enumerate(gene_names):\n",
63 | " if gene in GPT_3_5_gene_embeddings:\n",
64 | " lookup_embed[i,:] = GPT_3_5_gene_embeddings[gene].flatten()\n",
65 | " else:\n",
66 | " count_missing+=1\n",
67 | "# lookup_embed = np.random.rand(lookup_embed.shape[0], lookup_embed.shape[1])\n",
68 | "# genePT_w_emebed = np.dot(adata_.X,lookup_embed)/len(gene_names)\n",
69 | "genePT_w_emebed = adata_.X @ lookup_embed/len(gene_names)\n",
70 | "print(f\"Unable to match {count_missing} out of {len(gene_names)} genes in the GenePT-w embedding\")\n"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "adata_.obsm['X_pca'] = genePT_w_emebed # replace the PCs using gpt 3.5 embeddings\n"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "cf, ot, de = cnm.cinemaot.cinemaot_unweighted(adata_,obs_label='perturbation', ref_label=pert_cond, expr_label='No stimulation',mode='parametric',thres=0.5,smoothness=1e-5,eps=1e-3,preweight_label='cell_type0528')\n",
89 | "\n",
90 | "adata_.obsm['cf'] = cf.copy()\n",
91 | "adata_.obsm['cf'][adata_.obs['perturbation']==pert_cond,:] = np.matmul(ot/np.sum(ot,axis=1)[:,None],cf[adata_.obs['perturbation']=='No stimulation',:])\n",
92 | "sc.pp.neighbors(adata_,use_rep='cf')\n",
93 | "\n",
94 | "sc.tl.umap(adata_,random_state=1)\n",
95 | "sc.pl.umap(adata_,color=['perturbation','cell_type0528'],wspace=0.5, save = f'cinemaot_pbmc_cf_{pert_cond}_genept.pdf', palette='tab20c')"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "results = scib.metrics.metrics(\n",
105 | " adata_,\n",
106 | " adata_int=adata_,\n",
107 | " batch_key=\"perturbation\",\n",
108 | " label_key=\"cell_type0528\",\n",
109 | " embed=\"cf\",\n",
110 | " isolated_labels_asw_=True,\n",
111 | " silhouette_=True,\n",
112 | " hvg_score_=False,\n",
113 | " graph_conn_=True,\n",
114 | " pcr_=True,\n",
115 | " isolated_labels_f1_=False,\n",
116 | " trajectory_=False,\n",
117 | " nmi_=True, # use the clustering, bias to the best matching\n",
118 | " ari_=True, # use the clustering, bias to the best matching\n",
119 | " cell_cycle_=False,\n",
120 | " kBET_=False, # kBET return nan sometimes, need to examine\n",
121 | " ilisi_=True,\n",
122 | " clisi_=True,\n",
123 | ")"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "results"
133 | ]
134 | }
135 | ],
136 | "metadata": {
137 | "language_info": {
138 | "name": "python"
139 | }
140 | },
141 | "nbformat": 4,
142 | "nbformat_minor": 2
143 | }
144 |
--------------------------------------------------------------------------------
/Perturbation Analysis/cpa_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "\n",
11 | "from sklearn.metrics import r2_score\n",
12 | "import numpy as np\n",
13 | "\n",
14 | "import os\n",
15 | "# os.chdir('/home/mohsen/projects/cpa/')\n",
16 | "# os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
17 | "\n",
18 | "import cpa\n",
19 | "import scanpy as sc"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "sc.settings.set_figure_params(dpi=100)\n",
29 | "\n",
30 | "data_path = './combo_sciplex_prep_hvg_filtered.h5ad'"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "try:\n",
40 | " adata = sc.read(data_path)\n",
41 | "except:\n",
42 | " import gdown\n",
43 | " gdown.download('https://drive.google.com/uc?export=download&id=1RRV0_qYKGTvD3oCklKfoZQFYqKJy4l6t')\n",
44 | " data_path = 'combo_sciplex_prep_hvg_filtered.h5ad'\n",
45 | " adata = sc.read(data_path)\n",
46 | "\n",
47 | "adata"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "adata.obs['split_1ct_MEC'].value_counts()\n",
57 | "\n",
58 | "adata.X = adata.layers['counts'].copy()\n",
59 | "\n",
60 | "cpa.CPA.setup_anndata(adata,\n",
61 | " perturbation_key='condition_ID',\n",
62 | " dosage_key='log_dose',\n",
63 | " control_group='CHEMBL504',\n",
64 | " batch_key=None,\n",
65 | " is_count_data=True,\n",
66 | " categorical_covariate_keys=['cell_type'],\n",
67 | " deg_uns_key='rank_genes_groups_cov',\n",
68 | " deg_uns_cat_key='cov_drug_dose',\n",
69 | " max_comb_len=2,\n",
70 | " )"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": null,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "ae_hparams = {\n",
80 | " \"n_latent\": 1536,\n",
81 | " \"recon_loss\": \"nb\",\n",
82 | " \"doser_type\": \"logsigm\",\n",
83 | " \"n_hidden_encoder\": 512,\n",
84 | " \"n_layers_encoder\": 3,\n",
85 | " \"n_hidden_decoder\": 512,\n",
86 | " \"n_layers_decoder\": 3,\n",
87 | " \"use_batch_norm_encoder\": True,\n",
88 | " \"use_layer_norm_encoder\": False,\n",
89 | " \"use_batch_norm_decoder\": True,\n",
90 | " \"use_layer_norm_decoder\": False,\n",
91 | " \"dropout_rate_encoder\": 0.1,\n",
92 | " \"dropout_rate_decoder\": 0.1,\n",
93 | " \"variational\": False,\n",
94 | " \"seed\": 434,\n",
95 | "}\n",
96 | "\n",
97 | "trainer_params = {\n",
98 | " \"n_epochs_kl_warmup\": None,\n",
99 | " \"n_epochs_pretrain_ae\": 30,\n",
100 | " \"n_epochs_adv_warmup\": 50,\n",
101 | " \"n_epochs_mixup_warmup\": 3,\n",
102 | " \"mixup_alpha\": 0.1,\n",
103 | " \"adv_steps\": 2,\n",
104 | " \"n_hidden_adv\": 64,\n",
105 | " \"n_layers_adv\": 2,\n",
106 | " \"use_batch_norm_adv\": True,\n",
107 | " \"use_layer_norm_adv\": False,\n",
108 | " \"dropout_rate_adv\": 0.3,\n",
109 | " \"reg_adv\": 20.0,\n",
110 | " \"pen_adv\": 20.0,\n",
111 | " \"lr\": 0.0003,\n",
112 | " \"wd\": 4e-07,\n",
113 | " \"adv_lr\": 0.0003,\n",
114 | " \"adv_wd\": 4e-07,\n",
115 | " \"adv_loss\": \"cce\",\n",
116 | " \"doser_lr\": 0.0003,\n",
117 | " \"doser_wd\": 4e-07,\n",
118 | " \"do_clip_grad\": False,\n",
119 | " \"gradient_clip_value\": 1.0,\n",
120 | " \"step_size_lr\": 45,\n",
121 | "}"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "adata.var_names = adata.var['symbol-0'].values # important, change the ensemble id to gene name.\n",
131 | "\n",
132 | "import pickle\n",
133 | "with open(\"/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/ensem_emb_gpt3.5all.pickle\", \"rb\") as fp:\n",
134 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
135 | "gene_names= list(adata.var.index)\n",
136 | "count_missing = 0\n",
137 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
138 | "lookup_embed = np.zeros(shape=(len(gene_names),EMBED_DIM))\n",
139 | "for i, gene in enumerate(gene_names):\n",
140 | " if gene in GPT_3_5_gene_embeddings:\n",
141 | " lookup_embed[i,:] = GPT_3_5_gene_embeddings[gene].flatten()\n",
142 | " else:\n",
143 | " count_missing+=1\n",
144 | "# lookup_embed = np.random.rand(lookup_embed.shape[0], lookup_embed.shape[1])\n",
145 | "# genePT_w_emebed = np.dot(adata.X,lookup_embed)/len(gene_names)\n",
146 | "genePT_w_emebed = adata.X @ lookup_embed/len(gene_names)\n",
147 | "# genePT_w_emebed = adata.X @ lookup_embed\n",
148 | "print(f\"Unable to match {count_missing} out of {len(gene_names)} genes in the GenePT-w embedding\")"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": null,
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "model = cpa.CPA(adata=adata,\n",
158 | " split_key='split_1ct_MEC',\n",
159 | " train_split='train',\n",
160 | " valid_split='valid',\n",
161 | " test_split='ood',\n",
162 | " gene_embeddings = lookup_embed, # add the embeddings\n",
163 | " use_gene_emb = True,\n",
164 | " **ae_hparams,\n",
165 | " )"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "metadata": {},
172 | "outputs": [],
173 | "source": [
174 | "model.train(max_epochs=20,\n",
175 | " use_gpu=True,\n",
176 | " batch_size=128,\n",
177 | " plan_kwargs=trainer_params,\n",
178 | " early_stopping_patience=10,\n",
179 | " check_val_every_n_epoch=5,\n",
180 | " save_path='./cpa_out_gpt35/',\n",
181 | " )"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "# to load model\n",
191 | "# model = cpa.CPA.load(dir_path='/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/cpa_out_gemb_new/',\n",
192 | "# adata=adata, use_gpu=True)"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "cpa.pl.plot_history(model)"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "## Latent space UMAP visualization\n",
211 | "latent_outputs = model.get_latent_representation(adata, batch_size=1024)\n",
212 | "\n",
213 | "sc.settings.verbosity = 3\n",
214 | "\n",
215 | "latent_basal_adata = latent_outputs['latent_basal']\n",
216 | "latent_adata = latent_outputs['latent_after']\n",
217 | "\n",
218 | "sc.pp.neighbors(latent_basal_adata)\n",
219 | "sc.tl.umap(latent_basal_adata)\n",
220 | "\n",
221 | "\n",
222 | "sc.pl.umap(latent_basal_adata, color=['condition_ID'], frameon=False, wspace=0.2)"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {},
229 | "outputs": [],
230 | "source": [
231 | "\n",
232 | "sc.pp.neighbors(latent_adata)\n",
233 | "sc.tl.umap(latent_adata)\n",
234 | "\n",
235 | "sc.pl.umap(latent_adata, color=['condition_ID'], frameon=False, wspace=0.2)"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": null,
241 | "metadata": {},
242 | "outputs": [],
243 | "source": [
244 | "#save data\n",
245 | "latent_basal_adata.write_h5ad(\"cpa_geneptnew_example_basal.h5ad\")\n",
246 | "\n",
247 | "latent_adata.write_h5ad(\"cpa_geneptnew_example_perturb.h5ad\")"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {},
253 | "source": [
254 | "# Make prediction"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "model.predict(adata, batch_size=1024)\n",
264 | "adata.var_names = adata.var['ensembl_id-0'].values"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": null,
270 | "metadata": {},
271 | "outputs": [],
272 | "source": [
273 | "import numpy as np\n",
274 | "import pandas as pd\n",
275 | "from sklearn.metrics import r2_score\n",
276 | "from collections import defaultdict\n",
277 | "from tqdm import tqdm\n",
278 | "\n",
279 | "n_top_degs = [10, 20, 50, None] # None means all genes\n",
280 | "\n",
281 | "results = defaultdict(list)\n",
282 | "ctrl_adata = adata[adata.obs['condition_ID'] == 'CHEMBL504'].copy()\n",
283 | "for cat in tqdm(adata.obs['cov_drug_dose'].unique()):\n",
284 | " if 'CHEMBL504' not in cat:\n",
285 | " cat_adata = adata[adata.obs['cov_drug_dose'] == cat].copy()\n",
286 | "\n",
287 | " deg_cat = f'{cat}'\n",
288 | " deg_list = adata.uns['rank_genes_groups_cov'][deg_cat]\n",
289 | "\n",
290 | " x_true = cat_adata.layers['counts'].toarray()\n",
291 | " x_pred = cat_adata.obsm['CPA_pred']\n",
292 | " x_ctrl = ctrl_adata.layers['counts'].toarray()\n",
293 | "\n",
294 | " x_true = np.log1p(x_true)\n",
295 | " x_pred = np.log1p(x_pred)\n",
296 | " x_ctrl = np.log1p(x_ctrl)\n",
297 | "\n",
298 | " for n_top_deg in n_top_degs:\n",
299 | " if n_top_deg is not None:\n",
300 | " degs = np.where(np.isin(adata.var_names, deg_list[:n_top_deg]))[0]\n",
301 | " else:\n",
302 | " degs = np.arange(adata.n_vars)\n",
303 | " n_top_deg = 'all'\n",
304 | "\n",
305 | " x_true_deg = x_true[:, degs]\n",
306 | " x_pred_deg = x_pred[:, degs]\n",
307 | " x_ctrl_deg = x_ctrl[:, degs]\n",
308 | "\n",
309 | " r2_mean_deg = r2_score(x_true_deg.mean(0), x_pred_deg.mean(0))\n",
310 | " r2_var_deg = r2_score(x_true_deg.var(0), x_pred_deg.var(0))\n",
311 | "\n",
312 | " r2_mean_lfc_deg = r2_score(x_true_deg.mean(0) - x_ctrl_deg.mean(0), x_pred_deg.mean(0) - x_ctrl_deg.mean(0))\n",
313 | " r2_var_lfc_deg = r2_score(x_true_deg.var(0) - x_ctrl_deg.var(0), x_pred_deg.var(0) - x_ctrl_deg.var(0))\n",
314 | "\n",
315 | " cov, cond, dose = cat.split('_')\n",
316 | "\n",
317 | " results['cell_type'].append(cov)\n",
318 | " results['condition'].append(cond)\n",
319 | " results['dose'].append(dose)\n",
320 | " results['n_top_deg'].append(n_top_deg)\n",
321 | " results['r2_mean_deg'].append(r2_mean_deg)\n",
322 | " results['r2_var_deg'].append(r2_var_deg)\n",
323 | " results['r2_mean_lfc_deg'].append(r2_mean_lfc_deg)\n",
324 | " results['r2_var_lfc_deg'].append(r2_var_lfc_deg)\n",
325 | "\n",
326 | "df = pd.DataFrame(results)"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": null,
332 | "metadata": {},
333 | "outputs": [],
334 | "source": [
335 | "df[df['n_top_deg'] == 20]"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "metadata": {},
342 | "outputs": [],
343 | "source": [
344 | "for cat in adata.obs[\"cov_drug_dose\"].unique():\n",
345 | " if \"CHEMBL504\" not in cat:\n",
346 | " cat_adata = adata[adata.obs[\"cov_drug_dose\"] == cat].copy()\n",
347 | "\n",
348 | " cat_adata.X = np.log1p(cat_adata.layers[\"counts\"].A)\n",
349 | " cat_adata.obsm[\"CPA_pred\"] = np.log1p(cat_adata.obsm[\"CPA_pred\"])\n",
350 | "\n",
351 | " deg_list = adata.uns[\"rank_genes_groups_cov\"][f'{cat}'][:20]\n",
352 | "\n",
353 | " print(cat, f\"{cat_adata.shape}\")\n",
354 | " cpa.pl.mean_plot(\n",
355 | " cat_adata,\n",
356 | " pred_obsm_key=\"CPA_pred\",\n",
357 | " path_to_save=None,\n",
358 | " deg_list=deg_list,\n",
359 | " # gene_list=deg_list[:5],\n",
360 | " show=True,\n",
361 | " verbose=True,\n",
362 | " )"
363 | ]
364 | },
365 | {
366 | "cell_type": "markdown",
367 | "metadata": {},
368 | "source": [
369 | "# Display drug information"
370 | ]
371 | },
372 | {
373 | "cell_type": "code",
374 | "execution_count": null,
375 | "metadata": {},
376 | "outputs": [],
377 | "source": [
378 | "cpa_api = cpa.ComPertAPI(adata, model,\n",
379 | " de_genes_uns_key='rank_genes_groups_cov',\n",
380 | " pert_category_key='cov_drug_dose',\n",
381 | " control_group='CHEMBL504',\n",
382 | " )"
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": null,
388 | "metadata": {},
389 | "outputs": [],
390 | "source": [
391 | "cpa_plots = cpa.pl.CompertVisuals(cpa_api, fileprefix=None)"
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": null,
397 | "metadata": {},
398 | "outputs": [],
399 | "source": [
400 | "drug_adata = cpa_api.get_pert_embeddings()\n",
401 | "drug_adata.shape"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": null,
407 | "metadata": {},
408 | "outputs": [],
409 | "source": [
410 | "cpa_plots.plot_latent_embeddings(drug_adata.X, kind='perturbations', titlename='Drugs')"
411 | ]
412 | }
413 | ],
414 | "metadata": {
415 | "language_info": {
416 | "name": "python"
417 | }
418 | },
419 | "nbformat": 4,
420 | "nbformat_minor": 2
421 | }
422 |
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/__init__.py:
--------------------------------------------------------------------------------
1 | from .gears import GEARS
2 | from .pertdata import PertData
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/gears/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/__pycache__/data_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/gears/__pycache__/data_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/__pycache__/gears.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/gears/__pycache__/gears.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/__pycache__/inference.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/gears/__pycache__/inference.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/gears/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/__pycache__/pertdata.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/gears/__pycache__/pertdata.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/gears/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/__pycache__/version.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/Perturbation Analysis/gears/__pycache__/version.cpython-38.pyc
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import Sequential, Linear, ReLU
5 |
6 | from torch_geometric.nn import SGConv
7 |
8 | class MLP(torch.nn.Module):
9 |
10 | def __init__(self, sizes, batch_norm=True, last_layer_act="linear"):
11 | super(MLP, self).__init__()
12 | layers = []
13 | for s in range(len(sizes) - 1):
14 | layers = layers + [
15 | torch.nn.Linear(sizes[s], sizes[s + 1]),
16 | torch.nn.BatchNorm1d(sizes[s + 1])
17 | if batch_norm and s < len(sizes) - 1 else None,
18 | torch.nn.ReLU()
19 | ]
20 |
21 | layers = [l for l in layers if l is not None][:-1]
22 | self.activation = last_layer_act
23 | self.network = torch.nn.Sequential(*layers)
24 | self.relu = torch.nn.ReLU()
25 | def forward(self, x):
26 | return self.network(x)
27 |
28 |
29 | class GEARS_Model(torch.nn.Module):
30 | """
31 | GEARS
32 | """
33 |
34 | def __init__(self, args):
35 | super(GEARS_Model, self).__init__()
36 | self.args = args
37 | self.num_genes = args['num_genes']
38 | self.num_perts = args['num_perts']
39 | hidden_size = args['hidden_size']
40 | self.uncertainty = args['uncertainty']
41 | self.num_layers = args['num_go_gnn_layers']
42 | self.indv_out_hidden_size = args['decoder_hidden_size']
43 | self.num_layers_gene_pos = args['num_gene_gnn_layers']
44 | self.no_perturb = args['no_perturb']
45 | self.cell_fitness_pred = args['cell_fitness_pred']
46 | self.pert_emb_lambda = 0.2
47 | self.gene_emb_input = args['gene_emb']
48 |
49 | # perturbation positional embedding added only to the perturbed genes
50 | self.pert_w = nn.Linear(1, hidden_size)
51 |
52 | # gene/globel perturbation embedding dictionary lookup
53 | if self.gene_emb_input is None:
54 | self.gene_emb = nn.Embedding(self.num_genes, hidden_size, max_norm=True)
55 | else:
56 | self.gene_emb = nn.Linear(self.gene_emb_input.shape[1], hidden_size)
57 | # self.gene_emb = nn.Linear(self.gene_emb_input.shape[1], hidden_size)
58 | # self.gene_emb = MLP([self.gene_emb_input.shape[1], hidden_size, hidden_size], last_layer_act='ReLU')
59 | self.gene_emb_part2 = nn.Embedding(self.num_genes, hidden_size, max_norm=True)
60 | self.pert_emb = nn.Embedding(self.num_perts, hidden_size, max_norm=True)
61 |
62 | # transformation layer
63 | self.emb_trans = nn.ReLU()
64 | self.pert_base_trans = nn.ReLU()
65 | self.transform = nn.ReLU()
66 | self.emb_trans_v2 = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU')
67 | self.pert_fuse = MLP([hidden_size, hidden_size, hidden_size], last_layer_act='ReLU')
68 |
69 | # gene co-expression GNN
70 | self.G_coexpress = args['G_coexpress'].to(args['device'])
71 | self.G_coexpress_weight = args['G_coexpress_weight'].to(args['device'])
72 |
73 | self.emb_pos = nn.Embedding(self.num_genes, hidden_size, max_norm=True)
74 | self.layers_emb_pos = torch.nn.ModuleList()
75 | for i in range(1, self.num_layers_gene_pos + 1):
76 | self.layers_emb_pos.append(SGConv(hidden_size, hidden_size, 1))
77 |
78 | ### perturbation gene ontology GNN
79 | self.G_sim = args['G_go'].to(args['device'])
80 | self.G_sim_weight = args['G_go_weight'].to(args['device'])
81 |
82 | self.sim_layers = torch.nn.ModuleList()
83 | for i in range(1, self.num_layers + 1):
84 | self.sim_layers.append(SGConv(hidden_size, hidden_size, 1))
85 |
86 | # decoder shared MLP
87 | self.recovery_w = MLP([hidden_size, hidden_size*2, hidden_size], last_layer_act='linear')
88 |
89 | # gene specific decoder
90 | self.indv_w1 = nn.Parameter(torch.rand(self.num_genes,
91 | hidden_size, 1))
92 | self.indv_b1 = nn.Parameter(torch.rand(self.num_genes, 1))
93 | self.act = nn.ReLU()
94 | nn.init.xavier_normal_(self.indv_w1)
95 | nn.init.xavier_normal_(self.indv_b1)
96 |
97 | # Cross gene MLP
98 | self.cross_gene_state = MLP([self.num_genes, hidden_size,
99 | hidden_size])
100 | # final gene specific decoder
101 | self.indv_w2 = nn.Parameter(torch.rand(1, self.num_genes,
102 | hidden_size+1))
103 | self.indv_b2 = nn.Parameter(torch.rand(1, self.num_genes))
104 | nn.init.xavier_normal_(self.indv_w2)
105 | nn.init.xavier_normal_(self.indv_b2)
106 |
107 | # batchnorms
108 | self.bn_emb = nn.BatchNorm1d(hidden_size)
109 | self.bn_pert_base = nn.BatchNorm1d(hidden_size)
110 | self.bn_pert_base_trans = nn.BatchNorm1d(hidden_size)
111 |
112 | # uncertainty mode
113 | if self.uncertainty:
114 | self.uncertainty_w = MLP([hidden_size, hidden_size*2, hidden_size, 1], last_layer_act='linear')
115 |
116 | #if self.cell_fitness_pred:
117 | self.cell_fitness_mlp = MLP([self.num_genes, hidden_size*2, hidden_size, 1], last_layer_act='linear')
118 |
119 | def forward(self, data):
120 | x, pert_idx = data.x, data.pert_idx
121 | if self.no_perturb:
122 | out = x.reshape(-1,1)
123 | out = torch.split(torch.flatten(out), self.num_genes)
124 | return torch.stack(out)
125 | else:
126 | num_graphs = len(data.batch.unique())
127 |
128 | ## get base gene embeddings
129 | if self.gene_emb_input is None:
130 | emb = self.gene_emb(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device']))
131 | else:
132 | emb = self.gene_emb(torch.FloatTensor(self.gene_emb_input).to(self.args['device'])).repeat(num_graphs,1)
133 | emb = emb + self.gene_emb_part2(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device']))
134 | # print(emb.shape)
135 | emb = self.bn_emb(emb)
136 | base_emb = self.emb_trans(emb)
137 |
138 | pos_emb = self.emb_pos(torch.LongTensor(list(range(self.num_genes))).repeat(num_graphs, ).to(self.args['device']))
139 | for idx, layer in enumerate(self.layers_emb_pos):
140 | pos_emb = layer(pos_emb, self.G_coexpress, self.G_coexpress_weight)
141 | if idx < len(self.layers_emb_pos) - 1:
142 | pos_emb = pos_emb.relu()
143 |
144 | base_emb = base_emb + 0.2 * pos_emb
145 | base_emb = self.emb_trans_v2(base_emb)
146 |
147 | ## get perturbation index and embeddings
148 |
149 | pert_index = []
150 | for idx, i in enumerate(pert_idx):
151 | for j in i:
152 | if j != -1:
153 | pert_index.append([idx, j])
154 | pert_index = torch.tensor(pert_index).T
155 |
156 | pert_global_emb = self.pert_emb(torch.LongTensor(list(range(self.num_perts))).to(self.args['device']))
157 |
158 | ## augment global perturbation embedding with GNN
159 | for idx, layer in enumerate(self.sim_layers):
160 | pert_global_emb = layer(pert_global_emb, self.G_sim, self.G_sim_weight)
161 | if idx < self.num_layers - 1:
162 | pert_global_emb = pert_global_emb.relu()
163 |
164 | ## add global perturbation embedding to each gene in each cell in the batch
165 | base_emb = base_emb.reshape(num_graphs, self.num_genes, -1)
166 |
167 | if pert_index.shape[0] != 0:
168 | ### in case all samples in the batch are controls, then there is no indexing for pert_index.
169 | pert_track = {}
170 | for i, j in enumerate(pert_index[0]):
171 | if j.item() in pert_track:
172 | pert_track[j.item()] = pert_track[j.item()] + pert_global_emb[pert_index[1][i]]
173 | else:
174 | pert_track[j.item()] = pert_global_emb[pert_index[1][i]]
175 |
176 | if len(list(pert_track.values())) > 0:
177 | if len(list(pert_track.values())) == 1:
178 | # circumvent when batch size = 1 with single perturbation and cannot feed into MLP
179 | emb_total = self.pert_fuse(torch.stack(list(pert_track.values()) * 2))
180 | else:
181 | emb_total = self.pert_fuse(torch.stack(list(pert_track.values())))
182 |
183 | for idx, j in enumerate(pert_track.keys()):
184 | base_emb[j] = base_emb[j] + emb_total[idx]
185 |
186 | base_emb = base_emb.reshape(num_graphs * self.num_genes, -1)
187 | base_emb = self.bn_pert_base(base_emb)
188 |
189 | ## apply the first MLP
190 | base_emb = self.transform(base_emb)
191 | out = self.recovery_w(base_emb)
192 | out = out.reshape(num_graphs, self.num_genes, -1)
193 | out = out.unsqueeze(-1) * self.indv_w1
194 | w = torch.sum(out, axis = 2)
195 | out = w + self.indv_b1
196 |
197 | # Cross gene
198 | cross_gene_embed = self.cross_gene_state(out.reshape(num_graphs, self.num_genes, -1).squeeze(2))
199 | cross_gene_embed = cross_gene_embed.repeat(1, self.num_genes)
200 |
201 | cross_gene_embed = cross_gene_embed.reshape([num_graphs,self.num_genes, -1])
202 | cross_gene_out = torch.cat([out, cross_gene_embed], 2)
203 |
204 | cross_gene_out = cross_gene_out * self.indv_w2
205 | cross_gene_out = torch.sum(cross_gene_out, axis=2)
206 | out = cross_gene_out + self.indv_b2
207 | out = out.reshape(num_graphs * self.num_genes, -1) + x.reshape(-1,1)
208 | out = torch.split(torch.flatten(out), self.num_genes)
209 |
210 | ## uncertainty head
211 | if self.uncertainty:
212 | out_logvar = self.uncertainty_w(base_emb)
213 | out_logvar = torch.split(torch.flatten(out_logvar), self.num_genes)
214 | return torch.stack(out), torch.stack(out_logvar)
215 |
216 | if self.cell_fitness_pred:
217 | return torch.stack(out), self.cell_fitness_mlp(torch.stack(out))
218 |
219 | return torch.stack(out)
220 |
221 |
--------------------------------------------------------------------------------
/Perturbation Analysis/gears/version.py:
--------------------------------------------------------------------------------
1 | """GEARS version file
2 | """
3 | # Based on NiLearn package
4 | # License: simplified BSD
5 |
6 | # PEP0440 compatible formatted version, see:
7 | # https://www.python.org/dev/peps/pep-0440/
8 | #
9 | # Generic release markers:
10 | # X.Y
11 | # X.Y.Z # For bug fix releases
12 | #
13 | # Admissible pre-release markers:
14 | # X.YaN # Alpha release
15 | # X.YbN # Beta release
16 | # X.YrcN # Release Candidate
17 | # X.Y # Final release
18 | #
19 | # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
20 | # 'X.Y.dev0' is the canonical version of 'X.Y.dev'
21 | #
22 | __version__ = '0.0.4' # pragma: no cover
23 |
--------------------------------------------------------------------------------
/Perturbation Analysis/gears_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from gears import PertData, GEARS\n",
10 | "\n",
11 | "# get data\n",
12 | "pert_data = PertData('./data')\n",
13 | "# pert_data = PertData('./data_folder')\n",
14 | "# load dataset in paper: norman, adamson, dixit.\n",
15 | "pert_data.load(data_name = 'dixit')\n",
16 | "# specify data split\n",
17 | "pert_data.prepare_split(split = 'simulation', seed = 1)\n",
18 | "# get dataloader with batch size\n",
19 | "pert_data.get_dataloader(batch_size = 32, test_batch_size = 128)\n",
20 | "\n",
21 | "from sklearn.model_selection import train_test_split\n",
22 | "from gears.inference import evaluate, compute_metrics, deeper_analysis, non_dropout_analysis\n",
23 | "from gears.utils import create_cell_graph_dataset_for_prediction"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import gears \n",
33 | "import pytorch_lightning\n",
34 | "import seaborn as sns\n",
35 | "sns.set_style(\"whitegrid\")\n",
36 | "import pandas as pd\n",
37 | "import numpy as np \n",
38 | "import matplotlib.pyplot as plt\n",
39 | "%matplotlib inline\n",
40 | "import pickle\n",
41 | "import scanpy as sc\n",
42 | "# import sentence_transformers\n",
43 | "plt.style.use('ggplot')\n",
44 | "#plt.style.use('seaborn-v0_8-dark-palette')\n",
45 | "plt.rcParams['axes.facecolor'] = 'white'\n",
46 | "\n",
47 | "np.random.seed(202310)\n",
48 | "pytorch_lightning.seed_everything(202310)"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "with open(\"/gpfs/gibbs/pi/zhao/tl688/GenePT/data_embedding/GPT_3_5_gene_embeddings.pickle\", \"rb\") as fp:\n",
58 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
59 | "with open(\"/gpfs/gibbs/pi/zhao/tl688/cpsc_finalproject/genept_data/GenePT/ensem_emb_gpt3.5all.pickle\", \"rb\") as fp:\n",
60 | " GPT_3_5_gene_embeddings = pickle.load(fp)\n",
61 | "gene_names= list(pert_data.adata.var['gene_name'].values)\n",
62 | "count_missing = 0\n",
63 | "EMBED_DIM = 1536 # embedding dim from GPT-3.5\n",
64 | "lookup_embed = np.zeros(shape=(len(gene_names),EMBED_DIM))\n",
65 | "for i, gene in enumerate(gene_names):\n",
66 | " if gene in GPT_3_5_gene_embeddings:\n",
67 | " lookup_embed[i,:] = GPT_3_5_gene_embeddings[gene].flatten()\n",
68 | " else:\n",
69 | " count_missing+=1\n"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "# set up and train a model\n",
79 | "gears_model = GEARS(pert_data, device = 'cuda:0', gene_emb = lookup_embed)\n",
80 | "# gears_model = GEARS(pert_data, device = 'cuda:0')\n",
81 | "gears_model.model_initialize(hidden_size = 64)\n",
82 | "\n",
83 | "gears_model.train(epochs = 20)\n",
84 | "\n",
85 | "# # save/load model\n",
86 | "# gears_model.save_model('gears_dixit_new')\n",
87 | "# gears_model.load_pretrained('gears_dixit_new')\n",
88 | "\n",
89 | "test_res = evaluate(gears_model.dataloader['test_loader'], gears_model.model, gears_model.config['uncertainty'], gears_model.device)\n",
90 | "test_metrics, test_pert_res = compute_metrics(test_res)"
91 | ]
92 | }
93 | ],
94 | "metadata": {
95 | "language_info": {
96 | "name": "python"
97 | }
98 | },
99 | "nbformat": 4,
100 | "nbformat_minor": 2
101 | }
102 |
--------------------------------------------------------------------------------
/elmo dalle2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/elmo dalle2.png
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | #
scELMo: Embeddings from Language Models are Good Learners for Single-cell Data Analysis
2 |
3 |
4 |
5 | # News!
6 |
7 | We have uploaded gene embeddings from gpt4-o and drug embeddings from GPT 3.5 in our website, please check them if you wanna have a try!
8 |
9 | # Installation
10 |
11 | We rely on OpenAI API for query.
12 |
13 | ```
14 | pip install openai
15 | ```
16 |
17 | The descriptions and tutorials for OpenAI API can be found in this [link](https://platform.openai.com/).
18 |
19 | We reply on these packages for zero-shot learning analysis.
20 |
21 | ```
22 | pip install scib scib_metrics==0.3.3 pickle mygene scanpy==1.9.3 scikit-learn
23 | ```
24 |
25 | Installing hnswlib from the original Github profile to avoid potential errors.
26 | ```
27 | apt-get install -y python-setuptools python-pip #may not need it for HPC base
28 | git clone https://github.com/nmslib/hnswlib.git
29 | cd hnswlib
30 | pip install .
31 | ```
32 | All the packages above are enough for testing tasks absed on zero-shot learning.
33 |
34 | We rely on PyTorch for fine-tuning.
35 |
36 | ```
37 | conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
38 | conda install lightning -c conda-forge
39 | ```
40 |
41 | For the perturbation analysis, please install related pacakges based on their website and use the modifeid version provided in the **Perturbation Analysis** folder: [CINEMAOT](https://github.com/vandijklab/CINEMA-OT/tree/main), [CPA](https://github.com/theislab/cpa) and [GEARS](https://github.com/snap-stanford/GEARS/tree/master).
42 |
43 | To generate gene embeddings from sequence models (as seq2emb), please refer [seq2cells](https://github.com/GSK-AI/seq2cells) to install related packages.
44 |
45 |
46 | For users who cannot access OpenAI API, we provide an alternative solution based on [deepseekv2](https://www.deepseek.com/). Please refer the **Get outputs from LLMs** for more information.
47 |
48 | # Tutorials
49 |
50 | Please use the example ipynb notebook in each folders as instructions. Evaluations are included in the notebooks. The demo tutorial can be finished in a normal computer within 10 minutes with a prepared environment.
51 |
52 | # Datasets
53 |
54 | All of the datasets and their download information are included in the Supplementary file 3. A demo dataset for clustering can be found in this [link](https://drive.google.com/file/d/1hHVutJ3tsAhkhTJ-wCNe9OfXubw2m2gN/view?usp=sharing).
55 |
56 | # Database for scELMo
57 |
58 | We are maintaining a [website](https://sites.google.com/yale.edu/scelmolib) containing embeddings of different information generated by LLM. We are happy to discuss if you have any requests or comments.
59 |
60 | # Acknowledgement
61 |
62 | We refer the codes from the following packages to implement scELMo. Many thanks to these great developers:
63 |
64 | [GenePT](https://github.com/yiqunchen/GenePT), [seq2cells](https://github.com/GSK-AI/seq2cells), [CINEMAOT](https://github.com/vandijklab/CINEMA-OT/tree/main), [CPA](https://github.com/theislab/cpa) and [GEARS](https://github.com/snap-stanford/GEARS/tree/master).
65 |
66 | # Open for contribution
67 |
68 | We are happy to see if you have more exciting ideas about the extension of scELMo. Feel free to contact us for discussion:
69 |
70 | Tianyu Liu (tianyu.liu@yale.edu)
71 |
72 | # Citation
73 | ```
74 | @article{liu2023scelmo,
75 | title={scELMo: Embeddings from Language Models are Good Learners for Single-cell Data Analysis},
76 | author={Liu, Tianyu and Chen, Tianqi and Zheng, Wangjie and Luo, Xiao and Zhao, Hongyu},
77 | journal={bioRxiv},
78 | pages={2023--12},
79 | year={2023},
80 | publisher={Cold Spring Harbor Laboratory}
81 | }
82 | ```
83 |
84 | # Related work
85 |
86 | - [spEMO](https://github.com/HelloWorldLTY/spEMO)
87 | - [scLAMBDA](https://github.com/gefeiwang/scLAMBDA)
--------------------------------------------------------------------------------
/reproductivity/repro_instruction.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Figure 1\n",
8 | "\n",
9 | "To reproduce figure 1, we utilize draw.io to visualize the model structure."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# Figure 2"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "To reproduce (a), please refer the folder **Get outputs from LLMs**. For local LLMs, we utilize huggingface for testing.\n",
24 | "\n",
25 | "To reproduce (b), please use the package time (import time) or magic instruction (%%time) for testing.\n",
26 | "\n",
27 | "To reproduce (c), please refer the folder **Clustering**."
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "# Figure 3"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "To reproduce (a)-(d), please refer the folder **Batch Effect Correction**. "
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "# Figure 4"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "To reproduce (a)-(c), please refer the folder **In silico treatment**. For (c), we record the output of classifier for comparision."
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "# Figure 5"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {},
68 | "source": [
69 | "To reproduce (a)-(d), please refer the folder **Perturbation Analysis**. For GEARS, the model performance might be affected by random seed, due to the implementation of pyg."
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {},
75 | "source": [
76 | "# Table 1"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {},
82 | "source": [
83 | "To reproduce Table 1, please refer the folder **Cell-type Annotation** for both zero-shot and fine-tuning setting."
84 | ]
85 | }
86 | ],
87 | "metadata": {
88 | "language_info": {
89 | "name": "python"
90 | }
91 | },
92 | "nbformat": 4,
93 | "nbformat_minor": 2
94 | }
95 |
--------------------------------------------------------------------------------
/seq2emb/README.md:
--------------------------------------------------------------------------------
1 | # Preprocessing sequence embeddings (These codes come from seq2cells)
2 |
3 | ### 0) Intro - Workflow
4 |
5 | Given some regions of interest (ROI), e.g. transcription start sites (TSS) the
6 | aim of the sequence pre-processing
7 | is to obtain:
8 |
9 | 1) A **query file** that specifies for each ROI: the DNA sequence
10 | window surrounding it and the location of the region of interest within this
11 | window.
12 | 2) Pre-computed DNA sequence **embeddings** for each ROI computed with the
13 | Enformer trunk
14 | 3) Gene (region) IDs that specifiy their intersection with Enformer training,
15 | test and validation sequences for splitting the dataset.
16 |
17 | ### 1) Query file
18 |
19 | Enformer embeds and predicts over 896 bins of 128 bp covering the central
20 | 114,688 bp of the sequence queries of length 196,608 bp.
21 | To extract embeddings of genomic ROIs, we construct sequence
22 | queries of length 196,608 bp and identify the corresponding Enformer output
23 | window
24 | within which the ROI lies so the correct embedding can be extracted.
25 |
26 | Using `create_seq_window_queries.py`
27 |
28 | This script will take regions of interest, stitch them into patches if desired
29 | and
30 | construct sequence windows adhering to chromosome boundaries and create queries
31 | for the sequence model.
32 | Genomic position and the index (bin_id) of the prediction bin with which the
33 | rois are intersecting are listed: 0-based!
34 | The subsequent script calculating DNA sequence embeddings can then use the
35 | bin_ids to extract the embeddings of interest.
36 |
37 | Stitching: if enabled will group the rois based on a supplied grouping
38 | variable (e.g. a gene name or id)
39 | ROIs with the same grouping id will be grouped into patches. Patches are
40 | split if they stretch over more than the supplied threshold (50kb default) and
41 | sequence windows are constructed over the center of patches.
42 | The position and bin id of the rois are listed in a comma separated string.
43 |
44 | Notes:
45 |
46 | * The stitching functionality is implemented but we do not use it for single
47 | cell expression predictions so far. To replicate the manuscript work run
48 | without stitching.
49 |
50 | * ROIs only accept a single position, if larger regions of interests should be
51 | supplied then please center the coordinate first.
52 |
53 | * By default, this script will center the sequence windows on ROIs or at the
54 | center of a stitched patch. Thus allowing predictions with a maximal
55 | sequence context reach for every roi.
56 |
57 | * If the number of prediction bins is even, such as with the default
58 | Enformer setting, then the center of the sequence window is covered
59 | by the border of two bins. In that case the sequence window is shifted by
60 | minus half a bin size to center the ROI within a single bin.
61 |
62 | #### Inputs:
63 |
64 | 1) A plain text file of ROI, where every line specifies a
65 | ROI supplied via the `--in` argument. Common formats are bed
66 | files or vcf file without header. Important, the genomic coodinates may be
67 | provided in bed-like (0-based, half open format) or as single column
68 | (1-based) vcf-like format.
69 | The coordinate handeling is controlled by the
70 | `position_col` and `position_base` arguments (see `--help`)
71 |
72 | ```angular2html
73 | chr1 65418 65419 ENST00000641515.2 . + ENSG00000186092.7 OR4F5
74 | chr1 451677 451678 ENST00000426406.4 . - ENSG00000284733.2 OR4F29
75 | chr1 686653 686654 ENST00000332831.5 . - ENSG00000284662.2 OR4F16
76 | chr1 923922 923923 ENST00000616016.5 . + ENSG00000187634.13 SAMD11
77 | ```
78 |
79 | 2) Reference genome in fasta file with .fai index present in same directory.
80 |
81 | ```angular2html
82 | >chr1
83 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
84 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
85 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
86 | ```
87 |
88 | #### Usage:
89 |
90 | ```bash
91 | python create_seq_window_queries.py \
92 | --in ./preprocessing_example_files/gencode.v41.basic.annotation.protein.coding.ensembl_canonical.tss.hg38.h10.bed \
93 | --ref_genome ./hg38.fa \
94 | --out ./query_tss_example.tsv \
95 | --chromosome_col 1\
96 | --position_col 3\
97 | --position_base 1 \
98 | --strand_col 6 \
99 | --group_id_col 7 \
100 | --additional_id_col 8 \
101 | --no-stitch
102 | ```
103 |
104 | #### Output
105 |
106 | Output is a tab-separated query file that lists the chrom start end strand of
107 | the sequence window the ids of the stitched patch and the grouping and
108 | additional_id, the center of the sequence window the number of regions of
109 | interest within the distance between multiple rois in the sequence and
110 | the strands, position and bin id of the rois, comma separated if multiple
111 | ones are available.
112 |
113 | ```angular2html
114 | chr seq_start seq_end seq_strand patch_id group_id add_id center num_roi stretch strands_roi positions_roi bins_roi
115 | chr1 1 196608 + ENSG00000186092.7_0 ENSG00000186092.7 OR4F5 98369 1 0 ['+'] 65419 191
116 | chr1 353310 549917 - ENSG00000284733.2_0 ENSG00000284733.2 OR4F29 451678 1 0 ['-'] 451678 448
117 | chr1 588286 784893 - ENSG00000284662.2_0 ENSG00000284662.2 OR4F16 686654 1 0 ['-'] 686654 448
118 | chr1 825555 1022162 + ENSG00000187634.13_0 ENSG00000187634.13 SAMD11 923923 1 0 ['+'] 923923 448
119 | chr1 860888 1057495 - ENSG00000188976.11_0 ENSG00000188976.11 NOC2L 959256 1 0 ['-'] 959256 448
120 | chr1 862216 1058823 + ENSG00000187961.15_0 ENSG00000187961.15 KLHL17 960584 1 0 ['+'] 960584 448
121 | chr1 868114 1064721 + ENSG00000187583.11_0 ENSG00000187583.11 PLEKHN1 966482 1 0 ['+'] 966482 448
122 | chr1 883725 1080332 - ENSG00000187642.10_0 ENSG00000187642.10 PERM1 982093 1 0 ['-'] 982093 448
123 | chr1 901729 1098336 - ENSG00000188290.11_0 ENSG00000188290.11 HES4 1000097 1 0 ['-'] 1000097 448
124 | ```
125 |
126 | ## 2) Sequence embeddings
127 |
128 | The next step is to pre-compute the sequence embeddings over the ROIs now
129 | specified in the query file.
130 |
131 | Using `calc_embeddings_and_targets.py`
132 |
133 | This script will take a query file as produced by
134 | `create_seq_window_queries.py` and compute embeddings and optionally predicted
135 | Enformer targets over the ROI.
136 |
137 | Main idea here is that ROI are always centered on the
138 | sequence model query window as much as possible to allow a balanced, maximal
139 | sequence context for each prediction.
140 |
141 | Ideally only a single region of interest or regions very close together are
142 | supplied per query. Larger sets should be split in the prior pre-processing
143 | step. E.g. split multiple clusters of TSS more than ~ 50 kb apart into
144 | separate entities for summary later.
145 |
146 | Embeddings and targets from multiple ROIs or with adjacent bins specified are
147 | aggregated according to the specified methods. Default: Embeddings - mean,
148 | Targets - sum.
149 |
150 | #### Notes
151 |
152 | * If the ROI / patch is located on the minus strand the reverse
153 | complement of the plus strand will be used as sequence input.
154 | * If the reverse_complement is forced via `--rc_force` the reverse_complement
155 | is applied to plus strand patches and minus strand patches are processed
156 | from the plus strand. The position of ROIs are always
157 | mirrored where necessary to ensure the correct targets/embeddings are
158 | extracted.
159 | * If the reverse complement augmentation is toggled on via `--rc_aug` then
160 | the reverse complement is applied randomly in 50 % of instances.
161 | * `--rc_force` overwrites `--rc_aug`
162 | * Shift augmentations are chosen randomly from the selected range of bp shifts
163 | selected a single bp shift if wanting to precisely control for that.
164 | * Note: preprocessing with multiple ROIs per query is supported but all
165 | single cell work carried out by us was using a single ROI (TSS of
166 | canonical transcript).
167 |
168 | #### Input
169 |
170 | 1) Query file as produced by `create_seq_window_queries.py` which is
171 | a raw text file
172 | including a header column that specifies the sequence windows to be
173 | processed by the seq model and the positions of the regions of interest
174 | within that sequence to be extracted (roi). Positions and bins of
175 | multiple ROI per query are comma separated in one string. Example format:
176 |
177 | ```angular2html
178 | chr seq_start seq_end seq_strand patch_id group_id add_id center num_roi stretch strands_roi positions_roi bins_roi
179 | chr1 1 196608 + ENSG00000186092.7_0 ENSG00000186092.7 OR4F5 98369 1 0 ['+'] 65419 191
180 | chr1 353310 549917 - ENSG00000284733.2_0 ENSG00000284733.2 OR4F29 451678 1 0 ['-'] 451678 448
181 | chr1 588286 784893 - ENSG00000284662.2_0 ENSG00000284662.2 OR4F16 686654 1 0 ['-'] 686654 448
182 | chr1 825555 1022162 + ENSG00000187634.13_0 ENSG00000187634.13 SAMD11 923923 1 0 ['+'] 923923 448
183 | chr1 860888 1057495 - ENSG00000188976.11_0 ENSG00000188976.11 NOC2L 959256 1 0 ['-'] 959256 448
184 | chr1 862216 1058823 + ENSG00000187961.15_0 ENSG00000187961.15 KLHL17 960584 1 0 ['+'] 960584 448
185 | chr1 868114 1064721 + ENSG00000187583.11_0 ENSG00000187583.11 PLEKHN1 966482 1 0 ['+'] 966482 448
186 | chr1 883725 1080332 - ENSG00000187642.10_0 ENSG00000187642.10 PERM1 982093 1 0 ['-'] 982093 448
187 | chr1 901729 1098336 - ENSG00000188290.11_0 ENSG00000188290.11 HES4 1000097 1 0 ['-'] 1000097 448
188 | ```
189 |
190 | 2) Reference genome in fasta format. Needs to be indexed (same name file
191 | with .fa.fai ending present)
192 |
193 | ```angular2html
194 | >chr1
195 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
196 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
197 | NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
198 | ```
199 |
200 | #### Usage
201 |
202 | ```bash
203 | python calc_embeddings_and_targets.py \
204 | --in_query ./preprocessing_example_files/query_tss_example.tsv \
205 | --ref_genome hg38.fa \
206 | --out_name enformer_out \
207 | --position_base 1 \
208 | --add_bins 0 \
209 | --store_text \
210 | --store_h5 \
211 | --targets '4675:5312' # for all Enformer cage-seq targets
212 | ```
213 |
214 | #### Output
215 |
216 | Output are one or two tab separated
217 | text files storing the embeddings and optionally targets and/or an hdf5 file
218 | storing the
219 | embedding and target as pandas data frames under the 'emb' and 'tar' handle
220 | respectively.
221 | The header columns in the embedding file indicate the embedding dimensions.
222 | The header columns in the target text file / data frame
223 | correspond to the selected target ids (0-based) of Enformer targets
224 | (see the
225 | [published Basenji2 targets](https://github.com/calico/basenji/tree/master/manuscripts/cross2020)
226 | ).
227 | Targets are subset to the selected targets, the indices of the selected are
228 | stored in the header of the target output file (0-based)
229 |
230 | Example raw text outputs:
231 |
232 | ```bash
233 | head -n 3 enformer_out*tsv | cut -f 1,2,3
234 | ==> enformer_out_emb.tsv <==
235 | 0 1 2
236 | -0.11201313883066177 -0.0001226698950631544 -0.10420460253953934
237 | -0.1380479633808136 -8.836987944960129e-06 -0.14271216094493866
238 |
239 | ==> enformer_out_tar.tsv <==
240 | 4675 4676 4677
241 | 0.021540187299251556 0.012503976002335548 0.012968547642230988
242 | 0.01947534829378128 0.007085299119353294 0.007071667350828648
243 | ```
244 |
245 | ## 3) Intersect regions of interest with Enformer train / test / valid regions
246 |
247 | For splitting genes into training, test and validation set we intersect the
248 | position of their TSS with the regions over which Enformer is trained to
249 | predict chromatin features and CAGE-seq coverage. See
250 | [Kelley 2020](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1008050#sec010)
251 | For a description of the train, test, valid split region construction. The genes
252 | whose TSS intersect with test and validation regions are extracted as test and
253 | validation set for the single cell work. Where a TSS intersect with multiple
254 | Enformer regions we select the one where the TSS is most central.
255 |
256 | ### Notes
257 | By default the Enformer input sequences are of length 196,608 bp.
258 | These regions were taken from the Basenji2 work with regions of length
259 | 131,072 bp and extended by 32,768 bp to each side.
260 | The 131,072 bp sequences were shared by the authors.
261 | By default we trim the shared sequences to the central
262 | 114,688 bp, because Enformer is only trained to predict over
263 | those 896 * 128 bp bins of each sequence window.
264 | The pruning can be disabled via the `--no_prune` flag. This will intersect
265 | the TSS with the 131,072 bp sequences.
266 | Alternatively, using `--extend` flag the sequence windows can be extended to
267 | the full 196,608 bp.
268 |
269 | #### Input
270 |
271 | 1) Query file as produced by `create_seq_window_queries.py` which is
272 | a raw text file
273 | including a header column that specifies the sequence windows to be
274 | processed by the seq model and the positions of the regions of interest
275 | within that sequence to be extracted (roi). Positions and bins of
276 | multiple ROI per query are comma separated in one string.
277 | The 'patch_id' column is used for unique RSS/ROI identification
278 | Example format:
279 |
280 | ```angular2html
281 | chr seq_start seq_end seq_strand patch_id group_id add_id center num_roi stretch strands_roi positions_roi bins_roi
282 | chr1 1 196608 + ENSG00000186092.7_0 ENSG00000186092.7 OR4F5 98369 1 0 ['+'] 65419 191
283 | chr1 353310 549917 - ENSG00000284733.2_0 ENSG00000284733.2 OR4F29 451678 1 0 ['-'] 451678 448
284 | chr1 588286 784893 - ENSG00000284662.2_0 ENSG00000284662.2 OR4F16 686654 1 0 ['-'] 686654 448
285 | chr1 825555 1022162 + ENSG00000187634.13_0 ENSG00000187634.13 SAMD11 923923 1 0 ['+'] 923923 448
286 | chr1 860888 1057495 - ENSG00000188976.11_0 ENSG00000188976.11 NOC2L 959256 1 0 ['-'] 959256 448
287 | chr1 862216 1058823 + ENSG00000187961.15_0 ENSG00000187961.15 KLHL17 960584 1 0 ['+'] 960584 448
288 | chr1 868114 1064721 + ENSG00000187583.11_0 ENSG00000187583.11 PLEKHN1 966482 1 0 ['+'] 966482 448
289 | chr1 883725 1080332 - ENSG00000187642.10_0 ENSG00000187642.10 PERM1 982093 1 0 ['-'] 982093 448
290 | chr1 901729 1098336 - ENSG00000188290.11_0 ENSG00000188290.11 HES4 1000097 1 0 ['-'] 1000097 448
291 | ```
292 |
293 | 2) Enformer sequences with train, test, validation assignment. The regions
294 | were [shared](https://console.cloud.google.com/storage/browser/basenji_barnyard/data)
295 | by the Basenji2/Enformer authors. And are also stored with thre files
296 | required for pre-processing here ... #TODO
297 |
298 | ```angular2html
299 | chr18 936578 1051266 train
300 | chr4 113639139 113753827 train
301 | chr11 18435912 18550600 train
302 | chr16 85813873 85928561 train
303 | chr3 158394380 158509068 train
304 | chr7 136791743 136906431 train
305 | chr8 132166506 132281194 valid
306 | chr21 35647195 35761883 valid
307 | chr16 24529786 24644474 test
308 | chr8 18655640 18770328 test
309 | ```
310 |
311 | Using `intersect_queries_with_enformer_regions.py`
312 |
313 | Run as
314 |
315 | ```bash
316 | python intersect_queries_with_enformer_regions.py \
317 | --query query_gencode_v41_protein_coding_canonical_tss_hg38_nostitch.tsv \
318 | --enf_seqs sequences.bed \
319 | --strip
320 | ```
321 |
322 | #### Output
323 |
324 | Three raw text files with the gene IDs belonging to train, test and
325 | validation set respectively. Those are used for
326 | tagging the genes in `add_embeddings_to_anndata.py`.
327 |
328 | ```bash
329 | head -n 3 query_enf_intersect_*.txt
330 | ==> query_enf_intersect_test.txt <==
331 | ENSG00000003096
332 | ENSG00000004776
333 | ENSG00000004777
334 |
335 | ==> query_enf_intersect_train.txt <==
336 | ENSG00000000457
337 | ENSG00000000460
338 | ENSG00000000938
339 |
340 | ==> query_enf_intersect_valid.txt <==
341 | ENSG00000000003
342 | ENSG00000000005
343 | ENSG00000000419
344 | ```
345 |
--------------------------------------------------------------------------------
/seq2emb/preprocessing_example_files/enformer_out.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HelloWorldLTY/scELMo/e1d51c22e4ef8c7c343d0c376480010a11124c08/seq2emb/preprocessing_example_files/enformer_out.h5
--------------------------------------------------------------------------------
/seq2emb/preprocessing_example_files/gencode.v41.basic.annotation.protein.coding.ensembl_canonical.tss.hg38.h10.bed:
--------------------------------------------------------------------------------
1 | chr1 65418 65419 ENST00000641515.2 . + ENSG00000186092.7 OR4F5
2 | chr1 451677 451678 ENST00000426406.4 . - ENSG00000284733.2 OR4F29
3 | chr1 686653 686654 ENST00000332831.5 . - ENSG00000284662.2 OR4F16
4 | chr1 923922 923923 ENST00000616016.5 . + ENSG00000187634.13 SAMD11
5 | chr1 959255 959256 ENST00000327044.7 . - ENSG00000188976.11 NOC2L
6 | chr1 960583 960584 ENST00000338591.8 . + ENSG00000187961.15 KLHL17
7 | chr1 966481 966482 ENST00000379410.8 . + ENSG00000187583.11 PLEKHN1
8 | chr1 982092 982093 ENST00000433179.4 . - ENSG00000187642.10 PERM1
9 | chr1 1000096 1000097 ENST00000304952.11 . - ENSG00000188290.11 HES4
10 | chr1 1013496 1013497 ENST00000649529.1 . + ENSG00000187608.10 ISG15
11 |
--------------------------------------------------------------------------------
/seq2emb/preprocessing_example_files/query_tss_example.tsv:
--------------------------------------------------------------------------------
1 | chr seq_start seq_end seq_strand patch_id group_id add_id center num_roi stretch strands_roi positions_roi bins_roi
2 | chr1 1 196608 + ENSG00000186092.7_0 ENSG00000186092.7 OR4F5 98369 1 0 ['+'] 65419 191
3 | chr1 353310 549917 - ENSG00000284733.2_0 ENSG00000284733.2 OR4F29 451678 1 0 ['-'] 451678 448
4 | chr1 588286 784893 - ENSG00000284662.2_0 ENSG00000284662.2 OR4F16 686654 1 0 ['-'] 686654 448
5 | chr1 825555 1022162 + ENSG00000187634.13_0 ENSG00000187634.13 SAMD11 923923 1 0 ['+'] 923923 448
6 | chr1 860888 1057495 - ENSG00000188976.11_0 ENSG00000188976.11 NOC2L 959256 1 0 ['-'] 959256 448
7 | chr1 862216 1058823 + ENSG00000187961.15_0 ENSG00000187961.15 KLHL17 960584 1 0 ['+'] 960584 448
8 | chr1 868114 1064721 + ENSG00000187583.11_0 ENSG00000187583.11 PLEKHN1 966482 1 0 ['+'] 966482 448
9 | chr1 883725 1080332 - ENSG00000187642.10_0 ENSG00000187642.10 PERM1 982093 1 0 ['-'] 982093 448
10 | chr1 901729 1098336 - ENSG00000188290.11_0 ENSG00000188290.11 HES4 1000097 1 0 ['-'] 1000097 448
11 | chr1 915129 1111736 + ENSG00000187608.10_0 ENSG00000187608.10 ISG15 1013497 1 0 ['+'] 1013497 448
12 |
--------------------------------------------------------------------------------
/seq2emb/pseudobulk_anndata.py:
--------------------------------------------------------------------------------
1 | """
2 | Create embedding query from DNA sequence window and regions of interest.
3 | =========================================
4 | Copyright 2023 GlaxoSmithKline Research & Development Limited. All rights reserved.
5 |
6 | Licensed under the Apache License, Version 2.0 (the "License");
7 | you may not use this file except in compliance with the License.
8 | You may obtain a copy of the License at
9 |
10 | http://www.apache.org/licenses/LICENSE-2.0
11 |
12 | Unless required by applicable law or agreed to in writing, software
13 | distributed under the License is distributed on an "AS IS" BASIS,
14 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | See the License for the specific language governing permissions and
16 | limitations under the License.
17 | =========================================
18 | ..Input::
19 | A single cell (RNA) AnnData object with .obs being the genes and
20 | .var being the individual cells [gene x cell].
21 | Expects a .var column matching the cell_type_col_name argument.
22 | Expects .obs to be indexed of gene ID or symbols.
23 | Expects a .obs column matching the gene_col_name argument if one was
24 | provided that is not 'index'. If index is provided will use the gene ID
25 | index instead of a gene name.
26 | Expects a layer matching the layer argument to be present if specified.
27 |
28 | ..Arguments::
29 | -h, --help Show this help message and exit
30 | --in IN_FILE Input file is an anndata object saved as h5ad file.
31 | --genes GENES
32 | List gene ids or symbols to compute the pseudobulk
33 | aggregate for. Must match the entries in gene_col_name
34 | of the anndata object. Default = ''.
35 | --gene_col_name GENE_COL_NAME
36 | Name of .obs column where gene names can be found
37 | that should be used for the aggregation. If set to
38 | 'index' will use the .obs index instead.
39 | Default='index"
40 | --cell_type_col_name CELL_TYPE_COL
41 | Name of the .var column that indicates the cell types
42 | that will be used for the pseudobulking.
43 | Default='cell types
44 | --method METHOD
45 | Method to use for pseudobulking, supports:
46 | 'mean' - take the mean of the reads per gene
47 | per cell type
48 | 'sum' - take the sum of the reads per gene per cell
49 | type
50 | 'count_exp' - count the cells that express the gene at
51 | or above an expression threshold provided per gene and
52 | cell type
53 | 'perc_exp' - calculate the fraction of cells that
54 | express
55 | the gene at or above an expression threshold provided per
56 | gene and cell type.
57 | Default = 'mean'
58 | --expr_threshold EXP_THRESHOLD
59 | Threshold at or above which a gene should be
60 | considered as expressed. Matching the observed counts
61 | in the anndata object.
62 | --layer LATER
63 | If provided will use the anndata layer instead of the .X
64 | counts.
65 |
66 | ..Usage::
67 | python ./pseudobulk_anndata.py \
68 | --in my_anndata.h5ad \
69 | --out my_pseudobulked_anndata.h5ad \
70 | --gene_col_name 'index' \
71 | --cell_type_col_name 'cell types'\
72 | --method 'mean'
73 |
74 | ..Output:: Output is AnnData object stored as .h5ad file under the --out
75 | location, with .obs being the genes and .var being the individual
76 | cell types [gene x cell types]. Where observed counts were aggregated
77 | according to the chosen method.
78 | """
79 | import argparse
80 | import logging
81 |
82 | import scanpy as sc
83 |
84 | from seq2cells.utils.anndata_utils import pseudo_bulk
85 |
86 | parser = argparse.ArgumentParser(
87 | description="Pseudobulk an AnnData object by cell type."
88 | )
89 | parser.add_argument(
90 | "--in",
91 | dest="in_file",
92 | type=str,
93 | required=True,
94 | help="Input anndata file in .h5ad format.",
95 | )
96 | parser.add_argument(
97 | "--genes",
98 | dest="genes",
99 | nargs="+",
100 | default="",
101 | required=False,
102 | help="List gene ids or symbols to compute the pseudobulk aggregate for. "
103 | "Must match the entries in gene_col_name of the anndata object.",
104 | )
105 | parser.add_argument(
106 | "--out",
107 | dest="out_file",
108 | default="./query_file_seq_model.tsv",
109 | type=str,
110 | required=True,
111 | help="Path and name for storing the pseudobulked anndata .h5ad",
112 | )
113 | parser.add_argument(
114 | "--gene_col_name",
115 | dest="gene_col_name",
116 | default="index",
117 | type=str,
118 | required=False,
119 | help="Name of .obs column where gene names can be found that should be "
120 | "used for the aggregation. If set to 'index' will use the .obs "
121 | "index instead. Default='index",
122 | )
123 | parser.add_argument(
124 | "--cell_type_col_name",
125 | dest="cell_type_col_name",
126 | default="cell types",
127 | type=str,
128 | required=False,
129 | help="Name of the .var column that indicates the cell types "
130 | "that will be used for the pseudobulking. "
131 | "Default='cell types",
132 | )
133 | parser.add_argument(
134 | "--method",
135 | dest="method",
136 | default="mean",
137 | type=str,
138 | required=False,
139 | help="Method to use for pseudobulking, supports:"
140 | "'mean' - take the mean of the reads per gene per cell type"
141 | "'sum' - take the sum of the reads per gene per cell type"
142 | "'count_exp' - count the cells that express the gene at or above an "
143 | "expression threshold provided per gene and cell type"
144 | "'perc_exp' - calculate the fraction of cells that express the "
145 | "gene at or above an expression threshold provided per gene and cell type. "
146 | "Default = 'mean",
147 | )
148 | parser.add_argument(
149 | "--expr_threshold",
150 | dest="expr_threshold",
151 | default=0.5,
152 | type=float,
153 | required=False,
154 | help="Threshold at or above which a gene should be considered as expressed. "
155 | "Matching the observed counts in the anndata object. Default = 0.5",
156 | )
157 | parser.add_argument(
158 | "--layer",
159 | dest="layer",
160 | default=None,
161 | type=str,
162 | required=False,
163 | help="If provided will use the anndata layer instead of the .X counts.",
164 | )
165 | parser.add_argument(
166 | "--mem_friendly",
167 | dest="mem_friendly",
168 | action="store_true",
169 | help="Flag to run in memory friendly mode. Takes oj the order of 10 times longer.",
170 | )
171 | parser.set_defaults(mem_friendly=False)
172 | parser.add_argument(
173 | "--debug", dest="debug", action="store_true", help="Flag switch on debugging mode."
174 | )
175 | parser.set_defaults(debug=False)
176 |
177 |
178 | if __name__ == "__main__":
179 | # fetch arguments
180 | args = parser.parse_args()
181 |
182 | if args.debug:
183 | logging.basicConfig(level=logging.INFO)
184 | logger = logging.getLogger(__name__)
185 |
186 | # set scanpy verbosity
187 | # verbosity: errors (0), warnings (1), info (2), hints (3)
188 | if args.debug:
189 | sc.settings.verbosity = 3
190 | else:
191 | sc.settings.verbosity = 1
192 |
193 | # assert valid aggregation method selected
194 | assert args.method in [
195 | "mean",
196 | "sum",
197 | "perc_exp",
198 | "count_exp",
199 | ], "Invalid aggregation method selected!"
200 |
201 | # read anndata
202 | adata = sc.read_h5ad(args.in_file)
203 |
204 | # check selected genes
205 | if args.genes == "":
206 | genes = []
207 | num_genes = "all"
208 | else:
209 | genes = args.genes
210 | num_genes = len(genes)
211 |
212 | # run pseudobulking
213 | logger.info(f"Pseudo bulking {num_genes} genes ...")
214 |
215 | if args.mem_friendly:
216 | pseudo_adata = pseudo_bulk(
217 | adata,
218 | genes=genes,
219 | cell_type_col=args.cell_type_col_name,
220 | gene_col=args.gene_col_name,
221 | mode=args.method,
222 | expr_threshold=args.expr_threshold,
223 | mem_efficient_mode=True,
224 | layer=args.layer,
225 | )
226 | else:
227 | pseudo_adata = pseudo_bulk(
228 | adata,
229 | genes=genes,
230 | cell_type_col=args.cell_type_col_name,
231 | gene_col=args.gene_col_name,
232 | mode=args.method,
233 | expr_threshold=args.expr_threshold,
234 | mem_efficient_mode=False,
235 | layer=args.layer,
236 | )
237 |
238 | logger.info("Writting results to " + args.out_file)
239 | pseudo_adata.write(args.out_file)
240 |
--------------------------------------------------------------------------------