├── Dataset.py ├── PLM.py ├── README.md ├── module.py ├── mutate.csv ├── neutralize.csv ├── pipline.png ├── requirements.txt └── train.py /Dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | from pathlib import Path 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import pandas as pd 8 | from Bio import PDB 9 | from Bio.PDB import PDBParser 10 | from torch_geometric.data import Data 11 | from torch_geometric.data import InMemoryDataset 12 | 13 | protein_letters_3to1_extended = { 14 | "A5N": "N", "A8E": "V", "A9D": "S", "AA3": "A", "AA4": "A", "AAR": "R", 15 | "ABA": "A", "ACL": "R", "AEA": "C", "AEI": "D", "AFA": "N", "AGM": "R", 16 | "AGQ": "Y", "AGT": "C", "AHB": "N", "AHL": "R", "AHO": "A", "AHP": "A", 17 | "AIB": "A", "AKL": "D", "AKZ": "D", "ALA": "A", "ALC": "A", "ALM": "A", 18 | "ALN": "A", "ALO": "T", "ALS": "A", "ALT": "A", "ALV": "A", "ALY": "K", 19 | "AME": "M", "AN6": "L", "AN8": "A", "API": "K", "APK": "K", "AR2": "R", 20 | "AR4": "E", "AR7": "R", "ARG": "R", "ARM": "R", "ARO": "R", "AS7": "N", 21 | "ASA": "D", "ASB": "D", "ASI": "D", "ASK": "D", "ASL": "D", "ASN": "N", 22 | "ASP": "D", "ASQ": "D", "AYA": "A", "AZH": "A", "AZK": "K", "AZS": "S", 23 | "AZY": "Y", "AVJ": "H", "A30": "Y", "A3U": "F", "ECC": "Q", "ECX": "C", 24 | "EFC": "C", "EHP": "F", "ELY": "K", "EME": "E", "EPM": "M", "EPQ": "Q", 25 | "ESB": "Y", "ESC": "M", "EXY": "L", "EXA": "K", "E0Y": "P", "E9V": "H", 26 | "E9M": "W", "EJA": "C", "EUP": "T", "EZY": "G", "E9C": "Y", "EW6": "S", 27 | "EXL": "W", "I2M": "I", "I4G": "G", "I58": "K", "IAM": "A", "IAR": "R", 28 | "ICY": "C", "IEL": "K", "IGL": "G", "IIL": "I", "ILE": "I", "ILG": "E", 29 | "ILM": "I", "ILX": "I", "ILY": "K", "IML": "I", "IOR": "R", "IPG": "G", 30 | "IT1": "K", "IYR": "Y", "IZO": "M", "IC0": "G", "M0H": "C", "M2L": "K", 31 | "M2S": "M", "M30": "G", "M3L": "K", "M3R": "K", "MA ": "A", "MAA": "A", 32 | "MAI": "R", "MBQ": "Y", "MC1": "S", "MCL": "K", "MCS": "C", "MD3": "C", 33 | "MD5": "C", "MD6": "G", "MDF": "Y", "ME0": "M", "MEA": "F", "MEG": "E", 34 | "MEN": "N", "MEQ": "Q", "MET": "M", "MEU": "G", "MFN": "E", "MGG": "R", 35 | "MGN": "Q", "MGY": "G", "MH1": "H", "MH6": "S", "MHL": "L", "MHO": "M", 36 | "MHS": "H", "MHU": "F", "MIR": "S", "MIS": "S", "MK8": "L", "ML3": "K", 37 | "MLE": "L", "MLL": "L", "MLY": "K", "MLZ": "K", "MME": "M", "MMO": "R", 38 | "MNL": "L", "MNV": "V", "MP8": "P", "MPQ": "G", "MSA": "G", "MSE": "M", 39 | "MSL": "M", "MSO": "M", "MT2": "M", "MTY": "Y", "MVA": "V", "MYK": "K", 40 | "MYN": "R", "QCS": "C", "QIL": "I", "QMM": "Q", "QPA": "C", "QPH": "F", 41 | "Q3P": "K", "QVA": "C", "QX7": "A", "Q2E": "W", "Q75": "M", "Q78": "F", 42 | "QM8": "L", "QMB": "A", "QNQ": "C", "QNT": "C", "QNW": "C", "QO2": "C", 43 | "QO5": "C", "QO8": "C", "QQ8": "Q", "U2X": "Y", "U3X": "F", "UF0": "S", 44 | "UGY": "G", "UM1": "A", "UM2": "A", "UMA": "A", "UQK": "A", "UX8": "W", 45 | "UXQ": "F", "YCM": "C", "YOF": "Y", "YPR": "P", "YPZ": "Y", "YTH": "T", 46 | "Y1V": "L", "Y57": "K", "YHA": "K", "200": "F", "23F": "F", "23P": "A", 47 | "26B": "T", "28X": "T", "2AG": "A", "2CO": "C", "2FM": "M", "2GX": "F", 48 | "2HF": "H", "2JG": "S", "2KK": "K", "2KP": "K", "2LT": "Y", "2LU": "L", 49 | "2ML": "L", "2MR": "R", "2MT": "P", "2OR": "R", "2P0": "P", "2QZ": "T", 50 | "2R3": "Y", "2RA": "A", "2RX": "S", "2SO": "H", "2TY": "Y", "2VA": "V", 51 | "2XA": "C", "2ZC": "S", "6CL": "K", "6CW": "W", "6GL": "A", "6HN": "K", 52 | "60F": "C", "66D": "I", "6CV": "A", "6M6": "C", "6V1": "C", "6WK": "C", 53 | "6Y9": "P", "6DN": "K", "DA2": "R", "DAB": "A", "DAH": "F", "DBS": "S", 54 | "DBU": "T", "DBY": "Y", "DBZ": "A", "DC2": "C", "DDE": "H", "DDZ": "A", 55 | "DI7": "Y", "DHA": "S", "DHN": "V", "DIR": "R", "DLS": "K", "DM0": "K", 56 | "DMH": "N", "DMK": "D", "DNL": "K", "DNP": "A", "DNS": "K", "DNW": "A", 57 | "DOH": "D", "DON": "L", "DP1": "R", "DPL": "P", "DPP": "A", "DPQ": "Y", 58 | "DYS": "C", "D2T": "D", "DYA": "D", "DJD": "F", "DYJ": "P", "DV9": "E", 59 | "H14": "F", "H1D": "M", "H5M": "P", "HAC": "A", "HAR": "R", "HBN": "H", 60 | "HCM": "C", "HGY": "G", "HHI": "H", "HIA": "H", "HIC": "H", "HIP": "H", 61 | "HIQ": "H", "HIS": "H", "HL2": "L", "HLU": "L", "HMR": "R", "HNC": "C", 62 | "HOX": "F", "HPC": "F", "HPE": "F", "HPH": "F", "HPQ": "F", "HQA": "A", 63 | "HR7": "R", "HRG": "R", "HRP": "W", "HS8": "H", "HS9": "H", "HSE": "S", 64 | "HSK": "H", "HSL": "S", "HSO": "H", "HT7": "W", "HTI": "C", "HTR": "W", 65 | "HV5": "A", "HVA": "V", "HY3": "P", "HYI": "M", "HYP": "P", "HZP": "P", 66 | "HIX": "A", "HSV": "H", "HLY": "K", "HOO": "H", "H7V": "A", "L5P": "K", 67 | "LRK": "K", "L3O": "L", "LA2": "K", "LAA": "D", "LAL": "A", "LBY": "K", 68 | "LCK": "K", "LCX": "K", "LDH": "K", "LE1": "V", "LED": "L", "LEF": "L", 69 | "LEH": "L", "LEM": "L", "LEN": "L", "LET": "K", "LEU": "L", "LEX": "L", 70 | "LGY": "K", "LLO": "K", "LLP": "K", "LLY": "K", "LLZ": "K", "LME": "E", 71 | "LMF": "K", "LMQ": "Q", "LNE": "L", "LNM": "L", "LP6": "K", "LPD": "P", 72 | "LPG": "G", "LPS": "S", "LSO": "K", "LTR": "W", "LVG": "G", "LVN": "V", 73 | "LWY": "P", "LYF": "K", "LYK": "K", "LYM": "K", "LYN": "K", "LYO": "K", 74 | "LYP": "K", "LYR": "K", "LYS": "K", "LYU": "K", "LYX": "K", "LYZ": "K", 75 | "LAY": "L", "LWI": "F", "LBZ": "K", "P1L": "C", "P2Q": "Y", "P2Y": "P", 76 | "P3Q": "Y", "PAQ": "Y", "PAS": "D", "PAT": "W", "PBB": "C", "PBF": "F", 77 | "PCA": "Q", "PCC": "P", "PCS": "F", "PE1": "K", "PEC": "C", "PF5": "F", 78 | "PFF": "F", "PG1": "S", "PGY": "G", "PHA": "F", "PHD": "D", "PHE": "F", 79 | "PHI": "F", "PHL": "F", "PHM": "F", "PKR": "P", "PLJ": "P", "PM3": "F", 80 | "POM": "P", "PPN": "F", "PR3": "C", "PR4": "P", "PR7": "P", "PR9": "P", 81 | "PRJ": "P", "PRK": "K", "PRO": "P", "PRS": "P", "PRV": "G", "PSA": "F", 82 | "PSH": "H", "PTH": "Y", "PTM": "Y", "PTR": "Y", "PVH": "H", "PXU": "P", 83 | "PYA": "A", "PYH": "K", "PYX": "C", "PH6": "P", "P9S": "C", "P5U": "S", 84 | "POK": "R", "T0I": "Y", "T11": "F", "TAV": "D", "TBG": "V", "TBM": "T", 85 | "TCQ": "Y", "TCR": "W", "TEF": "F", "TFQ": "F", "TH5": "T", "TH6": "T", 86 | "THC": "T", "THR": "T", "THZ": "R", "TIH": "A", "TIS": "S", "TLY": "K", 87 | "TMB": "T", "TMD": "T", "TNB": "C", "TNR": "S", "TNY": "T", "TOQ": "W", 88 | "TOX": "W", "TPJ": "P", "TPK": "P", "TPL": "W", "TPO": "T", "TPQ": "Y", 89 | "TQI": "W", "TQQ": "W", "TQZ": "C", "TRF": "W", "TRG": "K", "TRN": "W", 90 | "TRO": "W", "TRP": "W", "TRQ": "W", "TRW": "W", "TRX": "W", "TRY": "W", 91 | "TS9": "I", "TSY": "C", "TTQ": "W", "TTS": "Y", "TXY": "Y", "TY1": "Y", 92 | "TY2": "Y", "TY3": "Y", "TY5": "Y", "TY8": "Y", "TY9": "Y", "TYB": "Y", 93 | "TYC": "Y", "TYE": "Y", "TYI": "Y", "TYJ": "Y", "TYN": "Y", "TYO": "Y", 94 | "TYQ": "Y", "TYR": "Y", "TYS": "Y", "TYT": "Y", "TYW": "Y", "TYY": "Y", 95 | "T8L": "T", "T9E": "T", "TNQ": "W", "TSQ": "F", "TGH": "W", "X2W": "E", 96 | "XCN": "C", "XPR": "P", "XSN": "N", "XW1": "A", "XX1": "K", "XYC": "A", 97 | "XA6": "F", "11Q": "P", "11W": "E", "12L": "P", "12X": "P", "12Y": "P", 98 | "143": "C", "1AC": "A", "1L1": "A", "1OP": "Y", "1PA": "F", "1PI": "A", 99 | "1TQ": "W", "1TY": "Y", "1X6": "S", "56A": "H", "5AB": "A", "5CS": "C", 100 | "5CW": "W", "5HP": "E", "5OH": "A", "5PG": "G", "51T": "Y", "54C": "W", 101 | "5CR": "F", "5CT": "K", "5FQ": "A", "5GM": "I", "5JP": "S", "5T3": "K", 102 | "5MW": "K", "5OW": "K", "5R5": "S", "5VV": "N", "5XU": "A", "55I": "F", 103 | "999": "D", "9DN": "N", "9NE": "E", "9NF": "F", "9NR": "R", "9NV": "V", 104 | "9E7": "K", "9KP": "K", "9WV": "A", "9TR": "K", "9TU": "K", "9TX": "K", 105 | "9U0": "K", "9IJ": "F", "B1F": "F", "B27": "T", "B2A": "A", "B2F": "F", 106 | "B2I": "I", "B2V": "V", "B3A": "A", "B3D": "D", "B3E": "E", "B3K": "K", 107 | "B3U": "H", "B3X": "N", "B3Y": "Y", "BB6": "C", "BB7": "C", "BB8": "F", 108 | "BB9": "C", "BBC": "C", "BCS": "C", "BCX": "C", "BFD": "D", "BG1": "S", 109 | "BH2": "D", "BHD": "D", "BIF": "F", "BIU": "I", "BL2": "L", "BLE": "L", 110 | "BLY": "K", "BMT": "T", "BNN": "F", "BOR": "R", "BP5": "A", "BPE": "C", 111 | "BSE": "S", "BTA": "L", "BTC": "C", "BTK": "K", "BTR": "W", "BUC": "C", 112 | "BUG": "V", "BYR": "Y", "BWV": "R", "BWB": "S", "BXT": "S", "F2F": "F", 113 | "F2Y": "Y", "FAK": "K", "FB5": "A", "FB6": "A", "FC0": "F", "FCL": "F", 114 | "FDL": "K", "FFM": "C", "FGL": "G", "FGP": "S", "FH7": "K", "FHL": "K", 115 | "FHO": "K", "FIO": "R", "FLA": "A", "FLE": "L", "FLT": "Y", "FME": "M", 116 | "FOE": "C", "FP9": "P", "FPK": "P", "FT6": "W", "FTR": "W", "FTY": "Y", 117 | "FVA": "V", "FZN": "K", "FY3": "Y", "F7W": "W", "FY2": "Y", "FQA": "K", 118 | "F7Q": "Y", "FF9": "K", "FL6": "D", "JJJ": "C", "JJK": "C", "JJL": "C", 119 | "JLP": "K", "J3D": "C", "J9Y": "R", "J8W": "S", "JKH": "P", "N10": "S", 120 | "N7P": "P", "NA8": "A", "NAL": "A", "NAM": "A", "NBQ": "Y", "NC1": "S", 121 | "NCB": "A", "NEM": "H", "NEP": "H", "NFA": "F", "NIY": "Y", "NLB": "L", 122 | "NLE": "L", "NLN": "L", "NLO": "L", "NLP": "L", "NLQ": "Q", "NLY": "G", 123 | "NMC": "G", "NMM": "R", "NNH": "R", "NOT": "L", "NPH": "C", "NPI": "A", 124 | "NTR": "Y", "NTY": "Y", "NVA": "V", "NWD": "A", "NYB": "C", "NYS": "C", 125 | "NZH": "H", "N80": "P", "NZC": "T", "NLW": "L", "N0A": "F", "N9P": "A", 126 | "N65": "K", "R1A": "C", "R4K": "W", "RE0": "W", "RE3": "W", "RGL": "R", 127 | "RGP": "E", "RT0": "P", "RVX": "S", "RZ4": "S", "RPI": "R", "RVJ": "A", 128 | "VAD": "V", "VAF": "V", "VAH": "V", "VAI": "V", "VAL": "V", "VB1": "K", 129 | "VH0": "P", "VR0": "R", "V44": "C", "V61": "F", "VPV": "K", "V5N": "H", 130 | "V7T": "K", "Z01": "A", "Z3E": "T", "Z70": "H", "ZBZ": "C", "ZCL": "F", 131 | "ZU0": "T", "ZYJ": "P", "ZYK": "P", "ZZD": "C", "ZZJ": "A", "ZIQ": "W", 132 | "ZPO": "P", "ZDJ": "Y", "ZT1": "K", "30V": "C", "31Q": "C", "33S": "F", 133 | "33W": "A", "34E": "V", "3AH": "H", "3BY": "P", "3CF": "F", "3CT": "Y", 134 | "3GA": "A", "3GL": "E", "3MD": "D", "3MY": "Y", "3NF": "Y", "3O3": "E", 135 | "3PX": "P", "3QN": "K", "3TT": "P", "3XH": "G", "3YM": "Y", "3WS": "A", 136 | "3WX": "P", "3X9": "C", "3ZH": "H", "7JA": "I", "73C": "S", "73N": "R", 137 | "73O": "Y", "73P": "K", "74P": "K", "7N8": "F", "7O5": "A", "7XC": "F", 138 | "7ID": "D", "7OZ": "A", "C1S": "C", "C1T": "C", "C1X": "K", "C22": "A", 139 | "C3Y": "C", "C4R": "C", "C5C": "C", "C6C": "C", "CAF": "C", "CAS": "C", 140 | "CAY": "C", "CCS": "C", "CEA": "C", "CGA": "E", "CGU": "E", "CGV": "C", 141 | "CHP": "G", "CIR": "R", "CLE": "L", "CLG": "K", "CLH": "K", "CME": "C", 142 | "CMH": "C", "CML": "C", "CMT": "C", "CR5": "G", "CS0": "C", "CS1": "C", 143 | "CS3": "C", "CS4": "C", "CSA": "C", "CSB": "C", "CSD": "C", "CSE": "C", 144 | "CSJ": "C", "CSO": "C", "CSP": "C", "CSR": "C", "CSS": "C", "CSU": "C", 145 | "CSW": "C", "CSX": "C", "CSZ": "C", "CTE": "W", "CTH": "T", "CWD": "A", 146 | "CWR": "S", "CXM": "M", "CY0": "C", "CY1": "C", "CY3": "C", "CY4": "C", 147 | "CYA": "C", "CYD": "C", "CYF": "C", "CYG": "C", "CYJ": "K", "CYM": "C", 148 | "CYQ": "C", "CYR": "C", "CYS": "C", "CYW": "C", "CZ2": "C", "CZZ": "C", 149 | "CG6": "C", "C1J": "R", "C4G": "R", "C67": "R", "C6D": "R", "CE7": "N", 150 | "CZS": "A", "G01": "E", "G8M": "E", "GAU": "E", "GEE": "G", "GFT": "S", 151 | "GHC": "E", "GHG": "Q", "GHW": "E", "GL3": "G", "GLH": "Q", "GLJ": "E", 152 | "GLK": "E", "GLN": "Q", "GLQ": "E", "GLU": "E", "GLY": "G", "GLZ": "G", 153 | "GMA": "E", "GME": "E", "GNC": "Q", "GPL": "K", "GSC": "G", "GSU": "E", 154 | "GT9": "C", "GVL": "S", "G3M": "R", "G5G": "L", "G1X": "Y", "G8X": "P", 155 | "K1R": "C", "KBE": "K", "KCX": "K", "KFP": "K", "KGC": "K", "KNB": "A", 156 | "KOR": "M", "KPI": "K", "KPY": "K", "KST": "K", "KYN": "W", "KYQ": "K", 157 | "KCR": "K", "KPF": "K", "K5L": "S", "KEO": "K", "KHB": "K", "KKD": "D", 158 | "K5H": "C", "K7K": "S", "OAR": "R", "OAS": "S", "OBS": "K", "OCS": "C", 159 | "OCY": "C", "OHI": "H", "OHS": "D", "OLD": "H", "OLT": "T", "OLZ": "S", 160 | "OMH": "S", "OMT": "M", "OMX": "Y", "OMY": "Y", "ONH": "A", "ORN": "A", 161 | "ORQ": "R", "OSE": "S", "OTH": "T", "OXX": "D", "OYL": "H", "O7A": "T", 162 | "O7D": "W", "O7G": "V", "O2E": "S", "O6H": "W", "OZW": "F", "S12": "S", 163 | "S1H": "S", "S2C": "C", "S2P": "A", "SAC": "S", "SAH": "C", "SAR": "G", 164 | "SBG": "S", "SBL": "S", "SCH": "C", "SCS": "C", "SCY": "C", "SD4": "N", 165 | "SDB": "S", "SDP": "S", "SEB": "S", "SEE": "S", "SEG": "A", "SEL": "S", 166 | "SEM": "S", "SEN": "S", "SEP": "S", "SER": "S", "SET": "S", "SGB": "S", 167 | "SHC": "C", "SHP": "G", "SHR": "K", "SIB": "C", "SLL": "K", "SLZ": "K", 168 | "SMC": "C", "SME": "M", "SMF": "F", "SNC": "C", "SNN": "N", "SOY": "S", 169 | "SRZ": "S", "STY": "Y", "SUN": "S", "SVA": "S", "SVV": "S", "SVW": "S", 170 | "SVX": "S", "SVY": "S", "SVZ": "S", "SXE": "S", "SKH": "K", "SNM": "S", 171 | "SNK": "H", "SWW": "S", "WFP": "F", "WLU": "L", "WPA": "F", "WRP": "W", 172 | "WVL": "V", "02K": "A", "02L": "N", "02O": "A", "02Y": "A", "033": "V", 173 | "037": "P", "03Y": "C", "04U": "P", "04V": "P", "05N": "P", "07O": "C", 174 | "0A0": "D", "0A1": "Y", "0A2": "K", "0A8": "C", "0A9": "F", "0AA": "V", 175 | "0AB": "V", "0AC": "G", "0AF": "W", "0AG": "L", "0AH": "S", "0AK": "D", 176 | "0AR": "R", "0BN": "F", "0CS": "A", "0E5": "T", "0EA": "Y", "0FL": "A", 177 | "0LF": "P", "0NC": "A", "0PR": "Y", "0QL": "C", "0TD": "D", "0UO": "W", 178 | "0WZ": "Y", "0X9": "R", "0Y8": "P", "4AF": "F", "4AR": "R", "4AW": "W", 179 | "4BF": "Y", "4CF": "F", "4CY": "M", "4DP": "W", "4FB": "P", "4FW": "W", 180 | "4HL": "Y", "4HT": "W", "4IN": "W", "4MM": "M", "4PH": "F", "4U7": "A", 181 | "41H": "F", "41Q": "N", "42Y": "S", "432": "S", "45F": "P", "4AK": "K", 182 | "4D4": "R", "4GJ": "C", "4KY": "P", "4L0": "P", "4LZ": "Y", "4N7": "P", 183 | "4N8": "P", "4N9": "P", "4OG": "W", "4OU": "F", "4OV": "S", "4OZ": "S", 184 | "4PQ": "W", "4SJ": "F", "4WQ": "A", "4HH": "S", "4HJ": "S", "4J4": "C", 185 | "4J5": "R", "4II": "F", "4VI": "R", "823": "N", "8SP": "S", "8AY": "A", 186 | } 187 | 188 | 189 | def process_pdb_file(input_file, output_file, chains): 190 | amino_acids = {'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLU', 'GLN', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 191 | 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'} 192 | with open(input_file, 'r') as f, open(output_file, 'w') as out: 193 | for line in f.readlines(): 194 | if line.startswith('ATOM'): 195 | chain = line[21] 196 | residue = line[17:20].strip() 197 | if chain in chains and residue in amino_acids: 198 | out.write(line) 199 | 200 | 201 | def extract_from_pdb_by_chain(pdb_file, chain_id): 202 | parser = PDBParser(QUIET=True) 203 | structure = parser.get_structure('complex', pdb_file) 204 | chain = structure[0][chain_id] 205 | 206 | sequence = '' 207 | 208 | for residue in chain: 209 | if PDB.is_aa(residue): 210 | residue_name = residue.get_resname() 211 | 212 | single_aa = protein_letters_3to1_extended[residue_name] 213 | sequence += single_aa 214 | return sequence 215 | 216 | 217 | class AffinityDataset(InMemoryDataset): 218 | def __init__(self, root): 219 | super(AffinityDataset, self).__init__(root) 220 | self.data, self.slices = torch.load(self.processed_paths[0]) 221 | 222 | @property 223 | def raw_file_names(self): 224 | return [] 225 | 226 | @property 227 | def processed_file_names(self): 228 | return ['AffinityDataset.pt'] 229 | 230 | def download(self): 231 | pass 232 | 233 | def _normalize(self, tensor, dim=-1): 234 | return torch.nan_to_num( 235 | torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) 236 | 237 | def get_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos): 238 | # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 239 | mask_n = np.char.equal(atom_names, b'N') 240 | # print('mask_n', mask_n) 241 | mask_ca = np.char.equal(atom_names, b'CA') 242 | mask_c = np.char.equal(atom_names, b'C') 243 | mask_cb = np.char.equal(atom_names, b'CB') 244 | mask_g = np.char.equal(atom_names, b'CG') | np.char.equal(atom_names, b'SG') | np.char.equal(atom_names, 245 | b'OG') | np.char.equal( 246 | atom_names, b'CG1') | np.char.equal(atom_names, b'OG1') 247 | mask_d = np.char.equal(atom_names, b'CD') | np.char.equal(atom_names, b'SD') | np.char.equal(atom_names, 248 | b'CD1') | np.char.equal( 249 | atom_names, b'OD1') | np.char.equal(atom_names, b'ND1') 250 | mask_e = np.char.equal(atom_names, b'CE') | np.char.equal(atom_names, b'NE') | np.char.equal(atom_names, b'OE1') 251 | mask_z = np.char.equal(atom_names, b'CZ') | np.char.equal(atom_names, b'NZ') 252 | mask_h = np.char.equal(atom_names, b'NH1') 253 | 254 | pos_n = np.full((len(amino_types), 3), np.nan) 255 | # print('pos_n', pos_n) 256 | pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n] 257 | pos_n = torch.FloatTensor(pos_n) 258 | 259 | pos_ca = np.full((len(amino_types), 3), np.nan) 260 | pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca] 261 | pos_ca = torch.FloatTensor(pos_ca) 262 | 263 | pos_c = np.full((len(amino_types), 3), np.nan) 264 | pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c] 265 | pos_c = torch.FloatTensor(pos_c) 266 | 267 | # if data only contain pos_ca, we set the position of C and N as the position of CA 268 | pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)] 269 | pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)] 270 | 271 | # if data only contain pos_n, we set the position of CA and C as the position of N 272 | pos_ca[torch.isnan(pos_ca)] = pos_n[torch.isnan(pos_ca)] 273 | pos_c[torch.isnan(pos_c)] = pos_n[torch.isnan(pos_c)] 274 | 275 | # if data only contain pos_c, we set the position of N ad CA as the position of C 276 | pos_ca[torch.isnan(pos_ca)] = pos_c[torch.isnan(pos_ca)] 277 | pos_n[torch.isnan(pos_n)] = pos_c[torch.isnan(pos_n)] 278 | 279 | pos_cb = np.full((len(amino_types), 3), np.nan) 280 | pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb] 281 | pos_cb = torch.FloatTensor(pos_cb) 282 | 283 | pos_g = np.full((len(amino_types), 3), np.nan) 284 | pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g] 285 | pos_g = torch.FloatTensor(pos_g) 286 | 287 | pos_d = np.full((len(amino_types), 3), np.nan) 288 | pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d] 289 | pos_d = torch.FloatTensor(pos_d) 290 | 291 | pos_e = np.full((len(amino_types), 3), np.nan) 292 | pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e] 293 | pos_e = torch.FloatTensor(pos_e) 294 | 295 | pos_z = np.full((len(amino_types), 3), np.nan) 296 | pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z] 297 | pos_z = torch.FloatTensor(pos_z) 298 | 299 | pos_h = np.full((len(amino_types), 3), np.nan) 300 | pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h] 301 | pos_h = torch.FloatTensor(pos_h) 302 | 303 | return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h 304 | 305 | def side_chain_embs(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): 306 | v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z 307 | 308 | angle1 = torch.unsqueeze(self.compute_dihedrals(v1, v2, v3), 1) 309 | angle2 = torch.unsqueeze(self.compute_dihedrals(v2, v3, v4), 1) 310 | angle3 = torch.unsqueeze(self.compute_dihedrals(v3, v4, v5), 1) 311 | angle4 = torch.unsqueeze(self.compute_dihedrals(v4, v5, v6), 1) 312 | angle5 = torch.unsqueeze(self.compute_dihedrals(v5, v6, v7), 1) 313 | 314 | side_chain_angles = torch.cat((angle1, angle2, angle3, angle4), 1) 315 | side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)), 1) 316 | 317 | return side_chain_embs 318 | 319 | def bb_embs(self, X): 320 | X = torch.reshape(X, [3 * X.shape[0], 3]) 321 | dX = X[1:] - X[:-1] 322 | U = self._normalize(dX, dim=-1) 323 | u0 = U[:-2] 324 | u1 = U[1:-1] 325 | u2 = U[2:] 326 | 327 | angle = self.compute_dihedrals(u0, u1, u2) 328 | angle = F.pad(angle, [1, 2]) 329 | angle = torch.reshape(angle, [-1, 3]) 330 | angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1) 331 | return angle_features 332 | 333 | def compute_dihedrals(self, v1, v2, v3): 334 | n1 = torch.cross(v1, v2) 335 | n2 = torch.cross(v2, v3) 336 | a = (n1 * n2).sum(dim=-1) 337 | b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1)) 338 | torsion = torch.nan_to_num(torch.atan2(b, a)) 339 | return torsion 340 | 341 | def get_data_from_pdb(self, pdb): 342 | amino_acids_mapping = {'GLY': 0, 'PRO': 1, 'CYS': 2, 'TRP': 3, 'PHE': 4, 'ASP': 5, 'SER': 6, 'ASN': 7, 'GLU': 8, 343 | 'HIS': 9, 'ALA': 10, 'ARG': 11, 'THR': 12, 'MET': 13, 'TYR': 14, 'GLN': 15, 'LEU': 16, 344 | 'ILE': 17, 'LYS': 18, 'VAL': 19} 345 | 346 | atom_names = [] 347 | atom_pos = [] 348 | for line in open(pdb, 'r').readlines(): 349 | if line.startswith('ATOM'): 350 | atom_num = int(line[6:11].strip()) 351 | atom_name = line[12:16].strip() 352 | residue_name = line[17:20].strip() 353 | chain_id = line[21] 354 | res_num = int(line[22:26].strip()) 355 | x = float(line[30:38].strip()) 356 | y = float(line[38:46].strip()) 357 | z = float(line[46:54].strip()) 358 | 359 | atom_names.append(bytes(atom_name, 'ascii')) 360 | atom_pos.append([x, y, z]) 361 | 362 | with open(pdb, 'r') as f: 363 | lines = f.readlines() 364 | residue_list = [] 365 | atom_amino_id = [] 366 | current_residue = None 367 | amino_id = 0 368 | 369 | for line in lines: 370 | if line.startswith('ATOM') or line.startswith('HETATM'): 371 | residue_name = line[17:20].strip() 372 | residue_number = int(line[22:26].strip()) 373 | 374 | if current_residue is None or residue_number != current_residue[1]: 375 | if current_residue is not None: 376 | residue_list.append(current_residue[0]) 377 | amino_id += 1 378 | current_residue = (residue_name, residue_number, amino_id) 379 | atom_amino_id.append(amino_id) 380 | 381 | if current_residue is not None: 382 | residue_list.append(current_residue[0]) 383 | amino_types = [amino_acids_mapping[x] for x in residue_list] 384 | 385 | return np.array(amino_types), np.array(atom_names), np.array(atom_amino_id), np.array(atom_pos) 386 | 387 | def calculate_pdb(self, pdb_file): 388 | amino_types, atom_names, atom_amino_id, atom_pos = self.get_data_from_pdb(pdb=pdb_file) 389 | pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h = self.get_atom_pos(amino_types, atom_names, 390 | atom_amino_id, 391 | atom_pos=atom_pos) 392 | side_chain_embs = self.side_chain_embs(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h) 393 | side_chain_embs[torch.isnan(side_chain_embs)] = 0 394 | bb_embs = self.bb_embs( 395 | torch.cat((torch.unsqueeze(pos_n, 1), torch.unsqueeze(pos_ca, 1), torch.unsqueeze(pos_c, 1)), 1)) 396 | bb_embs[torch.isnan(bb_embs)] = 0 397 | 398 | data = Data( 399 | side_chain_embs=side_chain_embs, 400 | bb_embs=bb_embs, 401 | x=torch.unsqueeze(torch.tensor(amino_types), 1), 402 | coords_ca=pos_ca, 403 | coords_n=pos_n, 404 | coords_c=pos_c 405 | ) 406 | assert len(data.x) == len(data.coords_ca) == len(data.coords_n) == len(data.coords_c) == len( 407 | data.side_chain_embs) == len(data.bb_embs) 408 | 409 | return torch.unsqueeze(torch.tensor(amino_types), 1), pos_c, pos_ca, pos_n, bb_embs, side_chain_embs 410 | 411 | def process(self): 412 | data_list = [] 413 | df = pd.read_csv('mutate.csv') 414 | for i in range(0, df.shape[0]): 415 | index = i 416 | print(i) 417 | antibody_path = df['antibody_path'][i] 418 | antigen_path = df['antigen_path'][i] 419 | 420 | ab_x, ab_pos_c, ab_pos_ca, ab_pos_n, ab_bb_embs, ab_side_chain_embs = self.calculate_pdb( 421 | pdb_file=antibody_path) 422 | 423 | ag_x, ag_pos_c, ag_pos_ca, ag_pos_n, ag_bb_embs, ag_side_chain_embs = self.calculate_pdb( 424 | pdb_file=antigen_path) 425 | 426 | affinity = torch.tensor(df['delta_g'][i]) 427 | 428 | data_list.append(Data( 429 | x=ab_x, 430 | y=affinity, 431 | label=index, 432 | side_chain_embs=ab_side_chain_embs, 433 | bb_embs=ab_bb_embs, 434 | coords_ca=ab_pos_ca, 435 | coords_n=ab_pos_n, 436 | coords_c=ab_pos_c, 437 | antigen=Data( 438 | x=ag_x, 439 | y=affinity, 440 | label=index, 441 | side_chain_embs=ag_side_chain_embs, 442 | bb_embs=ag_bb_embs, 443 | coords_ca=ag_pos_ca, 444 | coords_n=ag_pos_n, 445 | coords_c=ag_pos_c, 446 | ) 447 | )) 448 | 449 | data, slices = self.collate(data_list) 450 | torch.save((data, slices), self.processed_paths[0]) 451 | 452 | -------------------------------------------------------------------------------- /PLM.py: -------------------------------------------------------------------------------- 1 | """ 2 | PLM: Esm2_150M, ProtBert 3 | ALM: Ablang, AntiBERTy, BERT2DAb 4 | """ 5 | import multiprocessing 6 | 7 | import pandas as pd 8 | from sklearn.ensemble import RandomForestRegressor 9 | from transformers import T5Tokenizer, T5EncoderModel 10 | import torch 11 | import re 12 | from pathlib import Path 13 | import numpy as np 14 | from sklearn.linear_model import LinearRegression 15 | from sklearn.model_selection import KFold, train_test_split 16 | from sklearn.metrics import mean_absolute_error, mean_squared_error 17 | from sklearn.decomposition import PCA 18 | from scipy.stats import pearsonr 19 | import esm 20 | import os 21 | 22 | 23 | def extract_from_ProtTrans(): 24 | df = pd.read_csv('7_26_mutate_resample.csv') 25 | 26 | affinity = df['delta_g'].values 27 | # df['sequence1'] = df['Sequence'] 28 | # df['sequence2'] = df['Target'] 29 | # df['sequence1'] = df['seq_ab'] 30 | # df['sequence2'] = df['seq_ag'] 31 | # affinity = df['delta_g'].values 32 | df['sequence1'] = df['antibody_Hchain_sequence'].fillna('') + df['antibody_Lchain_sequence'].fillna('') 33 | df['sequence2'] = df['antigen_sequence'] 34 | 35 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 36 | 37 | tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False) 38 | 39 | model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device) 40 | model.full() if device == 'cpu' else model.half() 41 | 42 | features = [] 43 | 44 | for i in range(0, df.shape[0]): 45 | print(i) 46 | # prepare your protein sequences as a list 47 | sequence_examples = [df['sequence1'][i], df['sequence2'][i]] 48 | 49 | sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples] 50 | 51 | ids = tokenizer(sequence_examples, add_special_tokens=True, padding="longest") 52 | 53 | input_ids = torch.tensor(ids['input_ids']).to(device) 54 | attention_mask = torch.tensor(ids['attention_mask']).to(device) 55 | 56 | # generate embeddings 57 | with torch.no_grad(): 58 | embedding_repr = model(input_ids=input_ids, attention_mask=attention_mask) 59 | 60 | emb_0 = embedding_repr.last_hidden_state[0, :len(df['sequence1'][i])] # shape (sequence_length x 1024) 61 | # embedding_path = r"E:\ProgrammingSpace\Gitee\data\DeepAntibody\ProtTrans_embeddings\{}.npy".format(str(df['Index'][i])) 62 | # np.save(embedding_path, emb_0.cpu().numpy()) 63 | emb_1 = embedding_repr.last_hidden_state[1, :len(df['sequence2'][i])] 64 | 65 | emb_0_per_protein = emb_0.mean(dim=0) # shape (1024) 66 | emb_1_per_protein = emb_1.mean(dim=0) 67 | 68 | feature1 = emb_0_per_protein.cpu().numpy() 69 | feature2 = emb_1_per_protein.cpu().numpy() 70 | 71 | feature = np.concatenate((feature1, feature2)) 72 | features.append(feature) 73 | 74 | current_path = Path.cwd() 75 | embedding_path = current_path.joinpath('mutate_embeddings') 76 | embedding_file = embedding_path / 'ProtTrans_mutate_726.npy' 77 | affinity_file = embedding_path / 'affinity_726.npy' 78 | np.save(str(embedding_file), np.array(features)) 79 | np.save(str(affinity_file), affinity) 80 | 81 | 82 | def extract_from_ESM2_650M(): 83 | """ 84 | ESM-2能处理的最大蛋白质序列长度为1024,超过1024的序列需要切片处理 85 | embedding为1280 DIM 86 | :return: 87 | """ 88 | df = pd.read_csv('7_26_mutate_resample.csv') 89 | # df['sequence1'] = df['Sequence'] 90 | # df['sequence2'] = df['Target'] 91 | # df['sequence1'] = df['seq_ab'] 92 | # df['sequence2'] = df['seq_ag'] 93 | df['sequence1'] = df['antibody_Hchain_sequence'].fillna('') + df['antibody_Lchain_sequence'].fillna('') 94 | df['sequence2'] = df['antigen_sequence'] 95 | 96 | def split_sequences(s, max_length=1024): 97 | if len(s) <= max_length: 98 | return [('s1', s)] 99 | else: 100 | num_splits = len(s) // max_length 101 | splits = [] 102 | for i in range(num_splits): 103 | start = i * max_length 104 | end = (i + 1) * max_length 105 | splits.append(('s{}'.format(i + 1), s[start:end])) 106 | 107 | remaining = len(s) % max_length 108 | if remaining > 0: 109 | splits.append(('s{}'.format(num_splits + 1), s[-remaining:])) 110 | 111 | return splits 112 | 113 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 114 | 115 | # Load ESM-2 model 116 | model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() 117 | model = model.to(device) 118 | batch_converter = alphabet.get_batch_converter() 119 | model.eval() 120 | 121 | antigen_features = [] 122 | antibody_features = [] 123 | 124 | for i in range(0, df.shape[0]): 125 | print(i) 126 | 127 | # ================================================================================================= 128 | # 提取抗原特征 129 | antigens = split_sequences(df['sequence2'][i]) 130 | # 判断切片长度 131 | if len(antigens) < 2: 132 | batch_labels, batch_strs, batch_tokens = batch_converter(antigens) 133 | batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) 134 | batch_tokens = batch_tokens.to(device) 135 | 136 | with torch.no_grad(): 137 | batch_tokens = batch_tokens.to(device) 138 | results = model(batch_tokens, repr_layers=[33], return_contacts=True) 139 | token_representations = results["representations"][33] 140 | 141 | sequence_representations = [] 142 | for i, tokens_len in enumerate(batch_lens): 143 | sequence_representations.append(token_representations[i, 1: tokens_len - 1].mean(0)) 144 | 145 | antigen_features.append(sequence_representations[0].cpu().numpy()) 146 | # 遍历切片list 147 | else: 148 | representations = [] 149 | 150 | for antigen in antigens: 151 | # 装箱 152 | _, _, batch_tokens = batch_converter([antigen]) 153 | batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) 154 | batch_tokens = batch_tokens.to(device) 155 | with torch.no_grad(): 156 | batch_tokens = batch_tokens.to(device) 157 | results = model(batch_tokens, repr_layers=[33], return_contacts=True) 158 | token_representations = results["representations"][33] 159 | 160 | sequence_representations = [] 161 | for i, tokens_len in enumerate(batch_lens): 162 | sequence_representations.append(token_representations[i, 1: tokens_len - 1].mean(0)) 163 | 164 | representations.append(sequence_representations[0]) 165 | 166 | # 初始化总和矩阵 167 | sum = torch.zeros_like(representations[0]) 168 | for t in representations: 169 | sum += t 170 | 171 | average_representation = sum / len(representations) 172 | print(average_representation.shape) 173 | antigen_features.append(average_representation.cpu().numpy()) 174 | 175 | # =============================================================================================== 176 | # 提取抗体特征 177 | antibodies = split_sequences(df['sequence1'][i]) 178 | # 判断切片长度 179 | if len(antibodies) < 2: 180 | batch_labels, batch_strs, batch_tokens = batch_converter(antibodies) 181 | batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) 182 | batch_tokens = batch_tokens.to(device) 183 | 184 | with torch.no_grad(): 185 | batch_tokens = batch_tokens.to(device) 186 | results = model(batch_tokens, repr_layers=[33], return_contacts=True) 187 | token_representations = results["representations"][33] 188 | 189 | sequence_representations = [] 190 | for i, tokens_len in enumerate(batch_lens): 191 | sequence_representations.append(token_representations[i, 1: tokens_len - 1].mean(0)) 192 | antibody_features.append(sequence_representations[0].cpu().numpy()) 193 | # 遍历切片list 194 | else: 195 | representations = [] 196 | 197 | for antibody in antibodies: 198 | # 装箱 199 | _, _, batch_tokens = batch_converter([antibody]) 200 | batch_lens = (batch_tokens != alphabet.padding_idx).sum(1) 201 | batch_tokens = batch_tokens.to(device) 202 | with torch.no_grad(): 203 | batch_tokens = batch_tokens.to(device) 204 | results = model(batch_tokens, repr_layers=[33], return_contacts=True) 205 | token_representations = results["representations"][33] 206 | 207 | sequence_representations = [] 208 | for i, tokens_len in enumerate(batch_lens): 209 | sequence_representations.append(token_representations[i, 1: tokens_len - 1].mean(0)) 210 | 211 | representations.append(sequence_representations[0]) 212 | 213 | # 初始化总和矩阵 214 | sum = torch.zeros_like(representations[0]) 215 | for t in representations: 216 | sum += t 217 | 218 | average_representation = sum / len(representations) 219 | antibody_features.append(average_representation.cpu().numpy()) 220 | 221 | # =============================================================================================== 222 | # 拼接抗原和抗体的特征 223 | # print(antigen_features) 224 | # print(len(antigen_features)) 225 | # print(len(antigen_features[0])) 226 | # 227 | # print('=====') 228 | # print(len(antibody_features)) 229 | # print(antibody_features) 230 | # print(len(antibody_features[0])) 231 | 232 | features = np.concatenate((antibody_features, antigen_features), axis=1) 233 | print(features.shape) 234 | current_path = Path.cwd() 235 | embedding_path = current_path.joinpath('mutate_embeddings') 236 | embedding_file = embedding_path / 'ESM2_mutate_726.npy' 237 | print(embedding_file) 238 | np.save(str(embedding_file), np.array(features)) 239 | 240 | 241 | def extract_from_AbLang(): 242 | """ 243 | length_a_sequence * 480 244 | :return: 245 | """ 246 | import ablang2 247 | ablang = ablang2.pretrained(model_to_use='ablang2-paired', random_init=False, device='cuda') 248 | df = pd.read_csv('7_26_mutate_resample.csv') 249 | # df['sequence1'] = df['Sequence'] 250 | # df['sequence2'] = df['Target'] 251 | # df['sequence1'] = df['seq_ab'] 252 | # df['sequence2'] = df['seq_ag'] 253 | # affinity = df['delta_g'].values 254 | df['sequence1'] = df['antibody_Hchain_sequence'].fillna('') + df['antibody_Lchain_sequence'].fillna('') 255 | df['sequence2'] = df['antigen_sequence'] 256 | # Download and initialise the model 257 | antibody_feature = [] 258 | antigen_feature = [] 259 | for i in range(0, df.shape[0]): 260 | seq = [df['sequence1'][i]] 261 | 262 | # Tokenize input sequences 263 | tokenized_seq = ablang.tokenizer(seq, pad=True, w_extra_tkns=False, device="cuda") 264 | 265 | # Generate rescodings 266 | with torch.no_grad(): 267 | rescoding = ablang.AbRep(tokenized_seq).last_hidden_states 268 | 269 | rep = torch.mean(rescoding[0], dim=0, keepdim=True) 270 | antibody_feature.append(rep.cpu().numpy()[0]) 271 | 272 | # print(rescoding) 273 | # print(rescoding.shape) # 274 | # print(rescoding[0].shape) 275 | 276 | for i in range(0, df.shape[0]): 277 | print(i) 278 | seq = [df['sequence2'][i]] 279 | 280 | # Tokenize input sequences 281 | tokenized_seq = ablang.tokenizer(seq, pad=True, w_extra_tkns=False, device="cuda") 282 | 283 | # Generate rescodings 284 | with torch.no_grad(): 285 | rescoding = ablang.AbRep(tokenized_seq).last_hidden_states 286 | 287 | rep = torch.mean(rescoding[0], dim=0, keepdim=True) 288 | antigen_feature.append(rep.cpu().numpy()[0]) 289 | 290 | features = np.concatenate((antibody_feature, antigen_feature), axis=1) 291 | print(features.shape) 292 | 293 | current_path = Path.cwd() 294 | embedding_path = current_path.joinpath('mutate_embeddings') 295 | embedding_file = embedding_path / 'AbLang_mutate_726.npy' 296 | 297 | np.save(str(embedding_file), np.array(features)) 298 | 299 | 300 | def extract_from_AntiBERTy(): 301 | """ 302 | (length_a_sequence + 2) * 512 303 | :return: 304 | """ 305 | from antiberty import AntiBERTyRunner 306 | 307 | df = pd.read_csv('7_26_mutate_resample.csv') 308 | # df['sequence1'] = df['Sequence'] 309 | # df['sequence2'] = df['Target'] 310 | # df['sequence1'] = df['seq_ab'] 311 | # df['sequence2'] = df['seq_ag'] 312 | # affinity = df['delta_g'].values 313 | df['sequence1'] = df['antibody_Hchain_sequence'].fillna('') + df['antibody_Lchain_sequence'].fillna('') 314 | df['sequence2'] = df['antigen_sequence'] 315 | antiberty = AntiBERTyRunner() 316 | 317 | antibody_feature = [] 318 | antigen_feature = [] 319 | 320 | for i in range(0, df.shape[0]): 321 | print(i) 322 | if len(df['sequence1'][i]) > 510: 323 | sequences = [df['sequence1'][i][:510]] 324 | else: 325 | sequences = [df['sequence1'][i]] 326 | embeddings = antiberty.embed(sequences) 327 | 328 | # print(embeddings) 329 | # print(len(embeddings)) # len of sequences 330 | # print(embeddings[0].shape) 331 | # print(embeddings[0]) 332 | 333 | average_embeddings = torch.mean(embeddings[0], dim=0) 334 | # print(average_embeddings.shape) 335 | antibody_feature.append(average_embeddings.cpu().numpy()) 336 | 337 | for i in range(0, df.shape[0]): 338 | print(i) 339 | 340 | if len(df['sequence2'][i]) <= 512: 341 | sequences = [df['sequence2'][i]] 342 | embeddings = antiberty.embed(sequences) 343 | average_embeddings = torch.mean(embeddings[0], dim=0) 344 | antigen_feature.append(average_embeddings.cpu().numpy()) 345 | 346 | else: 347 | num_segments = len(df['sequence2'][i]) // 512 348 | remainder = len(df['sequence2'][i]) % 512 349 | sequences = [] 350 | for i in range(num_segments): 351 | sequences.append(df['sequence2'][i][i * 512: (i + 1) * 512]) 352 | if remainder > 0: 353 | sequences.append(df['sequence2'][i][num_segments * 512:]) 354 | temp = [] 355 | for sequence in sequences: 356 | embeddings = antiberty.embed([sequence]) 357 | average_embeddings = torch.mean(embeddings[0], dim=0) 358 | temp.append(average_embeddings.cpu().numpy()) 359 | antigen_feature.append(np.mean(np.array(temp), axis=0)) 360 | 361 | features = np.concatenate((antibody_feature, antigen_feature), axis=1) 362 | print(features.shape) 363 | 364 | current_path = Path.cwd() 365 | embedding_path = current_path.joinpath('mutate_embeddings') 366 | embedding_file = embedding_path / 'AntiBERTy_mutate_726.npy' 367 | 368 | np.save(str(embedding_file), np.array(features)) 369 | 370 | 371 | def extract_from_BERT2DAb(): 372 | from transformers import BertTokenizer, BertModel 373 | import ast 374 | df = pd.read_csv(r'D:\wild_89.csv') 375 | 376 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 377 | 378 | # # ============================================================================================= 379 | # 提取重链特征 380 | tokenizer_H = BertTokenizer.from_pretrained("w139700701/BERT2DAb_H") 381 | model_H = BertModel.from_pretrained("w139700701/BERT2DAb_H") 382 | model_H.to(device) 383 | 384 | hchains = [] 385 | for i in range(0, df.shape[0]): 386 | print(i) 387 | if pd.isnull(df['a'][i]): 388 | # 没有轻链则填充空 389 | hchains.append(np.zeros(768)) 390 | else: 391 | 392 | H_chain = ast.literal_eval(df['a'][i]) 393 | if len(H_chain) <= 126: 394 | encoded_input = tokenizer_H.encode_plus( 395 | H_chain, 396 | padding=True, 397 | add_special_tokens=True, 398 | return_tensors="pt" 399 | ) 400 | # 将编码后的文本数据转换为张量并移动到设备上 401 | input_ids = encoded_input["input_ids"].to(device) 402 | # print(input_ids) 403 | attention_mask = encoded_input["attention_mask"].to(device) 404 | # 获取模型的输出(嵌入向量) 405 | with torch.no_grad(): 406 | outputs = model_H(input_ids, attention_mask=attention_mask) 407 | 408 | # 获取嵌入向量 409 | embeddings = outputs.last_hidden_state 410 | # print('================={}============='.format(embeddings.shape)) 411 | # print(embeddings) 412 | # print(embeddings.shape) 413 | # 嵌入的第一行与最后一行是起始子和终止子 414 | average_embeddings = torch.mean(embeddings[0], dim=0, keepdim=True) 415 | # print(average_embeddings) 416 | # print(average_embeddings.shape) 417 | 418 | else: 419 | A_chains = [H_chain[i:i + 126] for i in range(0, len(H_chain), 126)] 420 | embeds = [] 421 | for chain in A_chains: 422 | encoded_input = tokenizer_H.encode_plus( 423 | chain, 424 | padding=True, 425 | add_special_tokens=True, 426 | return_tensors="pt" 427 | ) 428 | input_ids = encoded_input["input_ids"].to(device) 429 | attention_mask = encoded_input["attention_mask"].to(device) 430 | with torch.no_grad(): 431 | outputs = model_H(input_ids, attention_mask=attention_mask) 432 | embeddings = outputs.last_hidden_state 433 | average = torch.mean(embeddings[0], dim=0, keepdim=True) 434 | embeds.append(average) 435 | average_embeddings = torch.mean(torch.stack(embeds, dim=0), dim=0) 436 | hchains.append(average_embeddings[0].detach().cpu().numpy()) 437 | # 438 | # # ============================================================================================= 439 | # 提取轻链特征 440 | lchains = [] 441 | tokenizer_L = BertTokenizer.from_pretrained("w139700701/BERT2DAb_L") 442 | model_L = BertModel.from_pretrained("w139700701/BERT2DAb_L") 443 | model_L.to(device) 444 | for i in range(0, df.shape[0]): 445 | print(i) 446 | if pd.isnull(df['b'][i]): 447 | print('=============================================') 448 | # 没有轻链则填充空 449 | lchains.append(np.zeros(768)) 450 | else: 451 | L_chian = ast.literal_eval(df['b'][i]) 452 | if len(L_chian) <= 126: 453 | encoded_input = tokenizer_L.encode_plus( 454 | L_chian, 455 | padding=True, 456 | add_special_tokens=True, 457 | return_tensors="pt" 458 | ) 459 | input_ids = encoded_input["input_ids"].to(device) 460 | attention_mask = encoded_input["attention_mask"].to(device) 461 | with torch.no_grad(): 462 | outputs = model_L(input_ids, attention_mask=attention_mask) 463 | embeddings = outputs.last_hidden_state 464 | average_embeddings = torch.mean(embeddings[0], dim=0, keepdim=True) 465 | 466 | else: 467 | A_chains = [L_chian[i:i + 126] for i in range(0, len(L_chian), 126)] 468 | embeds = [] 469 | for chain in A_chains: 470 | encoded_input = tokenizer_L.encode_plus( 471 | chain, 472 | padding=True, 473 | add_special_tokens=True, 474 | return_tensors="pt" 475 | ) 476 | input_ids = encoded_input["input_ids"].to(device) 477 | attention_mask = encoded_input["attention_mask"].to(device) 478 | with torch.no_grad(): 479 | outputs = model_L(input_ids, attention_mask=attention_mask) 480 | embeddings = outputs.last_hidden_state 481 | average = torch.mean(embeddings[0], dim=0, keepdim=True) 482 | embeds.append(average) 483 | average_embeddings = torch.mean(torch.stack(embeds, dim=0), dim=0) 484 | lchains.append(average_embeddings[0].detach().cpu().numpy()) 485 | 486 | # ============================================================================================= 487 | # 提取抗原特征 488 | tokenizer_A = BertTokenizer.from_pretrained("w139700701/BERT2DAb_H") 489 | model_A = BertModel.from_pretrained("w139700701/BERT2DAb_H") 490 | model_A.to(device) 491 | 492 | achains = [] 493 | for i in range(0, df.shape[0]): 494 | print(i) # 80 495 | A_chain = ast.literal_eval(df['c'][i]) 496 | 497 | if len(A_chain) <= 126: 498 | # print(len(A_chain)) 499 | 500 | encoded_input = tokenizer_A.encode_plus( 501 | A_chain, 502 | padding=True, 503 | add_special_tokens=True, 504 | return_tensors="pt" 505 | ) 506 | 507 | input_ids = encoded_input["input_ids"].to(device) 508 | # print(input_ids) 509 | # print(len(input_ids)) 510 | attention_mask = encoded_input["attention_mask"].to(device) 511 | 512 | with torch.no_grad(): 513 | outputs = model_A(input_ids, attention_mask=attention_mask) 514 | 515 | embeddings = outputs.last_hidden_state 516 | 517 | average_embeddings = torch.mean(embeddings[0], dim=0, keepdim=True) 518 | 519 | 520 | else: 521 | A_chains = [A_chain[i:i + 126] for i in range(0, len(A_chain), 126)] 522 | embeds = [] 523 | for chain in A_chains: 524 | encoded_input = tokenizer_A.encode_plus( 525 | chain, 526 | padding=True, 527 | add_special_tokens=True, 528 | return_tensors="pt" 529 | ) 530 | 531 | input_ids = encoded_input["input_ids"].to(device) 532 | 533 | attention_mask = encoded_input["attention_mask"].to(device) 534 | 535 | with torch.no_grad(): 536 | outputs = model_A(input_ids, attention_mask=attention_mask) 537 | 538 | embeddings = outputs.last_hidden_state 539 | 540 | average = torch.mean(embeddings[0], dim=0, keepdim=True) 541 | embeds.append(average) 542 | average_embeddings = torch.mean(torch.stack(embeds, dim=0), dim=0) 543 | 544 | # print(average_embeddings) 545 | achains.append(average_embeddings[0].detach().cpu().numpy()) 546 | # ============================================================================================= 547 | 548 | features = np.concatenate((hchains, lchains, achains), axis=1) 549 | print(features.shape) 550 | 551 | current_path = Path.cwd() 552 | embedding_path = current_path.joinpath('wild_embeddings') 553 | embedding_file = embedding_path / 'BERT2DAb_features810.npy' 554 | affinity_file = embedding_path / 'affinity_89.npy' 555 | affinity = df['delta_g'].values 556 | np.save(str(affinity_file), affinity) 557 | np.save(str(embedding_file), np.array(features)) 558 | 559 | 560 | def compare_different_embeddings(): 561 | current_path = Path.cwd() 562 | embedding_path = current_path.joinpath('alphaseq_embeddings') 563 | 564 | embedding_paths = [ 565 | # 'ESM2_alphaseq.npy', 566 | 'ProtTrans_alphaseq.npy', 567 | 'AbLang_alphaseq.npy', 568 | # 'AntiBERTy_alphaseq.npy', 569 | # 'BERT2DAb_features.npy' 570 | ] 571 | 572 | original_affinity = embedding_path / 'affinity.npy' 573 | original_affinity = np.load(str(original_affinity)) 574 | 575 | for e in embedding_paths: 576 | e = embedding_path / e 577 | # 读取特征 578 | embedding = np.load(str(e)) 579 | print(embedding.shape) 580 | pca = PCA(n_components=0.99) 581 | embedding = pca.fit_transform(X=embedding) 582 | 583 | mae_scores = [] 584 | rmse_scores = [] 585 | pearson_scores = [] 586 | 587 | for _ in range(1): 588 | kf = KFold(n_splits=10, shuffle=True, random_state=42) 589 | for fold_idx, (train_idx, test_idx) in enumerate(kf.split(range(embedding.shape[0]))): 590 | # print('fold{}'.format(fold_idx+1)) 591 | model = RandomForestRegressor(n_estimators=50, n_jobs=multiprocessing.cpu_count(), max_depth=15, 592 | random_state=42) 593 | X_train, X_test = embedding[train_idx], embedding[test_idx] 594 | y_train, y_test = original_affinity[train_idx], original_affinity[test_idx] 595 | 596 | model.fit(X_train, y_train) 597 | y_pred_fold = model.predict(X_test) 598 | 599 | mae_scores.append(mean_absolute_error(y_test, y_pred_fold)) 600 | rmse_scores.append(np.sqrt(mean_squared_error(y_test, y_pred_fold))) 601 | pearson_scores.append(pearsonr(y_test, y_pred_fold)[0]) 602 | 603 | mae_mean = np.mean(mae_scores) 604 | mae_std = np.std(mae_scores) 605 | 606 | rmse_mean = np.mean(rmse_scores) 607 | rmse_std = np.std(rmse_scores) 608 | 609 | pearson_mean = np.mean(pearson_scores) 610 | pearson_std = np.std(pearson_scores) 611 | 612 | # 输出结果 613 | print(str(e)[:-4]) 614 | print("RMSE: {}±{}".format(round(rmse_mean, 4), round(rmse_std, 4))) 615 | print("MAE: {}±{}".format(round(mae_mean, 4), round(mae_std, 4))) 616 | print("PCC: {}±{}".format(round(pearson_mean, 4), round(pearson_std, 4))) 617 | 618 | 619 | if __name__ == "__main__": 620 | # extract_from_ProtTrans() 621 | # extract_from_ESM2_650M() 622 | # extract_from_AbLang() 623 | # extract_from_AntiBERTy() 624 | # extract_from_BERT2DAb() 625 | 626 | compare_different_embeddings() 627 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ⚠️ Project under development 2 | 3 | # MuLAAIP: Multi-Modality Representation Learning for Antibody-Antigen Interaction Prediction 4 | 5 | ![MuLAAIP Pipeline](pipline.png) 6 | 7 | This repository contains the official implementation of **MuLAAIP**, a novel deep learning framework for predicting antibody-antigen interactions (AAI) by integrating **3D structural** and **1D sequence** data. Our approach addresses critical challenges in AAI prediction, including structural data scarcity, sequence-structure dependency modeling, and imbalanced label distributions. 8 | 9 | ## 📁 Benchmark Datasets 10 | 11 | ### Dataset Summary 12 | | Dataset | Type | Samples | Description | 13 | |--------|------|---------|-------------| 14 | | **Wild-type/Mutant-type Affinity** | Affinity Labeling | 1,191 / 1,742 pairs | Antibody-antigen binding affinity| 15 | | **Alphaseq** | Affinity Labeling | 248k antibodies | Antibody-antigen binding affinity | 16 | | **SARS-CoV-2 Neutralization** | Binary Classification | 310 pairs (228+/82-) | Neutralization activity labels | 17 | 18 | > All missing experimental structures were predicted using **ESMFold** (https://github.com/facebookresearch/esm). 19 | 20 | ## 📥 Data Acquisition 21 | 22 | ### Download Instructions 23 | 1. **Get Data**: 24 | [Baidu Cloud Link (Password: iuqs)](https://pan.baidu.com/s/1HqXfAUIjGp6h1gh3M2Pa8Q ) 25 | 26 | 27 | ## Installation 28 | ```bash 29 | # Clone the repo 30 | git clone https://github.com/trashTian/MuLAAIP.git 31 | cd MuLAAIP 32 | 33 | # Install dependencies 34 | pip install -r requirements.txt 35 | ``` 36 | ## Data pre-processing 37 | (1) **1D Sequence Representation: use pre-trained protein (antibody) language models to process sequence data and obtain embeddings. For example, ProtTrans, ESM2, AbLang, AntiBERTy,BERT2DAb** 38 | ``` 39 | python PLM.py 40 | ``` 41 | > We have embedded and saved these sequences locally 42 | 43 | (2) **3D Structural Representation: construct fine-grained structural graph.** 44 | ``` 45 | python Dataset.py 46 | ``` 47 | 48 | ## Cross-validation 49 | ``` 50 | python train.py 51 | ``` 52 | 53 | 54 | ## Cite this work 55 | ``` 56 | @article{guo2025multi, 57 | title={Multi-Modality Representation Learning for Antibody-Antigen Interactions Prediction}, 58 | author={Guo, Peijin and Li, Minghui and Pan, Hewen and Huang, Ruixiang and Xue, Lulu and Hu, Shengqing and Guo, Zikang and Wan, Wei and Hu, Shengshan}, 59 | journal={arXiv preprint arXiv:2503.17666}, 60 | year={2025} 61 | } 62 | ``` 63 | 64 | 65 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | 3 | from torch_geometric.nn import inits, MessagePassing, SAGPooling 4 | from torch_geometric.nn import radius_graph 5 | from features import d_angle_emb, d_theta_phi_emb 6 | from torch_scatter import scatter 7 | from torch_sparse import matmul 8 | from torch_geometric.nn import TransformerConv, GATConv, GATv2Conv 9 | import torch 10 | from torch import nn 11 | from torch.nn import Embedding 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.nn.parameter import Parameter 15 | import torch 16 | import torch.nn.functional as F 17 | 18 | import numpy as np 19 | 20 | num_aa_type = 20 21 | num_side_chain_embs = 8 22 | num_bb_embs = 6 23 | 24 | 25 | def swish(x): 26 | return x * torch.sigmoid(x) 27 | 28 | 29 | def mish(x): 30 | return x * torch.tanh(F.softplus(x)) 31 | 32 | 33 | class Linear(torch.nn.Module): 34 | 35 | def __init__(self, in_channels, out_channels, bias=True, weight_initializer='glorot'): 36 | 37 | super().__init__() 38 | self.in_channels = in_channels 39 | self.out_channels = out_channels 40 | self.weight_initializer = weight_initializer 41 | 42 | self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels)) 43 | 44 | if bias: 45 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 46 | else: 47 | self.register_parameter('bias', None) 48 | 49 | self.reset_parameters() 50 | 51 | def reset_parameters(self): 52 | if self.weight_initializer == 'glorot': 53 | inits.glorot(self.weight) 54 | elif self.weight_initializer == 'zeros': 55 | inits.zeros(self.weight) 56 | if self.bias is not None: 57 | inits.zeros(self.bias) 58 | 59 | def forward(self, x): 60 | """""" 61 | return F.linear(x, self.weight, self.bias) 62 | 63 | 64 | class TwoLinear(torch.nn.Module): 65 | 66 | def __init__( 67 | self, 68 | in_channels, 69 | middle_channels, 70 | out_channels, 71 | bias=False, 72 | act=False 73 | ): 74 | super(TwoLinear, self).__init__() 75 | self.lin1 = Linear(in_channels, middle_channels, bias=bias) 76 | self.lin2 = Linear(middle_channels, out_channels, bias=bias) 77 | self.act = act 78 | 79 | def reset_parameters(self): 80 | self.lin1.reset_parameters() 81 | self.lin2.reset_parameters() 82 | 83 | def forward(self, x): 84 | x = self.lin1(x) 85 | if self.act: 86 | x = swish(x) 87 | x = self.lin2(x) 88 | if self.act: 89 | x = swish(x) 90 | return x 91 | 92 | 93 | class InteractionBlock(torch.nn.Module): 94 | def __init__( 95 | self, 96 | hidden_channels, 97 | output_channels, 98 | num_radial, 99 | num_spherical, 100 | num_layers, 101 | mid_emb, 102 | act=swish, 103 | # act=mish, 104 | num_pos_emb=16, 105 | dropout=0, 106 | level='allatom' 107 | ): 108 | super(InteractionBlock, self).__init__() 109 | self.act = act 110 | self.dropout = nn.Dropout(dropout) 111 | 112 | self.conv0 = GATConv(hidden_channels, hidden_channels) 113 | self.conv1 = GATConv(hidden_channels, hidden_channels) 114 | self.conv2 = GATConv(hidden_channels, hidden_channels) 115 | 116 | self.lin_feature0 = TwoLinear(num_radial * num_spherical ** 2, mid_emb, hidden_channels) 117 | if level == 'aminoacid': 118 | self.lin_feature1 = TwoLinear(num_radial * num_spherical, mid_emb, hidden_channels) 119 | elif level == 'backbone' or level == 'allatom': 120 | self.lin_feature1 = TwoLinear(3 * num_radial * num_spherical, mid_emb, hidden_channels) 121 | self.lin_feature2 = TwoLinear(num_pos_emb, mid_emb, hidden_channels) 122 | 123 | self.lin_1 = Linear(hidden_channels, hidden_channels) 124 | self.lin_2 = Linear(hidden_channels, hidden_channels) 125 | 126 | self.lin0 = Linear(hidden_channels, hidden_channels) 127 | self.lin1 = Linear(hidden_channels, hidden_channels) 128 | self.lin2 = Linear(hidden_channels, hidden_channels) 129 | 130 | self.lins_cat = torch.nn.ModuleList() 131 | self.lins_cat.append(Linear(3 * hidden_channels, hidden_channels)) 132 | for _ in range(num_layers - 1): 133 | self.lins_cat.append(Linear(hidden_channels, hidden_channels)) 134 | 135 | self.lins = torch.nn.ModuleList() 136 | for _ in range(num_layers - 1): 137 | self.lins.append(Linear(hidden_channels, hidden_channels)) 138 | self.final = Linear(hidden_channels, output_channels) 139 | self.bn = nn.BatchNorm1d(hidden_channels) 140 | self.bn0 = nn.BatchNorm1d(24) 141 | 142 | # all-atom 143 | self.bn1 = nn.BatchNorm1d(36) 144 | # amino 145 | # self.bn1 = nn.BatchNorm1d(12) 146 | 147 | self.bn_pos_emb = nn.BatchNorm1d(16) 148 | self.bn2 = nn.BatchNorm1d(128) 149 | self.bn3 = nn.BatchNorm1d(128) 150 | self.bn4 = nn.BatchNorm1d(128) 151 | self.reset_parameters() 152 | 153 | def reset_parameters(self): 154 | self.conv0.reset_parameters() 155 | self.conv1.reset_parameters() 156 | self.conv2.reset_parameters() 157 | 158 | self.lin_feature0.reset_parameters() 159 | self.lin_feature1.reset_parameters() 160 | self.lin_feature2.reset_parameters() 161 | 162 | self.lin_1.reset_parameters() 163 | self.lin_2.reset_parameters() 164 | 165 | self.lin0.reset_parameters() 166 | self.lin1.reset_parameters() 167 | self.lin2.reset_parameters() 168 | 169 | for lin in self.lins: 170 | lin.reset_parameters() 171 | for lin in self.lins_cat: 172 | lin.reset_parameters() 173 | 174 | self.final.reset_parameters() 175 | 176 | def forward(self, x, feature0, feature1, pos_emb, edge_index, batch): 177 | x = self.bn(x) 178 | feature0 = self.bn0(feature0) 179 | feature1 = self.bn1(feature1) 180 | 181 | x_lin_1 = self.act(self.lin_1(x)) 182 | x_lin_2 = self.act(self.lin_2(x)) 183 | 184 | feature0 = self.lin_feature0(feature0) 185 | h0 = self.conv0(x_lin_1, edge_index, feature0) 186 | h0 = self.lin0(h0) 187 | h0 = self.act(h0) 188 | h0 = self.bn2(h0) 189 | h0 = self.dropout(h0) 190 | 191 | feature1 = self.lin_feature1(feature1) 192 | h1 = self.conv1(x_lin_1, edge_index, feature1) 193 | h1 = self.lin1(h1) 194 | h1 = self.act(h1) 195 | h1 = self.bn3(h1) 196 | h1 = self.dropout(h1) 197 | 198 | feature2 = self.lin_feature2(pos_emb) 199 | h2 = self.conv2(x_lin_1, edge_index, feature2) 200 | h2 = self.lin2(h2) 201 | h2 = self.act(h2) 202 | h2 = self.bn4(h2) 203 | h2 = self.dropout(h2) 204 | 205 | h = torch.cat((h0, h1, h2), 1) 206 | for lin in self.lins_cat: 207 | h = self.act(lin(h)) 208 | 209 | h = h + x_lin_2 210 | 211 | for lin in self.lins: 212 | h = self.act(lin(h)) 213 | h = self.final(h) 214 | return h 215 | 216 | 217 | class PairNorm(nn.Module): 218 | def __init__(self, mode='PN', scale=1): 219 | assert mode in ['None', 'PN', 'PN-SI', 'PN-SCS'] 220 | super(PairNorm, self).__init__() 221 | self.mode = mode 222 | self.scale = scale 223 | 224 | def forward(self, x): 225 | if self.mode == 'None': 226 | return x 227 | 228 | col_mean = x.mean(dim=0) 229 | if self.mode == 'PN': 230 | x = x - col_mean 231 | rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt() 232 | x = self.scale * x / rownorm_mean 233 | 234 | if self.mode == 'PN-SI': 235 | x = x - col_mean 236 | rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt() 237 | x = self.scale * x / rownorm_individual 238 | 239 | if self.mode == 'PN-SCS': 240 | rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt() 241 | x = self.scale * x / rownorm_individual - col_mean 242 | 243 | return x 244 | 245 | 246 | def get_degree_mat(adj_mat, pow=1, degree_version='v1'): 247 | degree_mat = torch.eye(adj_mat.size()[0]).to(adj_mat.device) 248 | 249 | if degree_version == 'v1': 250 | degree_list = torch.sum((adj_mat > 0), dim=1).float() 251 | elif degree_version == 'v2': 252 | adj_mat_hat = F.relu(adj_mat) 253 | degree_list = torch.sum(adj_mat_hat, dim=1).float() 254 | elif degree_version == 'v3': 255 | degree_list = torch.sum(adj_mat, dim=1).float() 256 | degree_list = F.relu(degree_list) 257 | else: 258 | exit('error degree_version ' + degree_version) 259 | degree_list = torch.pow(degree_list, pow) 260 | degree_mat = degree_mat * degree_list 261 | return degree_mat 262 | 263 | 264 | def get_laplace_mat(adj_mat, type='sym', add_i=False, degree_version='v2'): 265 | if type == 'sym': 266 | # Symmetric normalized Laplacian 267 | if add_i is True: 268 | adj_mat_hat = torch.eye(adj_mat.size()[0]).to(adj_mat.device) + adj_mat 269 | else: 270 | adj_mat_hat = adj_mat 271 | # adj_mat_hat = adj_mat_hat[adj_mat_hat > 0] 272 | degree_mat_hat = get_degree_mat(adj_mat_hat, pow=-0.5, degree_version=degree_version) 273 | # print(degree_mat_hat.dtype, adj_mat_hat.dtype) 274 | laplace_mat = torch.mm(degree_mat_hat, adj_mat_hat) 275 | # print(laplace_mat) 276 | laplace_mat = torch.mm(laplace_mat, degree_mat_hat) 277 | return laplace_mat 278 | elif type == 'rw': 279 | # Random walk normalized Laplacian 280 | adj_mat_hat = torch.eye(adj_mat.size()[0]).to(adj_mat.device) + adj_mat 281 | degree_mat_hat = get_degree_mat(adj_mat_hat, pow=-1) 282 | laplace_mat = torch.mm(degree_mat_hat, adj_mat_hat) 283 | return laplace_mat 284 | 285 | 286 | class GCNConvL(nn.Module): 287 | def __init__(self, 288 | in_channels, 289 | out_channels, 290 | improved=False, 291 | dropout=0.6, 292 | bias=True 293 | ): 294 | super(GCNConvL, self).__init__() 295 | self.in_channels = in_channels 296 | self.out_channels = out_channels 297 | self.dropout = dropout 298 | self.bias = bias 299 | self.weight = Parameter( 300 | torch.Tensor(in_channels, out_channels) 301 | ) 302 | nn.init.xavier_normal_(self.weight) 303 | if bias is True: 304 | self.bias = Parameter(torch.Tensor(out_channels)) 305 | nn.init.zeros_(self.bias) 306 | 307 | def forward(self, node_ft, adj_mat): 308 | laplace_mat = get_laplace_mat(adj_mat, type='sym') 309 | node_state = torch.mm(laplace_mat, node_ft) 310 | node_state = torch.mm(node_state, self.weight) 311 | if self.bias is not None: 312 | node_state = node_state + self.bias 313 | 314 | return node_state 315 | 316 | 317 | class StructureGAT(nn.Module): 318 | 319 | def __init__( 320 | self, 321 | level='aminoacid', 322 | num_blocks=4, 323 | hidden_channels=128, 324 | out_channels=1, 325 | mid_emb=64, 326 | num_radial=6, 327 | num_spherical=2, 328 | cutoff=10.0, 329 | max_num_neighbors=32, 330 | int_emb_layers=3, 331 | out_layers=2, 332 | num_pos_emb=16, 333 | dropout=0, 334 | data_augment_eachlayer=False, 335 | euler_noise=False, 336 | ): 337 | super(StructureGAT, self).__init__() 338 | self.cutoff = cutoff 339 | self.max_num_neighbors = max_num_neighbors 340 | self.num_pos_emb = num_pos_emb 341 | self.data_augment_eachlayer = data_augment_eachlayer 342 | self.euler_noise = euler_noise 343 | self.level = level 344 | self.act = swish 345 | 346 | self.feature0 = d_theta_phi_emb(num_radial=num_radial, num_spherical=num_spherical, cutoff=cutoff) 347 | self.feature1 = d_angle_emb(num_radial=num_radial, num_spherical=num_spherical, cutoff=cutoff) 348 | 349 | if level == 'allatom': 350 | self.embedding = torch.nn.Linear(num_aa_type + num_bb_embs + num_side_chain_embs, hidden_channels) 351 | else: 352 | print('No supported model!') 353 | 354 | self.interaction_blocks = torch.nn.ModuleList( 355 | [ 356 | InteractionBlock( 357 | hidden_channels=hidden_channels, 358 | output_channels=hidden_channels, 359 | num_radial=num_radial, 360 | num_spherical=num_spherical, 361 | num_layers=int_emb_layers, 362 | mid_emb=mid_emb, 363 | act=self.act, 364 | num_pos_emb=num_pos_emb, 365 | dropout=dropout, 366 | level=level 367 | ) 368 | for _ in range(num_blocks) 369 | ] 370 | ) 371 | 372 | self.lins_out = torch.nn.ModuleList() 373 | for _ in range(out_layers - 1): 374 | self.lins_out.append(Linear(hidden_channels, hidden_channels)) 375 | self.lin_out = Linear(hidden_channels, out_channels) 376 | 377 | self.relu = nn.ReLU() 378 | self.dropout = nn.Dropout(dropout) 379 | 380 | self.reset_parameters() 381 | 382 | def reset_parameters(self): 383 | self.embedding.reset_parameters() 384 | for interaction in self.interaction_blocks: 385 | interaction.reset_parameters() 386 | for lin in self.lins_out: 387 | lin.reset_parameters() 388 | self.lin_out.reset_parameters() 389 | 390 | def pos_emb(self, edge_index, num_pos_emb=16): 391 | d = edge_index[0] - edge_index[1] 392 | 393 | frequency = torch.exp( 394 | torch.arange(0, num_pos_emb, 2, dtype=torch.float32, device=edge_index.device) 395 | * -(np.log(10000.0) / num_pos_emb) 396 | ) 397 | angles = d.unsqueeze(-1) * frequency 398 | E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) 399 | return E 400 | 401 | def forward(self, batch_data): 402 | 403 | z, pos, batch = torch.squeeze(batch_data.x.long()), batch_data.coords_ca, batch_data.batch 404 | pos_n = batch_data.coords_n 405 | pos_c = batch_data.coords_c 406 | bb_embs = batch_data.bb_embs 407 | side_chain_embs = batch_data.side_chain_embs 408 | device = z.device 409 | 410 | if self.level == 'allatom': 411 | x = torch.cat([torch.squeeze(F.one_hot(z, num_classes=num_aa_type).float()), bb_embs, side_chain_embs], 412 | dim=1) 413 | x = self.embedding(x) 414 | else: 415 | print('No supported model!') 416 | 417 | edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors) 418 | pos_emb = self.pos_emb(edge_index, self.num_pos_emb) 419 | j, i = edge_index 420 | 421 | dist = (pos[i] - pos[j]).norm(dim=1) 422 | 423 | num_nodes = len(z) 424 | 425 | # Calculate angles theta and phi. 426 | refi0 = (i - 1) % num_nodes 427 | refi1 = (i + 1) % num_nodes 428 | 429 | a = ((pos[j] - pos[i]) * (pos[refi0] - pos[i])).sum(dim=-1) 430 | b = torch.cross(pos[j] - pos[i], pos[refi0] - pos[i]).norm(dim=-1) 431 | theta = torch.atan2(b, a) 432 | 433 | plane1 = torch.cross(pos[refi0] - pos[i], pos[refi1] - pos[i]) 434 | plane2 = torch.cross(pos[refi0] - pos[i], pos[j] - pos[i]) 435 | a = (plane1 * plane2).sum(dim=-1) 436 | b = (torch.cross(plane1, plane2) * (pos[refi0] - pos[i])).sum(dim=-1) / ((pos[refi0] - pos[i]).norm(dim=-1)) 437 | phi = torch.atan2(b, a) 438 | 439 | feature0 = self.feature0(dist, theta, phi) 440 | 441 | if self.level == 'allatom': 442 | # Calculate Euler angles. 443 | Or1_x = pos_n[i] - pos[i] 444 | Or1_z = torch.cross(Or1_x, torch.cross(Or1_x, pos_c[i] - pos[i])) 445 | Or1_z_length = Or1_z.norm(dim=1) + 1e-7 446 | 447 | Or2_x = pos_n[j] - pos[j] 448 | Or2_z = torch.cross(Or2_x, torch.cross(Or2_x, pos_c[j] - pos[j])) 449 | Or2_z_length = Or2_z.norm(dim=1) + 1e-7 450 | 451 | Or1_Or2_N = torch.cross(Or1_z, Or2_z) 452 | 453 | angle1 = torch.atan2((torch.cross(Or1_x, Or1_Or2_N) * Or1_z).sum(dim=-1) / Or1_z_length, 454 | (Or1_x * Or1_Or2_N).sum(dim=-1)) 455 | angle2 = torch.atan2(torch.cross(Or1_z, Or2_z).norm(dim=-1), (Or1_z * Or2_z).sum(dim=-1)) 456 | angle3 = torch.atan2((torch.cross(Or1_Or2_N, Or2_x) * Or2_z).sum(dim=-1) / Or2_z_length, 457 | (Or1_Or2_N * Or2_x).sum(dim=-1)) 458 | 459 | if self.euler_noise: 460 | euler_noise = torch.clip(torch.empty(3, len(angle1)).to(device).normal_(mean=0.0, std=0.025), min=-0.1, 461 | max=0.1) 462 | angle1 += euler_noise[0] 463 | angle2 += euler_noise[1] 464 | angle3 += euler_noise[2] 465 | 466 | feature1 = torch.cat( 467 | (self.feature1(dist, angle1), self.feature1(dist, angle2), self.feature1(dist, angle3)), 1) 468 | 469 | 470 | for interaction_block in self.interaction_blocks: 471 | if self.data_augment_eachlayer: 472 | gaussian_noise = torch.clip(torch.empty(x.shape).to(device).normal_(mean=0.0, std=0.025), min=-0.1, 473 | max=0.1) 474 | x += gaussian_noise 475 | x = interaction_block(x, feature0, feature1, pos_emb, edge_index, batch) 476 | 477 | y = scatter(x, batch, dim=0) 478 | 479 | return y 480 | 481 | @property 482 | def num_params(self): 483 | return sum(p.numel() for p in self.parameters()) 484 | 485 | class MuLAAIP(nn.Module): 486 | def __init__(self, norm_mode='PN', norm_scale=1): 487 | super(MuLAAIP, self).__init__() 488 | struct_hidden = 128 489 | seq_hidden = 1280 490 | self.dropout_rate = 0 491 | self.ab_struct_model = StructureGAT(num_blocks=4, hidden_channels=128, cutoff=10, level='allatom', dropout=0.1) 492 | self.ag_struct_model = StructureGAT(num_blocks=4, hidden_channels=128, cutoff=10, level='allatom', dropout=0) 493 | 494 | self.fc_ab = nn.Linear(seq_hidden, seq_hidden) 495 | self.fc_ag = nn.Linear(seq_hidden, seq_hidden) 496 | self.bn1 = nn.BatchNorm1d(seq_hidden) 497 | self.bn2 = nn.BatchNorm1d(seq_hidden) 498 | self.bn3 = nn.BatchNorm1d(struct_hidden) 499 | self.bn4 = nn.BatchNorm1d(struct_hidden) 500 | self.bn_struct = nn.BatchNorm1d(2*struct_hidden) 501 | self.bn_seq = nn.BatchNorm1d(2*seq_hidden) 502 | 503 | self.activation = nn.ReLU() 504 | 505 | hidden_size_combine = struct_hidden + seq_hidden 506 | self.norm = PairNorm(mode=norm_mode, scale=norm_scale) 507 | self.ab_gcn1 = GCNConvL(in_channels=seq_hidden, out_channels=seq_hidden) 508 | self.ag_gcn1 = GCNConvL(in_channels=seq_hidden, out_channels=seq_hidden) 509 | 510 | self.ab_gcn2 = GCNConvL(in_channels=seq_hidden, out_channels=seq_hidden) 511 | self.ag_gcn2 = GCNConvL(in_channels=seq_hidden, out_channels=seq_hidden) 512 | 513 | self.ab_gcn3 = GCNConvL(in_channels=seq_hidden, out_channels=seq_hidden) 514 | self.ag_gcn3 = GCNConvL(in_channels=seq_hidden, out_channels=seq_hidden) 515 | 516 | self.mlp = nn.Sequential( 517 | nn.Linear(2 * hidden_size_combine, hidden_size_combine), 518 | nn.BatchNorm1d(hidden_size_combine), 519 | nn.ReLU(), 520 | nn.Linear(hidden_size_combine, 320), 521 | nn.BatchNorm1d(320), 522 | nn.ReLU(), 523 | nn.Linear(320, 160), 524 | nn.BatchNorm1d(160), 525 | nn.ReLU(), 526 | nn.Linear(160, 32), 527 | nn.BatchNorm1d(32), 528 | nn.ReLU(), 529 | nn.Linear(32, 1) 530 | ) 531 | 532 | def forward(self, batch_ab, batch_ag, seq_emb_ab, seq_emb_ag): 533 | 534 | struct_ab_emb = self.ab_struct_model(batch_ab) 535 | struct_ag_emb = self.ag_struct_model(batch_ag) 536 | combined_embedding = torch.cat((struct_ab_emb, struct_ag_emb), dim=1) 537 | 538 | # bach norm 539 | s_ab = self.bn3(struct_ab_emb) 540 | s_ag = self.bn4(struct_ag_emb) 541 | combined_s = torch.cat((s_ab, s_ag), dim=1) 542 | combined = combined_s + combined_embedding 543 | global_struct_feature = self.activation(combined) 544 | global_struct_feature = self.bn_struct(global_struct_feature) 545 | 546 | ab_in = seq_emb_ab 547 | ag_in = seq_emb_ag 548 | ab_0 = self.bn1(ab_in) 549 | ab_0 = self.activation(ab_0) 550 | ab_0 = self.fc_ab(ab_0) 551 | ab_0 = F.dropout(ab_0, p=self.dropout_rate) 552 | ab_0 = ab_in + ab_0 553 | 554 | w_ab = torch.norm(ab_0, p=2, dim=-1).view(-1, 1) 555 | w_mat_ab = w_ab * w_ab.t() 556 | ab_adj = torch.mm(ab_0, ab_0.t()) / w_mat_ab 557 | ab_1 = self.ab_gcn1(ab_0, ab_adj) 558 | ab_1 = self.norm(ab_1) 559 | ab_1 = ab_0 + ab_1 560 | 561 | ab_2 = self.activation(ab_1) 562 | ab_2 = F.dropout(ab_2, p=self.dropout_rate) 563 | ab_2 = self.ab_gcn2(ab_2, ab_adj) 564 | ab_2 = self.norm(ab_2) 565 | ab_2 = ab_2 + ab_1 566 | 567 | ab_3 = self.activation(ab_2) 568 | ab_3 = F.dropout(ab_3, p=self.dropout_rate) 569 | ab_3 = self.ab_gcn3(ab_3, ab_adj) 570 | ab_3 = self.norm(ab_3) 571 | ab_3 = ab_3 + ab_2 572 | 573 | ag_0 = self.bn2(ag_in) 574 | ag_0 = self.activation(ag_0) 575 | ag_0 = self.fc_ag(ag_0) 576 | ag_0 = F.dropout(ag_0, p=self.dropout_rate) 577 | ag_0 = ag_in + ag_0 578 | 579 | w_ag = torch.norm(ag_0, p=2, dim=-1).view(-1, 1) 580 | w_mat_ag = w_ag * w_ag.t() 581 | ag_adj = torch.mm(ag_0, ag_0.t()) / w_mat_ag 582 | ag_1 = self.ag_gcn1(ag_0, ag_adj) 583 | ag_1 = self.norm(ag_1) 584 | ag_1 = ag_0 + ag_1 585 | 586 | ag_2 = self.activation(ag_1) 587 | ag_2 = F.dropout(ag_2, p=self.dropout_rate) 588 | ag_2 = self.ag_gcn2(ag_2, ag_adj) 589 | ag_2 = self.norm(ag_2) 590 | ag_2 = ag_2 + ag_1 591 | 592 | ag_3 = self.activation(ag_2) 593 | ag_3 = F.dropout(ag_3, p=self.dropout_rate) 594 | ag_3 = self.ag_gcn3(ag_3, ag_adj) 595 | ag_3 = self.norm(ag_3) 596 | ag_3 = ag_3 + ag_2 597 | 598 | x_3 = torch.cat((ab_3, ag_3), dim=1) 599 | x_2 = torch.cat((ab_2, ag_2), dim=1) 600 | x_1 = torch.cat((ab_1, ag_1), dim=1) 601 | 602 | global_seq_feature = x_1 + x_2 + x_3 + torch.cat((ab_in, ag_in), dim=1) 603 | global_seq_feature = self.activation(global_seq_feature) 604 | global_seq_feature = self.bn_seq(global_seq_feature) 605 | 606 | output = self.mlp(torch.cat((global_seq_feature, global_struct_feature), dim=1)) 607 | 608 | return output, ab_adj, ag_adj -------------------------------------------------------------------------------- /pipline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/trashTian/MuLAAIP/4472c23b414b05d9541fed7fa6b5b696ddc634c9/pipline.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aaindex==1.1.2 2 | ablang2==0.1.0 3 | absl-py==2.0.0 4 | accelerate==0.31.0 5 | aiohttp==3.9.5 6 | aiosignal==1.3.1 7 | antiberty==0.1.3 8 | anyio==4.1.0 9 | appdirs==1.4.4 10 | argon2-cffi==23.1.0 11 | argon2-cffi-bindings==21.2.0 12 | arrow==1.3.0 13 | asttokens==2.4.1 14 | astunparse==1.6.3 15 | async-lru==2.0.4 16 | async-timeout==4.0.3 17 | attrs==23.1.0 18 | Babel==2.13.1 19 | beartype==0.17.2 20 | beautifulsoup4==4.12.2 21 | biopandas==0.5.1.dev0 22 | biopython==1.83 23 | bioservices==1.11.2 24 | bleach==6.1.0 25 | boto3==1.34.61 26 | botocore==1.34.61 27 | cachetools==5.3.2 28 | cattrs==23.2.3 29 | certifi==2023.11.17 30 | cffi==1.16.0 31 | charset-normalizer==3.3.2 32 | click==8.1.7 33 | colorama==0.4.6 34 | colorlog==6.8.0 35 | comm==0.2.0 36 | contourpy==1.2.0 37 | cycler==0.12.1 38 | datasets==2.20.0 39 | debugpy==1.8.0 40 | decorator==5.1.1 41 | deepdiff==6.7.1 42 | defusedxml==0.7.1 43 | dill==0.3.8 44 | docker-pycreds==0.4.0 45 | easydev==0.12.1 46 | einops==0.7.0 47 | et-xmlfile==1.1.0 48 | evaluate==0.4.2 49 | exceptiongroup==1.2.0 50 | executing==2.0.1 51 | fair-esm==2.0.0 52 | fasteners==0.19 53 | fastjsonschema==2.19.0 54 | filelock==3.9.0 55 | flatbuffers==23.5.26 56 | fonttools==4.46.0 57 | fqdn==1.5.1 58 | frozenlist==1.4.1 59 | fsspec==2023.12.0 60 | gast==0.5.4 61 | gevent==23.9.1 62 | gitdb==4.0.11 63 | GitPython==3.1.40 64 | google-auth==2.24.0 65 | google-auth-oauthlib==1.1.0 66 | google-pasta==0.2.0 67 | graphein==1.7.5 68 | graphviz==0.20.1 69 | greenlet==3.0.1 70 | grequests==0.7.0 71 | GridDataFormats==1.0.2 72 | grpcio==1.59.3 73 | h5py==3.10.0 74 | huggingface-hub==0.23.4 75 | idna==3.6 76 | imbalanced-learn==0.11.0 77 | importlib-metadata==7.0.0 78 | importlib-resources==6.1.1 79 | ipykernel==6.27.1 80 | ipython==8.18.1 81 | ipython-genutils==0.2.0 82 | ipywidgets==7.6.5 83 | isoduration==20.11.0 84 | jaxtyping==0.2.24 85 | jedi==0.19.1 86 | Jinja2==3.1.2 87 | jmespath==1.0.1 88 | joblib==1.3.2 89 | json5==0.9.14 90 | jsonpointer==2.4 91 | jsonschema==4.20.0 92 | jsonschema-specifications==2023.11.2 93 | jupyter==1.0.0 94 | jupyter_client==8.6.0 95 | jupyter-console==6.6.3 96 | jupyter_core==5.5.0 97 | jupyter-events==0.9.0 98 | jupyter-lsp==2.2.1 99 | jupyter_server==2.11.1 100 | jupyter_server_terminals==0.4.4 101 | jupyterlab==4.0.9 102 | jupyterlab_pygments==0.3.0 103 | jupyterlab_server==2.25.2 104 | jupyterlab-widgets==3.0.9 105 | keras==2.15.0 106 | Keras-Preprocessing==1.1.2 107 | kiwisolver==1.4.5 108 | libclang==16.0.6 109 | lightning-utilities==0.10.0 110 | lmdb==1.4.1 111 | loguru==0.7.2 112 | looseversion==1.1.2 113 | lxml==4.9.3 114 | Markdown==3.5.1 115 | markdown-it-py==3.0.0 116 | MarkupSafe==2.1.3 117 | matplotlib==3.8.2 118 | matplotlib-inline==0.1.6 119 | MDAnalysis==2.6.1 120 | mdurl==0.1.2 121 | mistune==3.0.2 122 | ml-dtypes==0.2.0 123 | mmtf-python==1.1.3 124 | mpmath==1.3.0 125 | mrcfile==1.4.3 126 | msgpack==1.0.7 127 | multidict==6.0.5 128 | multipledispatch==1.0.0 129 | multiprocess==0.70.16 130 | nbclient==0.9.0 131 | nbconvert==7.11.0 132 | nbformat==5.9.2 133 | nest-asyncio==1.5.8 134 | networkx==3.0 135 | nglview==3.0.8 136 | notebook==7.0.6 137 | notebook_shim==0.2.3 138 | numpy==1.23.5 139 | oauthlib==3.2.2 140 | obonet==1.0.0 141 | OpenMM==8.1.1 142 | openpyxl==3.1.2 143 | opt-einsum==3.3.0 144 | ordered-set==4.1.0 145 | overrides==7.4.0 146 | packaging==23.2 147 | pandas==1.5.3 148 | pandocfilters==1.5.0 149 | parso==0.8.3 150 | pathlib==1.0.1 151 | pexpect==4.9.0 152 | Pillow==10.1.0 153 | pip==24.1.1 154 | platformdirs==4.0.0 155 | plotly==5.18.0 156 | prometheus-client==0.19.0 157 | prompt-toolkit==3.0.41 158 | protobuf==4.23.4 159 | psutil==5.9.6 160 | ptyprocess==0.7.0 161 | pure-eval==0.2.2 162 | pyarrow==16.1.0 163 | pyarrow-hotfix==0.6 164 | pyasn1==0.5.1 165 | pyasn1-modules==0.3.0 166 | pycparser==2.21 167 | pydantic==1.10.13 168 | Pygments==2.17.2 169 | pykan==0.0.5 170 | pyparsing==3.1.1 171 | pyrosetta-installer==0.1.1 172 | python-dateutil==2.8.2 173 | python-json-logger==2.0.7 174 | pytz==2023.3.post1 175 | pywin32==306 176 | pywinpty==2.0.12 177 | PyYAML==6.0.1 178 | pyzmq==25.1.1 179 | qtconsole==5.5.1 180 | QtPy==2.4.1 181 | referencing==0.31.1 182 | regex==2023.10.3 183 | requests==2.32.3 184 | requests-cache==1.1.1 185 | requests-oauthlib==1.3.1 186 | rfc3339-validator==0.1.4 187 | rfc3986-validator==0.1.1 188 | rich==13.7.0 189 | rich-click==1.7.2 190 | rotary-embedding-torch==0.5.3 191 | rpds-py==0.13.2 192 | rsa==4.9 193 | s3transfer==0.10.0 194 | safetensors==0.4.1 195 | scikit-learn==1.3.2 196 | scipy==1.11.4 197 | seaborn==0.13.0 198 | Send2Trash==1.8.2 199 | sentencepiece==0.1.99 200 | sentry-sdk==1.38.0 201 | setproctitle==1.3.3 202 | setuptools==69.2.0 203 | six==1.16.0 204 | smmap==5.0.1 205 | sniffio==1.3.0 206 | soupsieve==2.5 207 | stack-data==0.6.3 208 | suds-community==1.1.2 209 | sympy==1.12 210 | tape-proteins==0.5 211 | tenacity==8.2.3 212 | tensorboard==2.15.1 213 | tensorboard-data-server==0.7.2 214 | tensorboardX==2.6.2.2 215 | tensorflow==2.15.0 216 | tensorflow-estimator==2.15.0 217 | tensorflow-intel==2.15.0 218 | tensorflow-io-gcs-filesystem==0.31.0 219 | termcolor==2.4.0 220 | terminado==0.18.0 221 | threadpoolctl==3.2.0 222 | tinycss2==1.2.1 223 | tokenizers==0.15.0 224 | tomli==2.0.1 225 | torch==2.1.1+cu118 226 | torch-cluster==1.6.3+pt21cu118 227 | torch_geometric==2.5.3 228 | torch-scatter==2.1.2+pt21cu118 229 | torch-sparse==0.6.18+pt21cu118 230 | torch-spline-conv==1.2.2+pt21cu118 231 | torch-summary==1.4.5 232 | torchcontrib==0.0.2 233 | torchmetrics==1.2.1 234 | torchviz==0.0.2 235 | tornado==6.4 236 | tqdm==4.66.4 237 | traitlets==5.14.0 238 | transformers==4.35.2 239 | typeguard==2.13.3 240 | types-python-dateutil==2.8.19.14 241 | typing_extensions==4.5.0 242 | tzdata==2023.3 243 | uri-template==1.3.0 244 | url-normalize==1.4.3 245 | urllib3==1.26.18 246 | wandb==0.16.0 247 | wcwidth==0.2.12 248 | webcolors==1.13 249 | webencodings==0.5.1 250 | websocket-client==1.7.0 251 | Werkzeug==3.0.1 252 | wget==3.2 253 | wheel==0.43.0 254 | widgetsnbextension==3.5.2 255 | win32-setctime==1.1.0 256 | wrapt==1.14.1 257 | xarray==2023.11.0 258 | xgboost==2.0.3 259 | XlsxWriter==3.2.0 260 | xmltodict==0.13.0 261 | xxhash==3.4.1 262 | yarl==1.9.4 263 | zipp==3.17.0 264 | zope.event==5.0 265 | zope.interface==6.1 266 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from pathlib import Path 3 | import numpy as np 4 | from sklearn.model_selection import KFold 5 | import torch 6 | from torch import nn 7 | from torch_geometric.loader import DataLoader 8 | from Dataset import AffinityDataset 9 | import torch.optim as optim 10 | from torch_geometric.data import Batch 11 | import logging 12 | import sys 13 | from Utils import evaluate_metric, evaluate 14 | from module import MuLAAIP 15 | 16 | def train_epoch(model, device, dataloader, loss_fn, optimizer, param_l2_coef, adj_loss_coef, embedding): 17 | model.train() 18 | losses = 0.0 19 | y_true = [] 20 | y_pred = [] 21 | for batch in dataloader: 22 | batch_ab = batch 23 | batch_ag = Batch.from_data_list(batch.antigen) 24 | 25 | embedding_ab = embedding[batch_ab.label.numpy().tolist(), :1280] 26 | embedding_ag = embedding[batch_ag.label.numpy().tolist(), 1280:] 27 | 28 | batch_ab = batch_ab.to(device) 29 | batch_ag = batch_ag.to(device) 30 | embedding_ab = torch.from_numpy(embedding_ab).to(device) 31 | embedding_ag = torch.from_numpy(embedding_ag).to(device) 32 | 33 | optimizer.zero_grad() 34 | 35 | output, ab_adj_mat, ag_adj_mat = model(batch_ab=batch_ab, batch_ag=batch_ag, seq_emb_ab=embedding_ab, 36 | seq_emb_ag=embedding_ag) 37 | 38 | loss = loss_fn(output.squeeze().to(torch.float32), batch.y.to(torch.float32)) 39 | losses += loss.item() 40 | 41 | param_l2_loss = 0 42 | for name, param in model.named_parameters(): 43 | if 'bias' not in name: 44 | param_l2_loss += torch.norm(param, p=2) 45 | param_l2_loss = param_l2_coef * param_l2_loss 46 | adj_loss = adj_loss_coef * torch.norm(ab_adj_mat) + adj_loss_coef * torch.norm(ag_adj_mat) 47 | loss = loss + adj_loss + param_l2_loss 48 | 49 | y_true.append(batch.y.detach().cpu().numpy()) 50 | y_pred.append(output.detach().cpu().numpy()) 51 | 52 | loss.backward() 53 | optimizer.step() 54 | 55 | return losses, np.concatenate(y_true, axis=0), np.concatenate(y_pred, axis=0).reshape(-1) 56 | 57 | 58 | def valid_epoch(model, device, dataloader, loss_fn, embedding): 59 | model.eval() 60 | losses = 0.0 61 | y_true = [] 62 | y_pred = [] 63 | 64 | with torch.no_grad(): 65 | for batch in dataloader: 66 | batch_ab = batch 67 | batch_ag = Batch.from_data_list(batch.antigen) 68 | 69 | embedding_ab = embedding[batch_ab.label.numpy().tolist(), :1280] 70 | embedding_ag = embedding[batch_ag.label.numpy().tolist(), 1280:] 71 | 72 | batch_ab = batch_ab.to(device) 73 | batch_ag = batch_ag.to(device) 74 | embedding_ab = torch.from_numpy(embedding_ab).to(device) 75 | embedding_ag = torch.from_numpy(embedding_ag).to(device) 76 | 77 | output, ab_adj_mat, ag_adj_mat = model(batch_ab=batch_ab, batch_ag=batch_ag, seq_emb_ab=embedding_ab, 78 | seq_emb_ag=embedding_ag) 79 | 80 | loss = loss_fn(output.squeeze().to(torch.float32), batch.y.to(torch.float32)) 81 | losses += loss.item() 82 | y_true.append(batch.y.detach().cpu().numpy()) 83 | y_pred.append(output.detach().cpu().numpy()) 84 | 85 | return losses, np.concatenate(y_true, axis=0), np.concatenate(y_pred, axis=0).reshape(-1) 86 | 87 | 88 | def run(lr=5e-5, epochs=200, adj_loss_coef=5e-6, param_l2_coef=5e-6, batch_size=32, 89 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')): 90 | logging.basicConfig( 91 | filename='DeepAntibody_wild.log', 92 | level=logging.INFO, 93 | format='%(asctime)s %(levelname)s: %(message)s', 94 | datefmt='%Y-%m-%d %H:%M:%S' 95 | ) 96 | 97 | console_handler = logging.StreamHandler(sys.stdout) 98 | console_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s', '%Y-%m-%d %H:%M:%S')) 99 | logging.getLogger().addHandler(console_handler) 100 | 101 | current_path = Path.cwd() 102 | path = current_path / 'Dataset' / 'wild' 103 | dataset = AffinityDataset(root=str(path)) 104 | embedding = np.load(str(current_path / 'wild_embeddings' / 'ProtTrans_wild.npy')) 105 | 106 | spliter = KFold(n_splits=10, shuffle=True, random_state=42) 107 | dataset_size = len(dataset) 108 | 109 | fold_loss_train = {} 110 | fold_loss_val = {} 111 | 112 | for fold, (train_indices, val_indices) in enumerate(spliter.split(range(dataset_size))): 113 | logging.info("Fold {}, train size {}, test size {}".format(fold + 1, len(train_indices), len(val_indices))) 114 | fold_loss_train[str(fold + 1)] = [] 115 | fold_loss_val[str(fold + 1)] = [] 116 | 117 | model = MuLAAIP() 118 | model.to(device) 119 | optimizer = optim.Adam(model.parameters(), lr=lr) 120 | 121 | loss_fn = nn.MSELoss() 122 | 123 | train_data = dataset[train_indices.tolist()] 124 | val_data = dataset[val_indices.tolist()] 125 | 126 | train_data_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True) 127 | val_data_loader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True) 128 | 129 | for epoch in range(epochs): 130 | train_loss, y_true, y_pred = train_epoch(model=model, 131 | device=device, 132 | dataloader=train_data_loader, 133 | loss_fn=loss_fn, 134 | optimizer=optimizer, 135 | param_l2_coef=param_l2_coef, 136 | embedding=embedding, 137 | adj_loss_coef=adj_loss_coef) 138 | avg_train_loss = train_loss / len(train_data_loader) 139 | train_mae, train_corr, train_rmse = evaluate_metric(y_true, y_pred) 140 | 141 | val_loss, y_true, y_pred = valid_epoch(model=model, 142 | device=device, 143 | dataloader=val_data_loader, 144 | loss_fn=loss_fn, 145 | embedding=embedding) 146 | avg_val_loss = val_loss / len(val_data_loader) 147 | val_mae, val_corr, val_rmse = evaluate_metric(y_true, y_pred) 148 | 149 | 150 | logging.info( 151 | f"Epoch {epoch + 1}: Train MSE: {avg_train_loss:.4f}, " 152 | f"MAE: {train_mae:.4f}, " 153 | f"PCC: {train_corr:.4f}, " 154 | f"RMSE: {train_rmse:.4f}; " 155 | f"Val MSE: {avg_val_loss:.4f}, " 156 | f"MAE: {val_mae:.4f}, " 157 | f"PCC: {val_corr:.4f}, " 158 | f"RMSE: {val_rmse:.4f}" 159 | ) 160 | 161 | 162 | if __name__ == "__main__": 163 | run() 164 | --------------------------------------------------------------------------------