├── DRfold_infer.py ├── PotentialFold ├── Clust.py ├── Cubic.py ├── Optimization.py ├── Potential.py ├── Selection.py ├── __init__.py ├── a2b.py ├── geo.py ├── lbfgs_rosetta.py ├── lib │ ├── base.npy │ ├── ddf.json │ ├── other2.npy │ ├── others.npy │ ├── side.npy │ └── vdw.json ├── operations.py └── rigid.py ├── README.md ├── cfg_95 ├── EvoMSA.py ├── EvoMSA2XYZ.py ├── EvoPair.py ├── Evoformer.py ├── IPA.py ├── RNALM2 │ ├── EvoMSA.py │ ├── EvoPair.py │ ├── Evoformer.py │ ├── Model.py │ └── basic.py ├── Structure.py ├── __init__.py ├── base.npy ├── basic.py ├── data.py ├── newconfig ├── test_modeldir.py └── util.py ├── cfg_96 ├── EvoMSA.py ├── EvoMSA2XYZ.py ├── EvoPair.py ├── Evoformer.py ├── IPA.py ├── RNALM2 │ ├── EvoMSA.py │ ├── EvoPair.py │ ├── Evoformer.py │ ├── Model.py │ └── basic.py ├── Structure.py ├── __init__.py ├── base.npy ├── basic.py ├── data.py ├── newconfig ├── test_modeldir.py └── util.py ├── cfg_97 ├── EvoMSA.py ├── EvoMSA2XYZ.py ├── EvoPair.py ├── Evoformer.py ├── IPA.py ├── RNALM2 │ ├── EvoMSA.py │ ├── EvoPair.py │ ├── Evoformer.py │ ├── Model.py │ └── basic.py ├── Structure.py ├── base.npy ├── basic.py ├── data.py ├── newconfig ├── test_modeldir.py └── util.py ├── cfg_99 ├── EvoMSA.py ├── EvoMSA2XYZ.py ├── EvoPair.py ├── Evoformer.py ├── IPA.py ├── RNALM2 │ ├── EvoMSA.py │ ├── EvoPair.py │ ├── Evoformer.py │ ├── Model.py │ └── basic.py ├── Structure.py ├── __init__.py ├── base.npy ├── basic.py ├── data.py ├── newconfig ├── test_modeldir.py └── util.py ├── cfg_for_folding.json ├── cfg_for_selection.json ├── install.sh ├── script └── refine.py └── test └── seq.fasta /DRfold_infer.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | from subprocess import Popen, PIPE, STDOUT 3 | import numpy as np 4 | import torch 5 | exp_dir = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | device = "cuda" if torch.cuda.is_available() else "cpu" 8 | 9 | 10 | 11 | dlexps = ['cfg_95','cfg_96','cfg_97','cfg_99'] 12 | 13 | 14 | fastafile = os.path.realpath(sys.argv[1]) 15 | outdir = os.path.realpath(sys.argv[2]) 16 | 17 | 18 | 19 | pclu = False 20 | if len(sys.argv) == 4 and sys.argv[3] == '1': 21 | print('will do cluster') 22 | pclu = True 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | if not os.path.isdir(outdir): 32 | os.makedirs(outdir) 33 | ret_dir = os.path.join(outdir,'rets_dir') 34 | if not os.path.isdir(ret_dir): 35 | os.makedirs(ret_dir) 36 | 37 | 38 | folddir = os.path.join(outdir,'folds') 39 | if not os.path.isdir(folddir): 40 | os.makedirs(folddir) 41 | 42 | refdir = os.path.join(outdir,'relax') 43 | if not os.path.isdir(refdir): 44 | os.makedirs(refdir) 45 | 46 | 47 | dlmains = [os.path.join(exp_dir,one_exp,'test_modeldir.py') for one_exp in dlexps] 48 | dirs = [os.path.join(exp_dir,'model_hub',one_exp) for one_exp in dlexps] 49 | if not os.path.isfile(ret_dir+'/done'): 50 | print(ret_dir+'/done', 'is not here. Will generate e2e and geo files.') 51 | for dlmain,one_exp,mdir in zip(dlmains,dlexps,dirs): 52 | cmd = f'python {dlmain} {device} {fastafile} {ret_dir}/{one_exp}_ {mdir}' 53 | print(cmd) 54 | # expdir=os.path.dirname(os.path.abspath(__file__)) 55 | p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT) 56 | output,error = p.communicate() 57 | #print(output,error) 58 | wfile = open(ret_dir+'/done','w') 59 | wfile.write('1') 60 | wfile.close() 61 | else: 62 | print(ret_dir+'/done', 'is here, using existing e2e and geo files.') 63 | 64 | 65 | def get_model_pdb(tdir,opt): 66 | files = os.listdir(tdir) 67 | files = [afile for afile in files if afile.startswith(opt)][0] 68 | return files 69 | 70 | cso_dir = folddir 71 | clufile = os.path.join(folddir,'clu.txt') 72 | config_sel = os.path.join(exp_dir,'cfg_for_selection.json') 73 | foldconfig = os.path.join(exp_dir,'cfg_for_folding.json') 74 | selpython = os.path.join(exp_dir,'PotentialFold','Selection.py') 75 | optpython = os.path.join(exp_dir,'PotentialFold','Optimization.py') 76 | clupy = os.path.join(exp_dir,'PotentialFold','Clust.py') 77 | arena = os.path.join(exp_dir,'Arena','Arena') 78 | 79 | optsaveprefix=os.path.join(cso_dir,f'opt_0') 80 | save_prefix = os.path.join(cso_dir,f'sel_0') 81 | rets = os.listdir(ret_dir) 82 | rets = [afile for afile in rets if afile.endswith('.ret')] 83 | rets = [os.path.join(ret_dir,aret) for aret in rets ] 84 | ret_str = ' '.join(rets) 85 | cmd = f'python {selpython} {fastafile} {config_sel} {save_prefix} {ret_str}' 86 | p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT) 87 | output,error = p.communicate() 88 | #print(output,error) 89 | cmd = f'python {optpython} {fastafile} {optsaveprefix} {ret_dir} {save_prefix} {foldconfig}' 90 | p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT) 91 | output,error = p.communicate() 92 | #print(output,error) 93 | cgpdb = os.path.join(folddir,get_model_pdb(folddir,'opt_0')) 94 | savepdb = os.path.join(refdir,'model_1.pdb') 95 | cmd = f'{arena} {cgpdb} {savepdb} 7' 96 | p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT) 97 | output,error = p.communicate() 98 | #print(output,error) 99 | 100 | if pclu: 101 | cmd = f'python {clupy} {ret_dir} {clufile}' 102 | p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT) 103 | output,error = p.communicate() 104 | 105 | 106 | lines = open(clufile).readlines() 107 | lines = [aline.strip() for aline in lines] 108 | lines = [aline for aline in lines if aline] 109 | 110 | for i in range(1,len(lines)): 111 | rets = lines[i].split() 112 | rets = [os.path.join(ret_dir,aret.replace('.pdb','.ret')) for aret in rets ] 113 | ret_str = ' '.join(rets) 114 | optsaveprefix = os.path.join(cso_dir,f'opt_{str(i+1)}') 115 | save_prefix = os.path.join(cso_dir,f'sel_{str(i+1)}') 116 | cmd = f'python {selpython} {fastafile} {config_sel} {save_prefix} {ret_str}' 117 | print(cmd) 118 | p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT) 119 | output,error = p.communicate() 120 | #print(output,error) 121 | cmd = f'python {optpython} {fastafile} {optsaveprefix} {ret_dir} {save_prefix} {foldconfig}' 122 | p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT) 123 | output,error = p.communicate() 124 | #print(output,error) 125 | cgpdb = os.path.join(folddir,get_model_pdb(folddir,f'opt_{str(i+1)}')) 126 | savepdb = os.path.join(refdir,f'model_{str(i+1)}.pdb') 127 | cmd = f'{arena} {cgpdb} {savepdb} 7' 128 | p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT) 129 | output,error = p.communicate() 130 | -------------------------------------------------------------------------------- /PotentialFold/Cubic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.interpolate import CubicSpline,UnivariateSpline 3 | import os 4 | from torch.autograd import Function 5 | import torch 6 | import math 7 | 8 | 9 | 10 | def fit_dis_cubic(dis_matrix,min_dis,max_dis,num_bin): 11 | dis_region=np.zeros(num_bin) 12 | for i in range(num_bin): 13 | dis_region[i]=min_dis+(i+0.5)*(max_dis-min_dis)*1.0/num_bin 14 | L=dis_matrix.shape[0] 15 | csnp=[] 16 | decsnp=[] 17 | for i in range(L): 18 | css=[] 19 | decss=[] 20 | for j in range(L): 21 | y=-np.log( (dis_matrix[i,j,1:-1]+1e-8) / (dis_matrix[i,j,[-2]]+1e-8) ) 22 | x=dis_region 23 | x[0]=-0.0001 24 | y[0]= max(10,y[1]+4) 25 | cs= CubicSpline(x,y) 26 | decs=cs.derivative() 27 | css.append(cs) 28 | decss.append(decs) 29 | csnp.append(css) 30 | decsnp.append(decss) 31 | return np.array(csnp),np.array(decsnp) 32 | 33 | def dis_cubic(out,min_dis,max_dis,num_bin): 34 | print('fitting cubic distance') 35 | cs,decs=fit_dis_cubic(out,min_dis,max_dis,num_bin) 36 | return cs,decs 37 | 38 | 39 | 40 | def cubic_matrix_torsion(dis_matrix,min_dis,max_dis,num_bin): 41 | dis_region=np.zeros(num_bin) 42 | bin_size=(max_dis-min_dis)/num_bin 43 | for i in range(num_bin): 44 | dis_region[i]=min_dis+(i+0.5)*(max_dis-min_dis)*1.0/num_bin 45 | L=dis_matrix.shape[0] 46 | csnp=[] 47 | decsnp=[] 48 | for i in range(L): 49 | css=[] 50 | decss=[] 51 | for j in range(L): 52 | y=-np.log( dis_matrix[i,j,:-1]+1e-8 ) 53 | x=dis_region 54 | x=np.append(x,x[-1]+bin_size) 55 | y=np.append(y,y[0]) 56 | cs= CubicSpline(x,y,bc_type='periodic') 57 | decs=cs.derivative() 58 | css.append(cs) 59 | decss.append(decs) 60 | csnp.append(css) 61 | decsnp.append(decss) 62 | return np.array(csnp),np.array(decsnp) 63 | def torsion_cubic(out,min_dis,max_dis,num_bin): 64 | print('fitting cubic') 65 | cs,decs=cubic_matrix_torsion(out,min_dis,max_dis,num_bin) 66 | return cs,decs 67 | 68 | def cubic_matrix_angle(dis_matrix,min_dis,max_dis,num_bin): # 0 - np.pi 12 69 | dis_region=np.zeros(num_bin) 70 | bin_size=(max_dis-min_dis)/num_bin 71 | for i in range(num_bin): 72 | dis_region[i]=min_dis+(i+0.5)*(max_dis-min_dis)*1.0/num_bin 73 | L=dis_matrix.shape[0] 74 | csnp=[] 75 | decsnp=[] 76 | for i in range(L): 77 | css=[] 78 | decss=[] 79 | for j in range(L): 80 | y=-np.log( dis_matrix[i,j,:-1]+1e-8 ) 81 | x=dis_region 82 | 83 | x=np.concatenate([[x[0]-bin_size*3,x[0]-bin_size*2,x[0]-bin_size], x,[x[-1]+bin_size,x[-1]+bin_size*2,x[-1]+bin_size*3] ]) 84 | y=np.concatenate([ [y[2],y[1],y[0]],y,[y[-1],y[-2],y[-3]] ]) 85 | 86 | cs= CubicSpline(x,y) 87 | decs=cs.derivative() 88 | 89 | css.append(cs) 90 | decss.append(decs) 91 | csnp.append(css) 92 | decsnp.append(decss) 93 | 94 | return np.array(csnp),np.array(decsnp) 95 | def angle_cubic(out,min_dis,max_dis,num_bin): 96 | 97 | print('fitting angle cubic') 98 | cs,decs=cubic_matrix_angle(out,min_dis,max_dis,num_bin) 99 | 100 | return cs,decs -------------------------------------------------------------------------------- /PotentialFold/Potential.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.interpolate import CubicSpline,UnivariateSpline 3 | import os 4 | from torch.autograd import Function 5 | import torch 6 | import math,json 7 | 8 | 9 | def batched_index_select(input, dim, index): 10 | # https://discuss.pytorch.org/t/batched-index-select/9115/8 11 | views = [input.shape[0]] + \ 12 | [1 if i != dim else -1 for i in range(1, len(input.shape))] 13 | expanse = list(input.shape) 14 | expanse[0] = -1 15 | expanse[dim] = -1 16 | index = index.view(views).expand(expanse) 17 | return torch.gather(input, dim, index).squeeze() 18 | 19 | class cubic_batch_dis_class(Function): 20 | @staticmethod 21 | def forward(ctx,input1,coe,x,min_dis,max_dis,bin_num): 22 | # inoput: B coe: B 3 i 23 | # min_ref=config['min_dis']+((config['max_dis']-config['min_dis'])/config['bin_num'])*1.5 24 | # bin_size = (config['max_dis']-config['min_dis'])/config['bin_num'] 25 | min_ref = min_dis+((max_dis-min_dis)/bin_num)*1.5 26 | bin_size = (max_dis-min_dis)/bin_num 27 | inputi= input1.detach() 28 | selction1=inputi <=min_ref 29 | selction2=inputi > min_ref 30 | ctx.selction1=selction1 31 | ctx.selction2=selction2 32 | ctx.inputi=inputi 33 | ctx.coe=coe 34 | out=inputi*1.0 35 | #out[selction1] = coe[selction1,0,0]*(inputi[selction1]+1e-4)**3 + coe[selction1,1,0]*(inputi[selction1]+1e-4)**2 + coe[selction1,2,0]*(inputi[selction1]+1e-4) + coe[selction1,3,0] 36 | #indexes =(( (inputi[selction2]-min_ref) // 0.5) +1).long() 37 | indexes = (torch.div((inputi[selction2]-min_ref), bin_size, rounding_mode='floor') +1 ).long() 38 | selectedaoe = batched_index_select(coe[selction2],2,indexes) # B 3 39 | selectedx= batched_index_select(x[selction2],1,indexes) # B 40 | input2=inputi[selction2] 41 | out[selction2] = selectedaoe[:,0]*(input2-selectedx)**3 + selectedaoe[:,1]*(input2-selectedx)**2 + selectedaoe[:,2]*(input2-selectedx) + selectedaoe[:,3] 42 | ctx.indexes=indexes 43 | ctx.selectedx=selectedx 44 | ctx.selectedaoe =selectedaoe 45 | return out 46 | @staticmethod 47 | def backward(ctx,grad_output): 48 | inputi=ctx.inputi 49 | grad = inputi+0.0 50 | grad[ctx.selction1] = 3*ctx.coe[ctx.selction1,0,0]*(inputi[ctx.selction1]+1e-4)**2 + 2*ctx.coe[ctx.selction1,1,0]*(ctx.inputi[ctx.selction1]+1e-4) + ctx.coe[ctx.selction1,2,0] 51 | grad[ctx.selction2] = 3*ctx.selectedaoe[:,0]*(inputi[ctx.selction2]- ctx.selectedx)**2 + 2*ctx.selectedaoe[:,1]*(ctx.inputi[ctx.selction2]-ctx.selectedx) + ctx.selectedaoe[:,2] 52 | return grad_output*grad,None,None,None,None,None,None 53 | def cubic_distance(input1,coe,x,min_dis,max_dis,bin_num): 54 | return cubic_batch_dis_class.apply(input1,coe,x,min_dis,max_dis,bin_num) 55 | 56 | 57 | class cubic_batch_torsion_class(Function): 58 | @staticmethod 59 | def forward(ctx,input1,coe,x,num_bin): 60 | # inoput: B coe: B 3 i 61 | x0=x[0][0] 62 | inputi= input1.detach() 63 | inputi[inputi < x0] +=math.pi*2 64 | ctx.inputi=inputi 65 | ctx.coe=coe 66 | out=inputi*1.0 67 | #indexes =(( (inputi-x0) // (2*math.pi/num_bin))).long() 68 | indexes =torch.div((inputi-x0), (2*math.pi/num_bin), rounding_mode='floor').long() 69 | selectedaoe = batched_index_select(coe,2,indexes) # B 3 70 | selectedx= batched_index_select(x,1,indexes) # B 71 | 72 | out = selectedaoe[:,0]*(inputi-selectedx)**3 + selectedaoe[:,1]*(inputi-selectedx)**2 + selectedaoe[:,2]*(inputi-selectedx) + selectedaoe[:,3] 73 | ctx.indexes=indexes 74 | ctx.selectedx=selectedx 75 | ctx.selectedaoe =selectedaoe 76 | #print(out.shape) 77 | #print(out) 78 | return out 79 | @staticmethod 80 | def backward(ctx,grad_output): 81 | inputi=ctx.inputi 82 | grad = inputi+0.0 83 | 84 | grad = 3*ctx.selectedaoe[:,0]*(inputi- ctx.selectedx)**2 + 2*ctx.selectedaoe[:,1]*(ctx.inputi-ctx.selectedx) + ctx.selectedaoe[:,2] 85 | 86 | #print('grad',grad.shape) 87 | #print(grad.sum()) 88 | return grad_output*grad,None,None,None 89 | def cubic_torsion(input1,coe,x,num_bin): 90 | return cubic_batch_torsion_class.apply(input1,coe,x,num_bin) 91 | 92 | 93 | class cubic_batch_angle_class(Function): 94 | @staticmethod 95 | def forward(ctx,input1,coe,x,num_bin=12): 96 | # inoput: B coe: B 3 i 97 | x0=x[0][0] 98 | inputi= input1.detach() 99 | #print(x0) 100 | ctx.inputi=inputi 101 | ctx.coe=coe 102 | out=inputi*1.0 103 | indexes =(( (inputi-x0) // (math.pi/num_bin))).long() 104 | selectedaoe = batched_index_select(coe,2,indexes) # B 3 105 | selectedx= batched_index_select(x,1,indexes) # B 106 | 107 | out = selectedaoe[:,0]*(inputi-selectedx)**3 + selectedaoe[:,1]*(inputi-selectedx)**2 + selectedaoe[:,2]*(inputi-selectedx) + selectedaoe[:,3] 108 | ctx.indexes=indexes 109 | ctx.selectedx=selectedx 110 | ctx.selectedaoe =selectedaoe 111 | #print(out.shape) 112 | #print(out) 113 | return out 114 | @staticmethod 115 | def backward(ctx,grad_output): 116 | inputi=ctx.inputi 117 | grad = inputi+0.0 118 | 119 | grad = 3*ctx.selectedaoe[:,0]*(inputi- ctx.selectedx)**2 + 2*ctx.selectedaoe[:,1]*(ctx.inputi-ctx.selectedx) + ctx.selectedaoe[:,2] 120 | 121 | #print('grad',grad.shape) 122 | #print(grad.sum()) 123 | return grad_output*grad,None,None,None 124 | def cubic_angle(input1,coe,x,num_bin): 125 | return cubic_batch_angle_class.apply(input1,coe,x,num_bin) 126 | 127 | 128 | 129 | def LJpotential(dis,th): 130 | #238325838 131 | r = ( (th+0.5) / (dis+0.5))**6 132 | return (r**2 - 2*r) 133 | #return torch.clamp((r**2 - 2*r),max=228325838) 134 | if __name__ == '__main__': 135 | print(LJpotential( torch.Tensor([0.]) ,2.5)) -------------------------------------------------------------------------------- /PotentialFold/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/PotentialFold/__init__.py -------------------------------------------------------------------------------- /PotentialFold/geo.py: -------------------------------------------------------------------------------- 1 | import math 2 | from numpy import NINF, arccos, arctan2 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | 7 | def sin_cos_angle(p0,p1,p2): 8 | # [b 3] 9 | b0=p0-p1 10 | b1=p2-p1 11 | 12 | b0=b0 / (torch.norm(b0,dim=-1,keepdim=True)+1e-08) 13 | b1=b1 / (torch.norm(b1,dim=-1,keepdim=True)+1e-08) 14 | recos=torch.sum(b0*b1,-1) 15 | recos=torch.clamp(recos,-0.9999,0.9999) 16 | resin = torch.sqrt(1-recos**2) 17 | return resin,recos 18 | 19 | 20 | def sin_cos_dihedral(p0,p1,p2,p3): 21 | 22 | #p0 = p[:,0:1,:] 23 | #p1 = p[:,1:2,:] 24 | #p2 = p[:,2:3,:] 25 | #p3 = p[:,3:4,:] 26 | b0 = -1.0*(p1 - p0) 27 | b1 = p2 - p1 28 | b2 = p3 - p2 29 | 30 | b1=b1/(torch.norm(b1,dim=-1,keepdim=True)+1e-8) 31 | 32 | v = b0 - torch.einsum('bj,bj->b', b0, b1)[:,None]*b1 33 | w = b2 - torch.einsum('bj,bj->b', b2, b1)[:,None]*b1 34 | x = torch.einsum('bj,bj->b', v, w) 35 | #print(x) 36 | y = torch.einsum('bj,bj->b', torch.cross(b1, v,-1), w) 37 | #print(y.shape) 38 | torsion_L = torch.norm(torch.cat([x[:,None],y[:,None]],dim=-1),dim=-1) 39 | x = x / (torsion_L+1e-8) 40 | y = y / (torsion_L+1e-8) 41 | return y,x #torch.atan2(y,x) 42 | 43 | def dihedral_2d(p0,p1,p2,p3): 44 | # p : [L,L,3] 45 | b0 = -1.0*(p1 - p0) 46 | b1 = p2 - p1 47 | b2 = p3 - p2 48 | #print(b0.shape) 49 | b1=b1/(torch.norm(b1,dim=-1,keepdim=True)+1e-8) 50 | v = b0 - torch.einsum('abj,abj->ab', b0, b1)[...,None]*b1 51 | w = b2 - torch.einsum('abj,abj->ab', b2, b1)[...,None]*b1 52 | x = torch.einsum('abj,abj->ab', v, w) 53 | y = torch.einsum('abj,abj->ab', torch.cross(b1, v,-1), w) 54 | return torch.atan2(y,x) 55 | def dihedral_1d(p0,p1,p2,p3): 56 | # p : [L,L,3] 57 | b0 = -1.0*(p1 - p0) 58 | b1 = p2 - p1 59 | b2 = p3 - p2 60 | print(b0.shape) 61 | b1=b1/(torch.norm(b1,dim=-1,keepdim=True)+1e-8) 62 | v = b0 - torch.einsum('bj,bj->b', b0, b1)[...,None]*b1 63 | w = b2 - torch.einsum('bj,bj->b', b2, b1)[...,None]*b1 64 | x = torch.einsum('bj,bj->b', v, w) 65 | y = torch.einsum('bj,bj->b', torch.cross(b1, v,-1), w) 66 | return torch.atan2(y,x) 67 | 68 | def angle_2d(p0,p1,p2): 69 | # [a b 3] 70 | b0=p0-p1 71 | b1=p2-p1 72 | b0=b0 / (torch.norm(b0,dim=-1,keepdim=True)+1e-08) 73 | b1=b1 / (torch.norm(b1,dim=-1,keepdim=True)+1e-08) 74 | recos=torch.sum(b0*b1,-1) 75 | recos=torch.clamp(recos,-0.9999,0.9999) 76 | return torch.arccos(recos) 77 | def angle_1d(p0,p1,p2): 78 | return angle_2d(p0,p1,p2) 79 | 80 | def distance_2d(p0,p1): 81 | return (p0-p1).norm(dim=-1) 82 | 83 | def get_omg_map(x): 84 | # x: L 4 3 N CA C CB 85 | L=x.shape[0] 86 | cai=x[:,None,1].repeat(1,L,1) 87 | cbi=x[:,None,-1].repeat(1,L,1) 88 | cbj=x[None,:,-1].repeat(L,1,1) 89 | caj=x[None,:,1].repeat(L,1,1) 90 | torsion = dihedral_2d(cai,cbi,cbj,caj) 91 | return torsion 92 | 93 | def get_phi_map(x): 94 | L=x.shape[0] 95 | cai=x[:,None,1].repeat(1,L,1) 96 | cbi=x[:,None,-1].repeat(1,L,1) 97 | cbj=x[None,:,-1].repeat(L,1,1) 98 | return angle_2d(cai,cbi,cbj) 99 | 100 | def get_theta_map(x): 101 | L=x.shape[0] 102 | ni =x[:,None,0].repeat(1,L,1) 103 | cai=x[:,None,1].repeat(1,L,1) 104 | cbi=x[:,None,-1].repeat(1,L,1) 105 | cbj=x[None,:,-1].repeat(L,1,1) 106 | return dihedral_2d(ni,cai,cbi,cbj) 107 | 108 | def get_cadis_map(x): 109 | cai=x[:,None,1] 110 | caj=x[None,:,1] 111 | return distance_2d(cai,caj) 112 | 113 | def get_cbdis_map(x): 114 | cai=x[:,None,-1] 115 | caj=x[None,:,-1] 116 | return distance_2d(cai,caj) 117 | 118 | 119 | def get_all_prot(x): 120 | L=x.shape[0] 121 | ni =x[:,None,0].repeat(1,L,1) 122 | cai=x[:,None,1].repeat(1,L,1) 123 | ci= x[:,None,2].repeat(1,L,1) 124 | cbi=x[:,None,-1].repeat(1,L,1) 125 | 126 | nj =x[None,:,0].repeat(L,1,1) 127 | caj=x[None,:,1].repeat(L,1,1) 128 | cj =x[None,:,2].repeat(L,1,1) 129 | cbj=x[None,:,-1].repeat(L,1,1) 130 | 131 | cbmap=distance_2d(cbi,cbj) 132 | camap=distance_2d(cai,caj) 133 | 134 | 135 | omgmap=dihedral_2d(cai,cbi,cbj,caj) 136 | psimap=angle_2d(cai,cbi,cbj) 137 | thetamap=dihedral_2d(ni,cai,cbi,cbj) 138 | 139 | def get_all(x): 140 | L = x.shape[0] 141 | pi= x[:,None,0].repeat(1,L,1) 142 | ci= x[:,None,1].repeat(1,L,1) 143 | ni= x[:,None,2].repeat(1,L,1) 144 | 145 | pj= x[None,:,0].repeat(L,1,1) 146 | cj= x[None,:,1].repeat(L,1,1) 147 | nj= x[None,:,2].repeat(L,1,1) 148 | 149 | pp = distance_2d(pi,pj) 150 | cc = distance_2d(ci,cj) 151 | nn = distance_2d(ni,nj) 152 | 153 | pnn = angle_2d(pi,ni,nj) 154 | pcc = angle_2d(pi,ci,cj) 155 | cnn = angle_2d(ci,ni,nj) 156 | 157 | pccp = dihedral_2d(pi,ci,cj,pj) 158 | pnnp = dihedral_2d(pi,ni,nj,pj) 159 | cnnc = dihedral_2d(ci,ni,nj,cj) 160 | 161 | 162 | return pp,cc,nn,pnn,pcc,cnn,pccp,pnnp,cnnc 163 | 164 | -------------------------------------------------------------------------------- /PotentialFold/lib/base.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/PotentialFold/lib/base.npy -------------------------------------------------------------------------------- /PotentialFold/lib/ddf.json: -------------------------------------------------------------------------------- 1 | { 2 | "weight_pp" : 1, 3 | "weight_cc" : 1, 4 | "weight_nn" : 1, 5 | "weight_pccp": 0, 6 | "weight_cnnc": 0, 7 | "weight_pnnp": 0, 8 | "weight_pcc" :0, 9 | "weight_cnn": 0, 10 | "weight_pnn": 0, 11 | "weight_vdw": 1, 12 | "weight_nn_contact":0, 13 | "weight_cc_contact":0, 14 | "weight_beta": 0, 15 | "weight_fape": 1, 16 | "weight_bond": 1000, 17 | 18 | 19 | "pair_weight_power": 0, 20 | "pair_weight_min": 0.3, 21 | "pair_error_power": 1, 22 | "pair_rest_min_dist": 3, 23 | "FAPE_max": 30, 24 | "geo_scale":50 25 | 26 | 27 | } -------------------------------------------------------------------------------- /PotentialFold/lib/other2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/PotentialFold/lib/other2.npy -------------------------------------------------------------------------------- /PotentialFold/lib/others.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/PotentialFold/lib/others.npy -------------------------------------------------------------------------------- /PotentialFold/lib/side.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/PotentialFold/lib/side.npy -------------------------------------------------------------------------------- /PotentialFold/lib/vdw.json: -------------------------------------------------------------------------------- 1 | { 2 | "weight_pp" : 0, 3 | "weight_cc" : 0, 4 | "weight_nn" : 0, 5 | "weight_pccp": 0, 6 | "weight_cnnc": 0, 7 | "weight_pnnp": 0, 8 | "weight_pcc" :0, 9 | "weight_cnn": 0, 10 | "weight_pnn": 0, 11 | "weight_vdw": 0.5, 12 | "weight_nn_contact":0, 13 | "weight_cc_contact":0, 14 | "weight_beta": 0, 15 | "weight_fape": 0, 16 | "weight_bond": 100 17 | } -------------------------------------------------------------------------------- /PotentialFold/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | import numpy as np 5 | import math,sys,math 6 | from subprocess import Popen, PIPE, STDOUT 7 | from io import BytesIO 8 | import os 9 | from torch.autograd import Function 10 | 11 | def coor_selection(coor,mask): 12 | #[L,n,3],[L,n],byte 13 | return torch.masked_select(coor,mask.bool()).view(-1,3) 14 | 15 | 16 | def pair_distance(x1,x2,eps=1e-6,p=2): 17 | n1=x1.size()[0] 18 | n2=x2.size()[0] 19 | x1_=x1.view(n1,1,-1) 20 | x2_=x2.view(1,n2,-1) 21 | x1_=x1_.expand(n1,n2,-1) 22 | x2_=x2_.expand(n1,n2,-1) 23 | diff = torch.abs(x1_ - x2_) 24 | out = torch.pow(diff + eps, p).sum(dim=2) 25 | return torch.pow(out, 1. / p) 26 | 27 | class torsion(Function): 28 | #PyTorch class to calculate differentiable torsion angle 29 | #https://stackoverflow.com/questions/20305272/dihedral-torsion-angle-from-four-points-in-cartesian-coordinates-in-python 30 | #https://salilab.org/modeller/manual/node492.html 31 | @staticmethod 32 | def forward_(ctx,input1,input2,input3,input4): 33 | #Lx3 34 | # 0 1 2 3 35 | inputi,inputj,inputk,inputl=input1.detach().numpy(),input2.detach().numpy(),input3.detach().numpy(),input4.detach().numpy() 36 | rij,rkj,rkl=inputi-inputj,inputk-inputj,inputk-inputl 37 | corss_ijkj=np.cross(rij,rkj) 38 | cross_kjkl=np.cross(rkj,rkl) 39 | angle=np.sum(corss_ijkj*cross_kjkl,axis=-1) 40 | angle=angle/(np.linalg.norm(corss_ijkj,axis=-1)*np.linalg.norm(cross_kjkl,axis=-1)) 41 | signlamda=np.sign(np.sum(rkj*np.cross(corss_ijkj,cross_kjkl),-1)) 42 | angle[angle<-1]=-1 43 | angle[angle>1]=1 44 | ctx.mj=corss_ijkj 45 | ctx.nk=cross_kjkl 46 | ctx.rkj=rkj 47 | ctx.rij=rij 48 | ctx.rkl=rkl 49 | #ctx.save_for_backward(input1,input2,input3,input4) 50 | return torch.as_tensor(np.arccos(angle)*signlamda,dtype=input1.dtype) 51 | @staticmethod 52 | def forward(ctx,input1,input2,input3,input4): 53 | #Lx3 54 | p0,p1,p2,p3=input1.detach().numpy(),input2.detach().numpy(),input3.detach().numpy(),input4.detach().numpy() 55 | b0_ = -(p1-p0) 56 | b1_ = p2-p1 57 | b2_ = p3-p2 58 | ctx.rkj=b1_+0.0 59 | ctx.rij=b0_+0.0 60 | ctx.rkl=-b2_ +0.0 61 | ctx.mj=np.cross(ctx.rij,ctx.rkj) 62 | ctx.nk=np.cross(ctx.rkj,ctx.rkl) 63 | b1 =b1_ / np.linalg.norm(b1_,axis=-1,keepdims=True) 64 | v = b0_ - (b0_*b1).sum(-1,keepdims=True) *b1 65 | w = b2_ - (b2_*b1).sum(-1,keepdims=True)*b1 66 | x = (v*w).sum(-1) 67 | y = (np.cross(b1, v)*w).sum(-1) 68 | #print(x.shape) 69 | 70 | return torch.as_tensor(np.arctan2(y, x),dtype=input1.dtype) 71 | @staticmethod 72 | def backward(ctx,grad_output): 73 | rij,rkj,rkl=ctx.rij,ctx.rkj,ctx.rkl 74 | rnk,rmj=ctx.nk,ctx.mj 75 | dkj=np.linalg.norm(rkj,axis=-1,keepdims=True) 76 | dmj=np.linalg.norm(rmj,axis=-1,keepdims=True) 77 | dnk=np.linalg.norm(rnk,axis=-1,keepdims=True) 78 | grad1=(dkj/((dmj*dmj)))*rmj 79 | grad4=-(dkj/((dnk*dnk)))*rnk 80 | 81 | grad2=( (rij*rkj).sum(-1,keepdims=True)/((dkj*dkj))-1 )*grad1 - (rkl*rkj).sum(-1,keepdims=True)/((dkj*dkj))*grad4 82 | grad3=( (rkl*rkj).sum(-1,keepdims=True)/((dkj*dkj))-1 )*grad4 - (rij*rkj).sum(-1,keepdims=True)/((dkj*dkj))*grad1 83 | 84 | grad1,grad2,grad3,grad4=torch.from_numpy(grad1),torch.from_numpy(grad2),torch.from_numpy(grad3),torch.from_numpy(grad4) 85 | return grad1*grad_output[:,None],grad2*grad_output[:,None],grad3*grad_output[:,None],grad4*grad_output[:,None] 86 | 87 | 88 | def dihedral(input1,input2,input3,input4): 89 | return torsion.apply(input1,input2,input3,input4) 90 | 91 | 92 | 93 | def angle(p0,p1,p2): 94 | # [b 3] 95 | b0=p0-p1 96 | b1=p2-p1 97 | 98 | b0=b0 / (torch.norm(b0,dim=-1,keepdim=True)+1e-08) 99 | b1=b1 / (torch.norm(b1,dim=-1,keepdim=True)+1e-08) 100 | recos=torch.sum(b0*b1,-1) 101 | recos=torch.clamp(recos,-0.9999,0.9999) 102 | #print(recos.shape) 103 | return torch.acos(recos) 104 | 105 | 106 | def rigidFrom3Points(x): 107 | x1,x2,x3 = x[:,0],x[:,1],x[:,2] 108 | v1=x3-x2 109 | v2=x1-x2 110 | e1=v1/(torch.norm(v1,dim=-1,keepdim=True) + 1e-06) 111 | u2=v2 - e1*(torch.einsum('bn,bn->b',e1,v2)[:,None]) 112 | e2 = u2/(torch.norm(u2,dim=-1,keepdim=True) + 1e-06) 113 | e3=torch.cross(e1,e2,dim=-1) 114 | 115 | return torch.stack([e1,e2,e3],dim=1) 116 | 117 | 118 | 119 | def Kabsch_rigid(bases,x1,x2,x3): 120 | ''' 121 | return the direction from to_q to from_p 122 | ''' 123 | the_dim=1 124 | to_q = torch.stack([x1,x2,x3],dim=the_dim) 125 | biasq=torch.mean(to_q,dim=the_dim,keepdim=True) 126 | q=to_q-biasq 127 | m = torch.einsum('bnz,bny->bzy',bases,q) 128 | u, s, v = torch.svd(m) 129 | vt = torch.transpose(v, 1, 2) 130 | det = torch.det(torch.matmul(u, vt)) 131 | det = det.view(-1, 1, 1) 132 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 133 | r = torch.matmul(u, vt) 134 | return r,biasq.squeeze() 135 | 136 | 137 | def Get_base(seq,basenpy_standard): 138 | base_num = basenpy_standard.shape[1] 139 | basenpy = np.zeros([len(seq),base_num,3]) 140 | seqnpy = np.array(list(seq)) 141 | basenpy[seqnpy=='A']=basenpy_standard[0] 142 | basenpy[seqnpy=='a']=basenpy_standard[0] 143 | 144 | basenpy[seqnpy=='G']=basenpy_standard[1] 145 | basenpy[seqnpy=='g']=basenpy_standard[1] 146 | 147 | basenpy[seqnpy=='C']=basenpy_standard[2] 148 | basenpy[seqnpy=='c']=basenpy_standard[2] 149 | 150 | basenpy[seqnpy=='U']=basenpy_standard[3] 151 | basenpy[seqnpy=='u']=basenpy_standard[3] 152 | 153 | basenpy[seqnpy=='T']=basenpy_standard[3] 154 | basenpy[seqnpy=='t']=basenpy_standard[3] 155 | return torch.from_numpy(basenpy).double() -------------------------------------------------------------------------------- /PotentialFold/rigid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def C4P_C3P(bp): 5 | if bp in ['A','a']: 6 | return torch.FloatTensor([0.4219, 0.7000, 1.2829]) 7 | if bp in ['G','g']: 8 | return torch.FloatTensor([0.4240, 0.6989, 1.2829]) 9 | if bp in ['C','c']: 10 | return torch.FloatTensor([0.4567, 0.6948, 1.2731]) 11 | if bp in ['U','u']: 12 | return torch.FloatTensor([0.4556, 0.6952, 1.2736]) 13 | 14 | 15 | 16 | def C3P_O3P(bp): 17 | if bp in ['A','a']: 18 | return torch.FloatTensor([0.3506, 1.3587, -0.1989]) 19 | if bp in ['G','g']: 20 | return torch.FloatTensor([0.3519, 1.3587, -0.2005]) 21 | if bp in ['C','c']: 22 | return torch.FloatTensor([0.3747, 1.3505, -0.2116]) 23 | if bp in ['U','u']: 24 | return torch.FloatTensor([0.3747, 1.3497, -0.2093]) 25 | 26 | 27 | def base_table(): 28 | base_dict={} 29 | base_dict['atoms'] =['N1','C2','O2','N2','N3','N4','C4','O4','C5','C6','O6','N6','N7','C8','N9'] 30 | base_dict['a_mask']=torch.FloatTensor( [ 1 , 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1]) 31 | base_dict['g_mask']=torch.FloatTensor( [ 1 , 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1]) 32 | base_dict['c_mask']=torch.FloatTensor( [ 1 , 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0]) 33 | base_dict['u_mask']=torch.FloatTensor( [ 1 , 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0]) 34 | return base_dict 35 | 36 | 37 | def side_mask(seq): 38 | base_dict = base_table() 39 | masks=[] 40 | for bp in seq: 41 | masks.append( torch.cat( [torch.FloatTensor([1]*7),base_dict[bp.lower()+'_mask']] ,dim=0) ) 42 | 43 | return torch.stack(masks,dim=0) 44 | 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DRfold2: Ab initio RNA structure prediction with composite language model and denoised end-to-end learning 2 | 3 | [![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue)](https://www.python.org/downloads/) 4 | [![PyTorch](https://img.shields.io/badge/PyTorch-1.11%2B-red)](https://pytorch.org/) 5 | 6 | ## Overview 7 | 8 | DRfold2 is a deep learning method for RNA structure prediction. At its core, DRfold2 utilizes the RNA Composite Language Model (RCLM), which provides enhanced full likelihood approximation capabilities to effectively capture co-evolutionary signals from unsupervised sequence data. 9 | 10 | ### Key Features 11 | 12 | - Advanced RNA Composite Language Model (RCLM) 13 | - End-to-end structure and geometry prediction 14 | - Optimization protocol 15 | 16 | 17 | ## Installation 18 | 19 | ### Prerequisites 20 | 21 | #### Minimal Requirements 22 | - Python (tested on 3.10.4, 3.11.4, and 3.11.7) 23 | - PyTorch (tested on 1.11, 2.01, and 2.21) 24 | - NumPy 25 | - SciPy 26 | - BioPython 27 | 28 | #### Optional Dependencies 29 | - OpenMM (required for structure refinement) - [Installation Guide](https://openmm.org/) 30 | 31 | ### Setup Instructions 32 | 33 | 1. Clone and navigate to the DRfold2 directory: 34 | ```bash 35 | git clone https://github.com/leeyang/DRfold2 36 | cd DRfold2 37 | ``` 38 | 39 | 2. Run the installation script: 40 | ```bash 41 | bash install.sh 42 | ``` 43 | This will download model weights ~1.3GB and install Arena. 44 | 45 | ## Usage 46 | 47 | ### Basic Structure Prediction 48 | 49 | For single model prediction: 50 | ```bash 51 | python DRfold_infer.py [input fasta file] [output_dir] 52 | ``` 53 | 54 | For multiple model prediction (up to 5 models): 55 | ```bash 56 | python DRfold_infer.py [input fasta file] [output_dir] 1 57 | ``` 58 | 59 | ### Parameters 60 | 61 | - `[input fasta file]`: Target sequence in FASTA format 62 | - `[output_dir]`: Directory for saving intermediate and final results 63 | - Final predictions will be saved as `[output dir]/relax/model_*.pdb` 64 | 65 | ### Structure Refinement (Optional) 66 | 67 | To further refine a predicted structure: 68 | ```bash 69 | python script/refine.py [input pdb] [output pdb] 70 | ``` 71 | 72 | ## Example Usage 73 | 74 | ```bash 75 | python DRfold_infer.py test/seq.fasta test/8fza_A/ 1 76 | ``` 77 | 78 | The final results can be found in `test/8fza_A/relax/`. 79 | 80 | **Note:** For long RNA sequences, you may want to clear intermediate results from `test/8fza_A/` to save space. 81 | 82 | ## Performance 83 | 84 | DRfold2 has been extensively tested on non-redundant test sets with various redundancy cut-offs, consistently demonstrating superior performance in: 85 | - 3D structure prediction 86 | - 2D base pair modeling 87 | - Co-evolutionary feature learning from unsupervised data 88 | 89 | ## Bug Reports and Issues 90 | 91 | Please report any issues or bugs on our [GitHub Issues page](https://github.com/leeyang/DRfold2/issues). 92 | 93 | ## Citation 94 | 95 | If you use DRfold2 in your research, please cite: 96 | ```bibtex 97 | @article{li2025drfold2, 98 | title={Ab initio RNA structure prediction with composite language model and denoised end-to-end learning}, 99 | author={Yang Li, Chenjie Feng, Xi Zhang, Yang Zhang.}, 100 | journal={}, 101 | year={2025} 102 | } 103 | ``` 104 | 105 | ## License 106 | 107 | Copyright (c) 2025 Yang Li 108 | 109 | Permission is hereby granted, free of charge, to any person obtaining a copy 110 | of this software and associated documentation files (the "Software"), to deal 111 | in the Software without restriction, including without limitation the rights 112 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 113 | copies of the Software, and to permit persons to whom the Software is 114 | furnished to do so, subject to the following conditions: 115 | 116 | The above copyright notice and this permission notice shall be included in all 117 | copies or substantial portions of the Software. 118 | 119 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 120 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 121 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 122 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 123 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 124 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 125 | SOFTWARE. 126 | -------------------------------------------------------------------------------- /cfg_95/EvoMSA.py: -------------------------------------------------------------------------------- 1 | from numpy import select 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import basic 6 | import math 7 | import os 8 | 9 | 10 | expdir=os.path.dirname(os.path.abspath(__file__)) 11 | lines = open(os.path.join(expdir,'newconfig')).readlines() 12 | attdrop = lines[0].strip().split()[-1] == '1' 13 | denoisee2e = lines[1].strip().split()[-1] == '1' 14 | ss_type = lines[2].strip().split()[-1] 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | def SignedSqrt( x): 24 | x = torch.sqrt(torch.relu(x)) - torch.sqrt(torch.relu(-x)) 25 | return x 26 | class MSARow(nn.Module): 27 | def __init__(self,m_dim,z_dim,N_head=8,c=8): 28 | super(MSARow,self).__init__() 29 | self.N_head = N_head 30 | self.c = c 31 | self.sq_c = 1/math.sqrt(c) 32 | self.norm1=nn.LayerNorm(m_dim) 33 | self.qlinear = basic.LinearNoBias(m_dim,N_head*c) 34 | self.klinear = basic.LinearNoBias(m_dim,N_head*c) 35 | self.vlinear = basic.LinearNoBias(m_dim,N_head*c) 36 | self.norm_z = nn.LayerNorm(z_dim) 37 | self.zlinear = basic.LinearNoBias(z_dim,N_head) 38 | self.glinear = basic.Linear(m_dim,N_head*c) 39 | self.olinear = basic.Linear(N_head*c,m_dim) 40 | 41 | def forward(self,m,z): 42 | # m : N L 32 43 | N,L,D = m.shape 44 | m = self.norm1(m) 45 | q = self.qlinear(m).reshape(N,L,self.N_head,self.c) #s rq h c 46 | k = self.klinear(m).reshape(N,L,self.N_head,self.c) #s rv h c 47 | v = self.vlinear(m).reshape(N,L,self.N_head,self.c) 48 | b = self.zlinear(self.norm_z(z)) 49 | g = torch.sigmoid(self.glinear(m)).reshape(N,L,self.N_head,self.c) 50 | att=torch.einsum('bqhc,bvhc->bqvh',q,k) * (self.sq_c) + b[None,:,:,:] # rq rv h 51 | att=F.softmax(SignedSqrt(att),dim=2) 52 | if attdrop: 53 | if self.training: 54 | att = basic.DropAtt(att,dim=2) 55 | o = torch.einsum('bqvh,bvhc->bqhc',att,v) * g 56 | m_ = self.olinear(o.reshape(N,L,-1)) 57 | return m_ 58 | 59 | class MSACol(nn.Module): 60 | def __init__(self,m_dim,N_head=8,c=8): 61 | super(MSACol,self).__init__() 62 | self.N_head = N_head 63 | self.c = c 64 | self.sq_c = 1/math.sqrt(c) 65 | self.norm1=nn.LayerNorm(m_dim) 66 | self.qlinear = basic.LinearNoBias(m_dim,N_head*c) 67 | self.klinear = basic.LinearNoBias(m_dim,N_head*c) 68 | self.vlinear = basic.LinearNoBias(m_dim,N_head*c) 69 | 70 | self.glinear = basic.Linear(m_dim,N_head*c) 71 | self.olinear = basic.Linear(N_head*c,m_dim) 72 | 73 | def forward(self,m): 74 | # m : N L 32 75 | N,L,D = m.shape 76 | m = self.norm1(m) 77 | q = self.qlinear(m).reshape(N,L,self.N_head,self.c) #s rq h c 78 | k = self.klinear(m).reshape(N,L,self.N_head,self.c) #s rv h c 79 | v = self.vlinear(m).reshape(N,L,self.N_head,self.c) 80 | 81 | g = torch.sigmoid(self.glinear(m)).reshape(N,L,self.N_head,self.c) 82 | 83 | att=torch.einsum('slhc,tlhc->stlh',q,k) * (self.sq_c) # rq rv h 84 | att=F.softmax(SignedSqrt(att),dim=1) 85 | if attdrop: 86 | if self.training: 87 | att = basic.DropAtt(att,dim=1) 88 | o = torch.einsum('stlh,tlhc->slhc',att,v) * g 89 | m_ = self.olinear(o.reshape(N,L,-1)) 90 | return m_ 91 | 92 | class MSATrans(nn.Module): 93 | def __init__(self,m_dim,c_expand=2): 94 | super(MSATrans,self).__init__() 95 | self.c_expand=4 96 | self.m_dim=m_dim 97 | self.norm=nn.LayerNorm(m_dim) 98 | self.linear1 = basic.Linear(m_dim,m_dim*c_expand) 99 | self.linear2 = basic.Linear(m_dim*c_expand,m_dim) 100 | def forward(self,m): 101 | m = self.norm(m) 102 | m = self.linear1(m) 103 | m = self.linear2(F.relu(m)) 104 | return m 105 | 106 | class MSAOPM(nn.Module): 107 | def __init__(self,m_dim,z_dim,c=12): 108 | super(MSAOPM,self).__init__() 109 | self.m_dim=m_dim 110 | self.c=c 111 | self.norm=nn.LayerNorm(m_dim) 112 | self.linear1=basic.Linear(m_dim,c) 113 | self.linear2=basic.Linear(m_dim,c) 114 | self.linear3=basic.Linear(c*c,z_dim) 115 | def forward(self,m): 116 | N,L,D=m.shape 117 | o=self.norm(m) 118 | a=self.linear2(o) 119 | b=self.linear1(o) 120 | o = torch.einsum('nia,njb->nijab',a,b).mean(dim=0) 121 | o = self.linear3(o.reshape(L,L,-1)) 122 | return o 123 | 124 | 125 | 126 | 127 | 128 | 129 | if __name__ == "__main__": 130 | N=10 131 | L=30 132 | m_dim=16 133 | z_dim=8 134 | m=torch.rand(N,L,m_dim) 135 | z=torch.rand(L,L,z_dim) 136 | msarow=MSARow(m_dim,z_dim) 137 | msacol=MSACol(m_dim) 138 | msatrans=MSATrans(m_dim) 139 | msaopm=MSAOPM(m_dim,z_dim) 140 | y=msaopm(m) 141 | print(y.shape) -------------------------------------------------------------------------------- /cfg_95/Evoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import basic,EvoPair,EvoMSA 5 | import math,sys 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class EvoBlock(nn.Module): 10 | def __init__(self,m_dim,z_dim,docheck=False): 11 | super(EvoBlock,self).__init__() 12 | self.msa_row=EvoMSA.MSARow(m_dim,z_dim) 13 | self.msa_col=EvoMSA.MSACol(m_dim) 14 | self.msa_trans=EvoMSA.MSATrans(m_dim) 15 | 16 | self.msa_opm=EvoMSA.MSAOPM(m_dim,z_dim) 17 | 18 | self.pair_triout=EvoPair.TriOut(z_dim) 19 | self.pair_triin =EvoPair.TriIn(z_dim) 20 | self.pair_tristart=EvoPair.TriAttStart(z_dim) 21 | self.pair_triend =EvoPair.TriAttEnd(z_dim) 22 | self.pair_trans = EvoPair.PairTrans(z_dim) 23 | self.docheck=docheck 24 | if docheck: 25 | print('will do checkpoint') 26 | 27 | def layerfunc_msa_row(self,m,z): 28 | return self.msa_row(m,z) + m 29 | def layerfunc_msa_col(self,m): 30 | return self.msa_col(m) + m 31 | def layerfunc_msa_trans(self,m): 32 | return self.msa_trans(m) + m 33 | def layerfunc_msa_opm(self,m,z): 34 | return self.msa_opm(m) + z 35 | 36 | def layerfunc_pair_triout(self,z): 37 | return self.pair_triout(z) + z 38 | def layerfunc_pair_triin(self,z): 39 | return self.pair_triin(z) + z 40 | def layerfunc_pair_tristart(self,z): 41 | return self.pair_tristart(z) + z 42 | def layerfunc_pair_triend(self,z): 43 | return self.pair_triend(z) + z 44 | def layerfunc_pair_trans(self,z): 45 | return self.pair_trans(z) + z 46 | def forward(self,m,z): 47 | if True: 48 | m = m + self.msa_row(m,z) 49 | m = m + self.msa_col(m) 50 | m = m + self.msa_trans(m) 51 | z = z + self.msa_opm(m) 52 | z = z + self.pair_triout(z) 53 | z = z + self.pair_triin(z) 54 | z = z + self.pair_tristart(z) 55 | z = z + self.pair_triend(z) 56 | z = z + self.pair_trans(z) 57 | return m,z 58 | 59 | 60 | 61 | class Evoformer(nn.Module): 62 | def __init__(self,m_dim,z_dim,docheck=False): 63 | super(Evoformer,self).__init__() 64 | self.layers=[16] 65 | self.docheck=docheck 66 | if docheck: 67 | pass 68 | #print('will do checkpoint') 69 | self.evos=nn.ModuleList([EvoBlock(m_dim,z_dim,True) for i in range(self.layers[0])]) 70 | 71 | def layerfunc(self,layermodule,m,z): 72 | m_,z_=layermodule(m,z) 73 | return m_,z_ 74 | 75 | 76 | 77 | def forward_n(self,m,z,starti,endi): 78 | for i in range(starti,endi): 79 | #print(i) 80 | m,z=self.evos[i](m,z) 81 | return m,z 82 | def forward(self,m,z): 83 | 84 | 85 | m,z = checkpoint(self.forward_n,m,z,0,3) 86 | m,z = checkpoint(self.forward_n,m,z,3,6) 87 | m,z = checkpoint(self.forward_n,m,z,6,10) 88 | m,z = checkpoint(self.forward_n,m,z,10,13) 89 | m,z = checkpoint(self.forward_n,m,z,13,16) 90 | return m,z 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | if __name__ == "__main__": 99 | N=10 100 | L=30 101 | m_dim=16 102 | z_dim=8 103 | m=torch.rand(N,L,m_dim) 104 | z=torch.rand(L,L,z_dim) 105 | model = Evoformer(m_dim,z_dim) 106 | m,z=model(m,z) 107 | print(model.parameters()) 108 | for param in model.parameters(): 109 | print(type(param), param.size()) 110 | print(m.shape,z.shape) -------------------------------------------------------------------------------- /cfg_95/IPA.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import basic 6 | import math 7 | import os 8 | expdir=os.path.dirname(os.path.abspath(__file__)) 9 | lines = open(os.path.join(expdir,'newconfig')).readlines() 10 | attdrop = lines[0].strip().split()[-1] == '1' 11 | denoisee2e = lines[1].strip().split()[-1] == '1' 12 | ss_type = lines[2].strip().split()[-1] 13 | class InvariantPointAttention(nn.Module): 14 | def __init__(self,dim_in,dim_z,N_head=8,c=16,N_query=4,N_p_values=6,) -> None: 15 | super(InvariantPointAttention,self).__init__() 16 | self.dim_in=dim_in 17 | self.dim_z=dim_z 18 | self.N_head =N_head 19 | self.c=c 20 | self.c_squ = 1.0/math.sqrt(c) 21 | self.W_c = math.sqrt(2.0/(9*N_query)) 22 | self.W_L = math.sqrt(1.0/3) 23 | self.N_query=N_query 24 | self.N_p_values=N_p_values 25 | self.liner_nb_q1=basic.LinearNoBias(dim_in,self.c*N_head) 26 | self.liner_nb_k1=basic.LinearNoBias(dim_in,self.c*N_head) 27 | self.liner_nb_v1=basic.LinearNoBias(dim_in,self.c*N_head) 28 | 29 | self.liner_nb_q2=basic.LinearNoBias(dim_in,N_head*N_query*3) 30 | self.liner_nb_k2=basic.LinearNoBias(dim_in,N_head*N_query*3) 31 | 32 | self.liner_nb_v3=basic.LinearNoBias(dim_in,N_head*N_p_values*3) 33 | 34 | self.liner_nb_z=basic.LinearNoBias(dim_z,N_head) 35 | self.lastlinear1=basic.Linear(N_head*dim_z,dim_in) 36 | self.lastlinear2=basic.Linear(N_head*c,dim_in) 37 | self.lastlinear3=basic.Linear(N_head*N_p_values*3,dim_in) 38 | self.gama = nn.ParameterList([nn.Parameter(torch.zeros(N_head))]) 39 | self.cos_f=nn.CosineSimilarity(dim=-1) 40 | 41 | def forward(self,s,z,rot,trans): 42 | L=s.shape[0] 43 | q1=self.liner_nb_q1(s).reshape(L,self.N_head,self.c) # Lq, 44 | k1=self.liner_nb_k1(s).reshape(L,self.N_head,self.c) 45 | v1=self.liner_nb_v1(s).reshape(L,self.N_head,self.c) # lv,h,c 46 | 47 | attmap=torch.einsum('ihc,jhc->ijh',q1,k1) * self.c_squ # Lq,Lk_v,h 48 | bias_z=self.liner_nb_z(z) # L L h 49 | 50 | q2 = self.liner_nb_q2(s).reshape(L,self.N_head,self.N_query,3) 51 | k2 = self.liner_nb_k2(s).reshape(L,self.N_head,self.N_query,3) 52 | 53 | v3 = self.liner_nb_v3(s).reshape(L,self.N_head,self.N_p_values,3) 54 | 55 | q2 = basic.IPA_transform(q2,rot,trans) # Lq,self.N_head,self.N_query,3 56 | k2 = basic.IPA_transform(k2,rot,trans) # Lk,self.N_head,self.N_query,3 57 | 58 | dismap=((q2[:,None,:,:,:] - k2[None,:,:,:,:])**2).sum([3,4]) ## Lq,Lk, self.N_head, 59 | #dismap=dismap - (self.cos_f(q2[:,None,:,:,:] , k2[None,:,:,:,:])).sum(3) 60 | attmap = attmap + bias_z - F.softplus(self.gama[0])[None,None,:]*dismap*self.W_c*0.5 61 | #print(torch.max(attmap*self.W_L),torch.min(attmap)*self.W_L) 62 | #attmap = F.softmax( torch.clamp(attmap*self.W_L,-5,5),dim=1 ) # Lk dim, [Lq,Lk, self.N_head] 63 | 64 | attmap = F.softmax( attmap*self.W_L,dim=1 ) # Lk dim, [Lq,Lk, self.N_head] 65 | if attdrop: 66 | if self.training: 67 | attmap = basic.DropAtt(attmap,dim=1) 68 | o1 = (attmap[:,:,:,None] * z[:,:,None,:]).sum(1) # Lq, N_head, c_z 69 | o2 = torch.einsum('abc,dab->dbc',v1,attmap) # Lq, N_head, c 70 | o3 = basic.IPA_transform(v3,rot,trans) # Lv, h, p* ,3 71 | o3 = basic.IPA_inverse_transform( torch.einsum('vhpt,gvh->ghpt',o3,attmap),rot,trans) #Lv, h, p* ,3 72 | 73 | return self.lastlinear1(o1.reshape(L,-1)) + self.lastlinear2(o2.reshape(L,-1)) + self.lastlinear3(o3.reshape(L,-1)) 74 | 75 | 76 | 77 | 78 | 79 | 80 | if __name__ == "__main__": 81 | dim_in,dim_z = 8,4 82 | L = 10 83 | ipa = InvariantPointAttention(dim_in,dim_z) 84 | s=torch.rand(L,dim_in) 85 | z=torch.rand(L,L,dim_z) 86 | rot=(torch.eye(3)[None,:,:]).repeat(L,1,1) 87 | trans=torch.rand(L,3) 88 | 89 | out=ipa(s,z,rot,trans) 90 | print(out) 91 | print(out.shape) 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /cfg_95/RNALM2/EvoMSA.py: -------------------------------------------------------------------------------- 1 | from numpy import select 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from RNALM2 import basic 6 | import math 7 | 8 | def SignedSqrt( x): 9 | x = torch.sqrt(torch.relu(x)) - torch.sqrt(torch.relu(-x)) 10 | return x 11 | class MSARow(nn.Module): 12 | def __init__(self,m_dim,z_dim,N_head=8,c=8): 13 | super(MSARow,self).__init__() 14 | self.N_head = N_head 15 | self.c = c 16 | self.sq_c = 1/math.sqrt(c) 17 | self.norm1=nn.LayerNorm(m_dim) 18 | self.qlinear = basic.LinearNoBias(m_dim,N_head*c) 19 | self.klinear = basic.LinearNoBias(m_dim,N_head*c) 20 | self.vlinear = basic.LinearNoBias(m_dim,N_head*c) 21 | self.norm_z = nn.LayerNorm(z_dim) 22 | self.zlinear = basic.LinearNoBias(z_dim,N_head) 23 | self.glinear = basic.Linear(m_dim,N_head*c) 24 | self.olinear = basic.Linear(N_head*c,m_dim) 25 | 26 | def forward(self,m,z): 27 | # m : N L 32 28 | N,L,D = m.shape 29 | m = self.norm1(m) 30 | q = self.qlinear(m).reshape(N,L,self.N_head,self.c) #s rq h c 31 | k = self.klinear(m).reshape(N,L,self.N_head,self.c) #s rv h c 32 | v = self.vlinear(m).reshape(N,L,self.N_head,self.c) 33 | b = self.zlinear(self.norm_z(z)) 34 | g = torch.sigmoid(self.glinear(m)).reshape(N,L,self.N_head,self.c) 35 | att=torch.einsum('bqhc,bvhc->bqvh',q,k) * (self.sq_c) + b[None,:,:,:] # rq rv h 36 | att=F.softmax(SignedSqrt(att),dim=2) 37 | o = torch.einsum('bqvh,bvhc->bqhc',att,v) * g 38 | m_ = self.olinear(o.reshape(N,L,-1)) 39 | return m_ 40 | 41 | class MSACol(nn.Module): 42 | def __init__(self,m_dim,N_head=8,c=8): 43 | super(MSACol,self).__init__() 44 | self.N_head = N_head 45 | self.c = c 46 | self.sq_c = 1/math.sqrt(c) 47 | self.norm1=nn.LayerNorm(m_dim) 48 | self.qlinear = basic.LinearNoBias(m_dim,N_head*c) 49 | self.klinear = basic.LinearNoBias(m_dim,N_head*c) 50 | self.vlinear = basic.LinearNoBias(m_dim,N_head*c) 51 | 52 | self.glinear = basic.Linear(m_dim,N_head*c) 53 | self.olinear = basic.Linear(N_head*c,m_dim) 54 | 55 | def forward(self,m): 56 | # m : N L 32 57 | N,L,D = m.shape 58 | m = self.norm1(m) 59 | q = self.qlinear(m).reshape(N,L,self.N_head,self.c) #s rq h c 60 | k = self.klinear(m).reshape(N,L,self.N_head,self.c) #s rv h c 61 | v = self.vlinear(m).reshape(N,L,self.N_head,self.c) 62 | 63 | g = torch.sigmoid(self.glinear(m)).reshape(N,L,self.N_head,self.c) 64 | 65 | att=torch.einsum('slhc,tlhc->stlh',q,k) * (self.sq_c) # rq rv h 66 | att=F.softmax(SignedSqrt(att),dim=1) 67 | o = torch.einsum('stlh,tlhc->slhc',att,v) * g 68 | m_ = self.olinear(o.reshape(N,L,-1)) 69 | return m_ 70 | 71 | class MSATrans(nn.Module): 72 | def __init__(self,m_dim,c_expand=2): 73 | super(MSATrans,self).__init__() 74 | self.c_expand=4 75 | self.m_dim=m_dim 76 | self.norm=nn.LayerNorm(m_dim) 77 | self.linear1 = basic.Linear(m_dim,m_dim*c_expand) 78 | self.linear2 = basic.Linear(m_dim*c_expand,m_dim) 79 | def forward(self,m): 80 | m = self.norm(m) 81 | m = self.linear1(m) 82 | m = self.linear2(F.relu(m)) 83 | return m 84 | 85 | class MSAOPM(nn.Module): 86 | def __init__(self,m_dim,z_dim,c=12): 87 | super(MSAOPM,self).__init__() 88 | self.m_dim=m_dim 89 | self.c=c 90 | self.norm=nn.LayerNorm(m_dim) 91 | self.linear1=basic.Linear(m_dim,c) 92 | self.linear2=basic.Linear(m_dim,c) 93 | self.linear3=basic.Linear(c*c,z_dim) 94 | def forward(self,m): 95 | N,L,D=m.shape 96 | o=self.norm(m) 97 | a=self.linear2(o) 98 | b=self.linear1(o) 99 | o = torch.einsum('nia,njb->nijab',a,b).mean(dim=0) 100 | o = self.linear3(o.reshape(L,L,-1)) 101 | return o 102 | 103 | 104 | 105 | 106 | 107 | 108 | if __name__ == "__main__": 109 | N=10 110 | L=30 111 | m_dim=16 112 | z_dim=8 113 | m=torch.rand(N,L,m_dim) 114 | z=torch.rand(L,L,z_dim) 115 | msarow=MSARow(m_dim,z_dim) 116 | msacol=MSACol(m_dim) 117 | msatrans=MSATrans(m_dim) 118 | msaopm=MSAOPM(m_dim,z_dim) 119 | y=msaopm(m) 120 | print(y.shape) -------------------------------------------------------------------------------- /cfg_95/RNALM2/EvoPair.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from RNALM2 import basic 5 | import math 6 | 7 | def SignedSqrt( x): 8 | x = torch.sqrt(torch.relu(x)) - torch.sqrt(torch.relu(-x)) 9 | return x 10 | class TriOut(nn.Module): 11 | def __init__(self,z_dim,c=32): 12 | super(TriOut,self).__init__() 13 | self.z_dim = z_dim 14 | self.norm =nn.LayerNorm(z_dim) 15 | self.onorm =nn.LayerNorm(c) 16 | self.alinear=basic.Linear(z_dim,c) 17 | self.blinear=basic.Linear(z_dim,c) 18 | self.aglinear=basic.Linear(z_dim,c) 19 | self.bglinear=basic.Linear(z_dim,c) 20 | self.glinear =basic.Linear(z_dim,z_dim) 21 | self.olinear=basic.Linear(c,z_dim) 22 | 23 | def forward(self,z): 24 | z = self.norm(z) 25 | a = self.alinear(z) * torch.sigmoid(self.aglinear(z)) 26 | b = self.blinear(z) * torch.sigmoid(self.bglinear(z)) 27 | o = torch.einsum('ilc,jlc->ijc',a,b) 28 | o = self.onorm(o) 29 | o = self.olinear(o) 30 | o = o * torch.sigmoid(self.glinear(z)) 31 | return o 32 | 33 | class TriIn(nn.Module): 34 | def __init__(self,z_dim,c=32): 35 | super(TriIn,self).__init__() 36 | self.z_dim = z_dim 37 | self.norm =nn.LayerNorm(z_dim) 38 | self.onorm =nn.LayerNorm(c) 39 | self.alinear=basic.Linear(z_dim,c) 40 | self.blinear=basic.Linear(z_dim,c) 41 | self.aglinear=basic.Linear(z_dim,c) 42 | self.bglinear=basic.Linear(z_dim,c) 43 | self.glinear =basic.Linear(z_dim,z_dim) 44 | self.olinear=basic.Linear(c,z_dim) 45 | 46 | def forward(self,z): 47 | z = self.norm(z) 48 | a = self.alinear(z) * torch.sigmoid(self.aglinear(z)) 49 | b = self.blinear(z) * torch.sigmoid(self.bglinear(z)) 50 | o = torch.einsum('lic,ljc->ijc',a,b) 51 | o = self.onorm(o) 52 | o = self.olinear(o) 53 | o = o * torch.sigmoid(self.glinear(z)) 54 | return o 55 | 56 | 57 | class TriAttStart(nn.Module): 58 | def __init__(self,z_dim,N_head=4,c=8): 59 | super(TriAttStart,self).__init__() 60 | self.z_dim = z_dim 61 | self.N_head = N_head 62 | self.c = c 63 | self.sq_c = 1/math.sqrt(c) 64 | self.norm=nn.LayerNorm(z_dim) 65 | self.qlinear=basic.Linear(z_dim,c*N_head) 66 | self.klinear=basic.Linear(z_dim,c*N_head) 67 | self.vlinear=basic.Linear(z_dim,c*N_head) 68 | self.blinear=basic.Linear(z_dim,N_head) 69 | self.glinear=basic.Linear(z_dim,c*N_head) 70 | self.olinear=basic.Linear(c*N_head,z_dim) 71 | 72 | def forward(self,z_): 73 | L1,L2,D=z_.shape 74 | z = self.norm(z_) 75 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 76 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 77 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 78 | b = self.blinear(z) 79 | att = torch.einsum('blhc,bkhc->blkh',q,k)*self.sq_c + b[None,:,:,:] 80 | att = F.softmax(SignedSqrt(att),dim=2) 81 | o = torch.einsum('blkh,bkhc->blhc',att,v) 82 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 83 | o = self.olinear(o) 84 | return o 85 | 86 | class TriAttEnd(nn.Module): 87 | def __init__(self,z_dim,N_head=4,c=8): 88 | super(TriAttEnd,self).__init__() 89 | self.z_dim = z_dim 90 | self.N_head = N_head 91 | self.c = c 92 | self.sq_c = 1/math.sqrt(c) 93 | self.norm=nn.LayerNorm(z_dim) 94 | self.qlinear=basic.Linear(z_dim,c*N_head) 95 | self.klinear=basic.Linear(z_dim,c*N_head) 96 | self.vlinear=basic.Linear(z_dim,c*N_head) 97 | self.blinear=basic.Linear(z_dim,N_head) 98 | self.glinear=basic.Linear(z_dim,c*N_head) 99 | self.olinear=basic.Linear(c*N_head,z_dim) 100 | 101 | def forward(self,z_): 102 | L1,L2,D=z_.shape 103 | z = self.norm(z_) 104 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 105 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 106 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 107 | b = self.blinear(z) 108 | att = torch.einsum('blhc,kbhc->blkh',q,k)*self.sq_c + b[None,:,:,:].permute(0,2,1,3) 109 | att = F.softmax(SignedSqrt(att),dim=2) 110 | o = torch.einsum('blkh,klhc->blhc',att,v) 111 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 112 | o = self.olinear(o) 113 | return o 114 | def forward2(self,z_): 115 | z = z_.permute(1,0,2) 116 | L1,L2,D=z_.shape 117 | z = self.norm(z_) 118 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 119 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 120 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 121 | b = self.blinear(z) 122 | att = torch.einsum('blhc,bkhc->blkh',q,k)*self.sq_c + b[None,:,:,:] 123 | att = F.softmax(att,dim=2) 124 | o = torch.einsum('blkh,bkhc->blhc',att,v) 125 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 126 | o = self.olinear(o) 127 | o = o.permute(1,0,2) 128 | return o 129 | class PairTrans(nn.Module): 130 | def __init__(self,z_dim,c_expand=2): 131 | super(PairTrans,self).__init__() 132 | self.z_dim=z_dim 133 | self.c_expand=c_expand 134 | self.norm = nn.LayerNorm(z_dim) 135 | self.linear1=basic.Linear(z_dim,z_dim*c_expand) 136 | self.linear2=basic.Linear(z_dim*c_expand,z_dim) 137 | def forward(self,z): 138 | a = self.linear1(self.norm(z)) 139 | a = self.linear2(F.relu(a)) 140 | return a 141 | 142 | 143 | 144 | 145 | if __name__ == "__main__": 146 | N=10 147 | L=30 148 | m_dim=16 149 | z_dim=8 150 | m=torch.rand(N,L,m_dim) 151 | z=torch.rand(L,L,z_dim) 152 | 153 | tr1=TriAttEnd(z_dim) 154 | tr2=PairTrans(z_dim) 155 | y=tr1(z) 156 | y2=tr1.forward2(z) 157 | y3=tr2(z) 158 | print(y3.shape) 159 | 160 | -------------------------------------------------------------------------------- /cfg_95/RNALM2/Evoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from RNALM2 import basic,EvoPair,EvoMSA 5 | import math,sys 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class EvoBlock(nn.Module): 10 | def __init__(self,m_dim,z_dim,docheck=False): 11 | super(EvoBlock,self).__init__() 12 | N_head = 16 13 | c = 16 14 | self.msa_row=EvoMSA.MSARow(m_dim,z_dim,N_head,c) 15 | self.msa_col=EvoMSA.MSACol(m_dim,N_head,c) 16 | self.msa_trans=EvoMSA.MSATrans(m_dim) 17 | 18 | self.msa_opm=EvoMSA.MSAOPM(m_dim,z_dim) 19 | 20 | self.pair_triout=EvoPair.TriOut(z_dim,72) 21 | self.pair_triin =EvoPair.TriIn(z_dim,72) 22 | self.pair_tristart=EvoPair.TriAttStart(z_dim) 23 | self.pair_triend =EvoPair.TriAttEnd(z_dim) 24 | self.pair_trans = EvoPair.PairTrans(z_dim) 25 | self.docheck=docheck 26 | if docheck: 27 | print('will do checkpoint') 28 | 29 | def layerfunc_msa_row(self,m,z): 30 | return self.msa_row(m,z) + m 31 | def layerfunc_msa_col(self,m): 32 | return self.msa_col(m) + m 33 | def layerfunc_msa_trans(self,m): 34 | return self.msa_trans(m) + m 35 | def layerfunc_msa_opm(self,m,z): 36 | return self.msa_opm(m) + z 37 | 38 | def layerfunc_pair_triout(self,z): 39 | return self.pair_triout(z) + z 40 | def layerfunc_pair_triin(self,z): 41 | return self.pair_triin(z) + z 42 | def layerfunc_pair_tristart(self,z): 43 | return self.pair_tristart(z) + z 44 | def layerfunc_pair_triend(self,z): 45 | return self.pair_triend(z) + z 46 | def layerfunc_pair_trans(self,z): 47 | return self.pair_trans(z) + z 48 | def forward(self,m,z): 49 | if True: 50 | m = m + self.msa_row(m,z) 51 | m = m + self.msa_col(m) 52 | m = m + self.msa_trans(m) 53 | z = z + self.msa_opm(m) 54 | z = z + self.pair_triout(z) 55 | z = z + self.pair_triin(z) 56 | #z = z + self.pair_tristart(z) 57 | #z = z + self.pair_triend(z) 58 | z = z + self.pair_trans(z) 59 | return m,z 60 | else: 61 | m=checkpoint(self.layerfunc_msa_row,m,z) 62 | m=checkpoint(self.layerfunc_msa_col,m) 63 | m=checkpoint(self.layerfunc_msa_trans,m) 64 | z=checkpoint(self.layerfunc_msa_opm,m,z) 65 | 66 | z=checkpoint(self.layerfunc_pair_triout,z) 67 | z=checkpoint(self.layerfunc_pair_triin,z) 68 | z=checkpoint(self.layerfunc_pair_tristart,z) 69 | z=checkpoint(self.layerfunc_pair_triend,z) 70 | z=checkpoint(self.layerfunc_pair_trans,z) 71 | 72 | return m,z 73 | 74 | 75 | class Evoformer(nn.Module): 76 | def __init__(self,m_dim,z_dim,N_elayers=12,docheck=False): 77 | super(Evoformer,self).__init__() 78 | self.layers=[N_elayers] 79 | self.docheck=docheck 80 | if docheck: 81 | pass 82 | #print('will do checkpoint') 83 | self.evos=nn.ModuleList([EvoBlock(m_dim,z_dim,True) for i in range(self.layers[0])]) 84 | 85 | def layerfunc(self,layermodule,m,z): 86 | m_,z_=layermodule(m,z) 87 | return m_,z_ 88 | 89 | 90 | def forward(self,m,z): 91 | 92 | if True: 93 | #print('will do checkpoint in Evoformer') 94 | for i in range(self.layers[0]): 95 | #m,z=checkpoint(self.layerfunc,self.evos[i],m,z) 96 | m,z=self.evos[i](m,z) 97 | return m,z 98 | else: 99 | for i in range(self.layers[0]): 100 | m,z=self.evos[i](m,z) 101 | 102 | return m,z 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | if __name__ == "__main__": 113 | N=10 114 | L=30 115 | m_dim=16 116 | z_dim=8 117 | m=torch.rand(N,L,m_dim) 118 | z=torch.rand(L,L,z_dim) 119 | model = Evoformer(m_dim,z_dim) 120 | m,z=model(m,z) 121 | print(model.parameters()) 122 | for param in model.parameters(): 123 | print(type(param), param.size()) 124 | print(m.shape,z.shape) -------------------------------------------------------------------------------- /cfg_95/RNALM2/Model.py: -------------------------------------------------------------------------------- 1 | from numpy import select 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from RNALM2 import basic,Evoformer 6 | import math,sys 7 | from torch.utils.checkpoint import checkpoint 8 | import numpy as np 9 | 10 | 11 | def one_d(idx_, d, max_len=2056*8): 12 | idx = idx_[None] 13 | K = torch.arange(d//2).to(idx.device) 14 | sin_e = torch.sin(idx[..., None] * math.pi / (max_len**(2*K[None]/d))).to(idx.device) 15 | cos_e = torch.cos(idx[..., None] * math.pi / (max_len**(2*K[None]/d))).to(idx.device) 16 | return torch.cat([sin_e, cos_e], axis=-1)[0] 17 | 18 | 19 | 20 | 21 | 22 | 23 | class RNAembedding(nn.Module): 24 | def __init__(self,cfg): 25 | super(RNAembedding,self).__init__() 26 | self.s_in_dim=cfg['s_in_dim'] 27 | self.z_in_dim=cfg['z_in_dim'] 28 | self.s_dim=cfg['s_dim'] 29 | self.z_dim=cfg['z_dim'] 30 | self.qlinear =basic.Linear(self.s_in_dim+1,self.z_dim) 31 | self.klinear =basic.Linear(self.s_in_dim+1,self.z_dim) 32 | self.slinear =basic.Linear(self.s_in_dim+1,self.s_dim) 33 | self.zlinear =basic.Linear(self.z_in_dim+1,self.z_dim) 34 | 35 | self.poslinears = basic.Linear(64,self.s_dim) 36 | self.poslinearz = basic.Linear(64,self.z_dim) 37 | def forward(self,in_dict): 38 | # msa N L D, seq L D 39 | # mask: maksing, L, 1 means masked 40 | # aa: L x s_in_dim 41 | # ss: L x L x 2 42 | # idx: L (LongTensor) 43 | L = in_dict['aa'].shape[0] 44 | aamask = in_dict['mask'][:,None] 45 | zmask = in_dict['mask'][:,None] + in_dict['mask'][None,:] 46 | zmask[zmask>0.5]=1 47 | zmask = zmask[...,None] 48 | s = torch.cat([aamask,(1-aamask)*in_dict['aa']],dim=-1) 49 | sq=self.qlinear(s) 50 | sk=self.klinear(s) 51 | z=sq[None,:,:]+sk[:,None,:] 52 | seq_idx = in_dict['idx'][None] 53 | relative_pos = seq_idx[:, :, None] - seq_idx[:, None, :] 54 | relative_pos = relative_pos.reshape([1, -1]) 55 | relative_pos =one_d(relative_pos,64) 56 | z = z + self.poslinearz( relative_pos.reshape([1, L, L, -1])[0] ) 57 | 58 | s = self.slinear(s) + self.poslinears( one_d(in_dict['idx'], 64) ) 59 | 60 | return s,z 61 | 62 | 63 | class RNA2nd(nn.Module): 64 | def __init__(self,cfg): 65 | super(RNA2nd,self).__init__() 66 | self.s_in_dim=cfg['s_in_dim'] 67 | self.z_in_dim=cfg['z_in_dim'] 68 | self.s_dim=cfg['s_dim'] 69 | self.z_dim=cfg['z_dim'] 70 | self.N_elayers =cfg['N_elayers'] 71 | self.emb = RNAembedding(cfg) 72 | self.evmodel=Evoformer.Evoformer(self.s_dim,self.z_dim,self.N_elayers) 73 | self.seq_head = basic.Linear(self.s_dim,self.s_in_dim) 74 | self.joint_head = basic.Linear(self.z_dim,self.s_in_dim*self.s_in_dim) 75 | 76 | 77 | 78 | 79 | def embedding(self,in_dict): 80 | s,z = self.emb(in_dict) 81 | s,z = self.evmodel(s[None,...],z) 82 | return s[0],z 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /cfg_95/RNALM2/basic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import random 5 | class Linear(nn.Module): 6 | def __init__(self,dim_in,dim_out): 7 | super(Linear,self).__init__() 8 | self.linear = nn.Linear(dim_in,dim_out) 9 | def forward(self,x): 10 | x = self.linear(x) 11 | return x 12 | 13 | 14 | class LinearNoBias(nn.Module): 15 | def __init__(self,dim_in,dim_out): 16 | super(LinearNoBias,self).__init__() 17 | self.linear = nn.Linear(dim_in,dim_out,bias=False) 18 | def forward(self,x): 19 | x = self.linear(x) 20 | return x 21 | 22 | 23 | 24 | def transform(k,rotation,translation): 25 | # K L x 3 26 | # rotation 27 | return torch.matmul(k,rotation) + translation 28 | 29 | 30 | def batch_transform(k,rotation,translation): 31 | # k: L 3 32 | # rotation: L 3 x 3 33 | # translation: L 3 34 | return torch.einsum('ba,bad->bd',k,rotation) + translation 35 | 36 | def batch_atom_transform(k,rotation,translation): 37 | # k: L N 3 38 | # rotation: L 3 x 3 39 | # translation: L 3 40 | return torch.einsum('bja,bad->bjd',k,rotation) + translation[:,None,:] 41 | 42 | def IPA_transform(k,rotation,translation): 43 | # k: L d1, d2, 3 44 | # rotation: L 3 x 3 45 | # translation: L 3 46 | return torch.einsum('bija,bad->bijd',k,rotation)+translation[:,None,None,:] 47 | 48 | def IPA_inverse_transform(k,rotation,translation): 49 | # k: L d1, d2, 3 50 | # rotation: L 3 x 3 51 | # translation: L 3 52 | return torch.einsum('bija,bad->bijd',k-translation[:,None,None,:],rotation.transpose(-1,-2)) 53 | 54 | def update_transform(t,tr,rotation,translation): 55 | return torch.einsum('bja,bad->bjd',t,rotation),torch.einsum('ba,bad->bd',tr,rotation) +translation 56 | 57 | 58 | def quat2rot(q,L): 59 | scale= ((q**2).sum(dim=-1,keepdim=True) +1) [:,:,None] 60 | u=torch.empty([L,3,3],device=q.device) 61 | u[:,0,0]=1*1+q[:,0]*q[:,0]-q[:,1]*q[:,1]-q[:,2]*q[:,2] 62 | u[:,0,1]=2*(q[:,0]*q[:,1]-1*q[:,2]) 63 | u[:,0,2]=2*(q[:,0]*q[:,2]+1*q[:,1]) 64 | u[:,1,0]=2*(q[:,0]*q[:,1]+1*q[:,2]) 65 | u[:,1,1]=1*1-q[:,0]*q[:,0]+q[:,1]*q[:,1]-q[:,2]*q[:,2] 66 | u[:,1,2]=2*(q[:,1]*q[:,2]-1*q[:,0]) 67 | u[:,2,0]=2*(q[:,0]*q[:,2]-1*q[:,1]) 68 | u[:,2,1]=2*(q[:,1]*q[:,2]+1*q[:,0]) 69 | u[:,2,2]=1*1-q[:,0]*q[:,0]-q[:,1]*q[:,1]+q[:,2]*q[:,2] 70 | return u/scale 71 | 72 | 73 | def rotation_x(sintheta,costheta,ones,zeros): 74 | # L x 1 75 | return torch.stack([torch.stack([ones, zeros, zeros]), 76 | torch.stack([zeros, costheta, sintheta]), 77 | torch.stack([zeros, -sintheta, costheta])]) 78 | def rotation_y(sintheta,costheta,ones,zeros): 79 | # L x 1 80 | return torch.stack([torch.stack([costheta, zeros, sintheta]), 81 | torch.stack([zeros, ones, zeros]), 82 | torch.stack([-sintheta, zeros, costheta])]) 83 | def rotation_z(sintheta,costheta,ones,zeros): 84 | # L x 1 85 | return torch.stack([torch.stack([costheta, sintheta, zeros]), 86 | torch.stack([-sintheta, costheta, zeros]), 87 | torch.stack([zeros, zeros, ones])]) 88 | def batch_rotation(k,rotation): 89 | # k: L 3 90 | # rotation: L 3 x 3 91 | # translation: L 3 92 | return torch.einsum('ba,bad->bd',k,rotation) 93 | 94 | def compute_cb(bl,sin_angle,cos_angle,sin_torsion,cos_torsion): 95 | L=bl.shape[0] 96 | ones=torch.ones(L,device=bl.device) 97 | zeros=torch.zeros(L,device=bl.device) 98 | cb=torch.stack([bl,zeros,zeros]).permute(1,0) 99 | rotz=rotation_z(sin_angle,cos_angle,ones,zeros).permute(2,0,1) 100 | rotx=rotation_x(sin_torsion,cos_torsion,ones,zeros).permute(2,0,1) 101 | cb=batch_rotation(cb,rotz) 102 | cb=batch_rotation(cb,rotx) 103 | return cb 104 | 105 | def rigidFrom3Points_(x1,x2,x3): 106 | v1=x3-x2 107 | v2=x1-x2 108 | e1=v1/(torch.norm(v1,dim=-1,keepdim=True) + 1e-03) 109 | u2=v2 - e1*(torch.einsum('bn,bn->b',e1,v2)[:,None]) 110 | e2 = u2/(torch.norm(u2,dim=-1,keepdim=True) + 1e-03) 111 | e3=torch.cross(e1,e2,dim=-1) 112 | 113 | return torch.stack([e1,e2,e3],dim=1),x2[:,:] 114 | def rigidFrom3Points(x1,x2,x3):# L 3 115 | the_dim=1 116 | x = torch.stack([x1,x2,x3],dim=the_dim) 117 | x_mean = torch.mean(x,dim=the_dim,keepdim=True) 118 | x = x - x_mean 119 | 120 | 121 | m = x.view(-1, 3, 3) 122 | u, s, v = torch.svd(m) 123 | vt = torch.transpose(v, 1, 2) 124 | det = torch.det(torch.matmul(u, vt)) 125 | det = det.view(-1, 1, 1) 126 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 127 | r = torch.matmul(u, vt) 128 | return r,x_mean.squeeze() 129 | def Kabsch_rigid(bases,x1,x2,x3): 130 | ''' 131 | return the direction from to_q to from_p 132 | ''' 133 | the_dim=1 134 | to_q = torch.stack([x1,x2,x3],dim=the_dim) 135 | biasq=torch.mean(to_q,dim=the_dim,keepdim=True) 136 | q=to_q-biasq 137 | m = torch.einsum('bnz,bny->bzy',bases,q) 138 | u, s, v = torch.svd(m) 139 | vt = torch.transpose(v, 1, 2) 140 | det = torch.det(torch.matmul(u, vt)) 141 | det = det.view(-1, 1, 1) 142 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 143 | r = torch.matmul(u, vt) 144 | return r,biasq.squeeze() 145 | def Generate_msa_mask(n,l): 146 | # 1: 15% mask out 147 | randommatrix=torch.rand(n,l) 148 | mask = (randommatrix <0.1).float() 149 | # 2 random a segment 150 | seqlength = int(l*0.1) 151 | sindex=round(random.random()*(l-seqlength)) 152 | endindex=min(l,sindex+seqlength) 153 | mask[:,sindex:endindex]=1 154 | return mask 155 | 156 | 157 | -------------------------------------------------------------------------------- /cfg_95/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/cfg_95/__init__.py -------------------------------------------------------------------------------- /cfg_95/base.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/cfg_95/base.npy -------------------------------------------------------------------------------- /cfg_95/basic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import random 5 | class Linear(nn.Module): 6 | def __init__(self,dim_in,dim_out): 7 | super(Linear,self).__init__() 8 | self.linear = nn.Linear(dim_in,dim_out) 9 | def forward(self,x): 10 | x = self.linear(x) 11 | return x 12 | 13 | 14 | class LinearNoBias(nn.Module): 15 | def __init__(self,dim_in,dim_out): 16 | super(LinearNoBias,self).__init__() 17 | self.linear = nn.Linear(dim_in,dim_out,bias=False) 18 | def forward(self,x): 19 | x = self.linear(x) 20 | return x 21 | 22 | def DropAtt(x,dim,droprate = 0.25): 23 | shapes = x.shape 24 | L = shapes[dim] 25 | num_of_dim = len(shapes) 26 | randmask = torch.rand(L).to(x.device) 27 | themask = randmaskbd',k,rotation) + translation 49 | 50 | def batch_atom_transform(k,rotation,translation): 51 | # k: L N 3 52 | # rotation: L 3 x 3 53 | # translation: L 3 54 | return torch.einsum('bja,bad->bjd',k,rotation) + translation[:,None,:] 55 | 56 | def IPA_transform(k,rotation,translation): 57 | # k: L d1, d2, 3 58 | # rotation: L 3 x 3 59 | # translation: L 3 60 | return torch.einsum('bija,bad->bijd',k,rotation)+translation[:,None,None,:] 61 | 62 | def IPA_inverse_transform(k,rotation,translation): 63 | # k: L d1, d2, 3 64 | # rotation: L 3 x 3 65 | # translation: L 3 66 | return torch.einsum('bija,bad->bijd',k-translation[:,None,None,:],rotation.transpose(-1,-2)) 67 | 68 | def update_transform(t,tr,rotation,translation): 69 | return torch.einsum('bja,bad->bjd',t,rotation),torch.einsum('ba,bad->bd',tr,rotation) +translation 70 | 71 | 72 | def quat2rot(q,L): 73 | scale= ((q**2).sum(dim=-1,keepdim=True) +1) [:,:,None] 74 | u=torch.empty([L,3,3],device=q.device) 75 | u[:,0,0]=1*1+q[:,0]*q[:,0]-q[:,1]*q[:,1]-q[:,2]*q[:,2] 76 | u[:,0,1]=2*(q[:,0]*q[:,1]-1*q[:,2]) 77 | u[:,0,2]=2*(q[:,0]*q[:,2]+1*q[:,1]) 78 | u[:,1,0]=2*(q[:,0]*q[:,1]+1*q[:,2]) 79 | u[:,1,1]=1*1-q[:,0]*q[:,0]+q[:,1]*q[:,1]-q[:,2]*q[:,2] 80 | u[:,1,2]=2*(q[:,1]*q[:,2]-1*q[:,0]) 81 | u[:,2,0]=2*(q[:,0]*q[:,2]-1*q[:,1]) 82 | u[:,2,1]=2*(q[:,1]*q[:,2]+1*q[:,0]) 83 | u[:,2,2]=1*1-q[:,0]*q[:,0]-q[:,1]*q[:,1]+q[:,2]*q[:,2] 84 | return u/scale 85 | 86 | 87 | def rotation_x(sintheta,costheta,ones,zeros): 88 | # L x 1 89 | return torch.stack([torch.stack([ones, zeros, zeros]), 90 | torch.stack([zeros, costheta, sintheta]), 91 | torch.stack([zeros, -sintheta, costheta])]) 92 | def rotation_y(sintheta,costheta,ones,zeros): 93 | # L x 1 94 | return torch.stack([torch.stack([costheta, zeros, sintheta]), 95 | torch.stack([zeros, ones, zeros]), 96 | torch.stack([-sintheta, zeros, costheta])]) 97 | def rotation_z(sintheta,costheta,ones,zeros): 98 | # L x 1 99 | return torch.stack([torch.stack([costheta, sintheta, zeros]), 100 | torch.stack([-sintheta, costheta, zeros]), 101 | torch.stack([zeros, zeros, ones])]) 102 | def batch_rotation(k,rotation): 103 | # k: L 3 104 | # rotation: L 3 x 3 105 | # translation: L 3 106 | return torch.einsum('ba,bad->bd',k,rotation) 107 | 108 | def compute_cb(bl,sin_angle,cos_angle,sin_torsion,cos_torsion): 109 | L=bl.shape[0] 110 | ones=torch.ones(L,device=bl.device) 111 | zeros=torch.zeros(L,device=bl.device) 112 | cb=torch.stack([bl,zeros,zeros]).permute(1,0) 113 | rotz=rotation_z(sin_angle,cos_angle,ones,zeros).permute(2,0,1) 114 | rotx=rotation_x(sin_torsion,cos_torsion,ones,zeros).permute(2,0,1) 115 | cb=batch_rotation(cb,rotz) 116 | cb=batch_rotation(cb,rotx) 117 | return cb 118 | 119 | def rigidFrom3Points_(x1,x2,x3): 120 | v1=x3-x2 121 | v2=x1-x2 122 | e1=v1/(torch.norm(v1,dim=-1,keepdim=True) + 1e-03) 123 | u2=v2 - e1*(torch.einsum('bn,bn->b',e1,v2)[:,None]) 124 | e2 = u2/(torch.norm(u2,dim=-1,keepdim=True) + 1e-03) 125 | e3=torch.cross(e1,e2,dim=-1) 126 | 127 | return torch.stack([e1,e2,e3],dim=1),x2[:,:] 128 | def rigidFrom3Points(x1,x2,x3):# L 3 129 | the_dim=1 130 | x = torch.stack([x1,x2,x3],dim=the_dim) 131 | x_mean = torch.mean(x,dim=the_dim,keepdim=True) 132 | x = x - x_mean 133 | 134 | 135 | m = x.view(-1, 3, 3) 136 | u, s, v = torch.svd(m) 137 | vt = torch.transpose(v, 1, 2) 138 | det = torch.det(torch.matmul(u, vt)) 139 | det = det.view(-1, 1, 1) 140 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 141 | r = torch.matmul(u, vt) 142 | return r,x_mean.squeeze() 143 | def Kabsch_rigid(bases,x1,x2,x3): 144 | ''' 145 | return the direction from to_q to from_p 146 | ''' 147 | the_dim=1 148 | to_q = torch.stack([x1,x2,x3],dim=the_dim) 149 | biasq=torch.mean(to_q,dim=the_dim,keepdim=True) 150 | q=to_q-biasq 151 | m = torch.einsum('bnz,bny->bzy',bases,q) 152 | u, s, v = torch.svd(m) 153 | vt = torch.transpose(v, 1, 2) 154 | det = torch.det(torch.matmul(u, vt)) 155 | det = det.view(-1, 1, 1) 156 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 157 | r = torch.matmul(u, vt) 158 | return r,biasq.squeeze() 159 | def Generate_msa_mask(n,l): 160 | # 1: 15% mask out 161 | randommatrix=torch.rand(n,l) 162 | mask = (randommatrix <0.1).float() 163 | # 2 random a segment 164 | seqlength = int(l*0.1) 165 | sindex=round(random.random()*(l-seqlength)) 166 | endindex=min(l,sindex+seqlength) 167 | mask[:,sindex:endindex]=1 168 | return mask 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /cfg_95/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os,math 3 | import numpy as np 4 | from torch.utils.data import Dataset, DataLoader 5 | from random import sample 6 | from numpy import float32 7 | import random 8 | from scipy.spatial.distance import cdist 9 | from subprocess import Popen, PIPE, STDOUT 10 | expdir=os.path.dirname(os.path.abspath(__file__)) 11 | 12 | code_standard = { 13 | 'A':'A','G':'G','C':'C','U':'U','a':'A','g':'G','c':'C','u':'U','T':'U','t':'U' 14 | } 15 | expdir=os.path.dirname(os.path.abspath(__file__)) 16 | parentdir = os.path.dirname(expdir) 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | def parse_seq(inseq): 27 | seqnpy=np.zeros(len(inseq)) 28 | seq1=np.array(list(inseq)) 29 | seqnpy[seq1=='A']=1 30 | seqnpy[seq1=='G']=2 31 | seqnpy[seq1=='C']=3 32 | seqnpy[seq1=='U']=4 33 | seqnpy[seq1=='T']=4 34 | return seqnpy 35 | 36 | 37 | 38 | 39 | def Get_base(seq,basenpy_standard): 40 | basenpy = np.zeros([len(seq),3,3]) 41 | seqnpy = np.array(list(seq)) 42 | basenpy[seqnpy=='A']=basenpy_standard[0] 43 | basenpy[seqnpy=='a']=basenpy_standard[0] 44 | 45 | basenpy[seqnpy=='G']=basenpy_standard[1] 46 | basenpy[seqnpy=='g']=basenpy_standard[1] 47 | 48 | basenpy[seqnpy=='C']=basenpy_standard[2] 49 | basenpy[seqnpy=='c']=basenpy_standard[2] 50 | 51 | basenpy[seqnpy=='U']=basenpy_standard[3] 52 | basenpy[seqnpy=='u']=basenpy_standard[3] 53 | 54 | basenpy[seqnpy=='T']=basenpy_standard[3] 55 | basenpy[seqnpy=='t']=basenpy_standard[3] 56 | return basenpy 57 | 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /cfg_95/newconfig: -------------------------------------------------------------------------------- 1 | attdrop: 1 2 | denoisee2e: 1 3 | ss_type: attention -------------------------------------------------------------------------------- /cfg_95/test_modeldir.py: -------------------------------------------------------------------------------- 1 | import random 2 | random.seed(0) 3 | import numpy as np 4 | np.random.seed(0) 5 | import os,sys,re,random 6 | from numpy import select 7 | import torch 8 | torch.manual_seed(0) 9 | torch.backends.cudnn.deterministic = True 10 | torch.backends.cudnn.benchmark = False 11 | expdir=os.path.dirname(os.path.abspath(__file__)) 12 | 13 | 14 | 15 | import torch.optim as opt 16 | from torch.nn import functional as F 17 | import data,util 18 | import EvoMSA2XYZ,basic 19 | import math 20 | import pickle 21 | Batch_size=3 22 | Num_cycle=3 23 | TEST_STEP=1000 24 | VISION_STEP=50 25 | device = sys.argv[1] 26 | 27 | 28 | # expdir=os.path.dirname(os.path.abspath(__file__)) 29 | # expround=expdir.split('_')[-1] 30 | # model_path=os.path.join(expdir,'others','models') 31 | # if not os.path.isdir(model_path): 32 | # try: 33 | # os.makedirs(model_path) 34 | # except: 35 | # pass 36 | # testdir=os.path.join(expdir,'others','preds') 37 | # if not os.path.isdir(testdir): 38 | # try: 39 | # os.makedirs(testdir) 40 | # except: 41 | # pass 42 | 43 | 44 | 45 | basenpy_standard= np.load( os.path.join(os.path.dirname(os.path.abspath(__file__)),'base.npy' ) ) 46 | def data_collect(pdb_seq): 47 | aa_type = data.parse_seq(pdb_seq) 48 | base = data.Get_base(pdb_seq,basenpy_standard) 49 | seq_idx = np.arange(len(pdb_seq)) + 1 50 | 51 | msa=aa_type[None,:] 52 | msa=torch.from_numpy(msa).to(device) 53 | msa=torch.cat([msa,msa],0) 54 | msa=F.one_hot(msa.long(),6).float() 55 | 56 | base_x = torch.from_numpy(base).float().to(device) 57 | seq_idx = torch.from_numpy(seq_idx).long().to(device) 58 | return msa,base_x,seq_idx 59 | predxs,plddts = model.pred(msa,seq_idx,ss,base_x,sample_1['alpha_0']) 60 | 61 | 62 | 63 | def classifier(infasta,out_prefix,model_dir): 64 | with torch.no_grad(): 65 | lines = open(infasta).readlines()[1:] 66 | seqs = [aline.strip() for aline in lines] 67 | seq = ''.join(seqs) 68 | msa,base_x,seq_idx = data_collect(seq) 69 | # seq_idx = np.genfromtxt(idxfile).astype(int) 70 | # seq_idx = torch.from_numpy(seq_idx).long().to(device) 71 | 72 | msa_dim=6+1 73 | m_dim,s_dim,z_dim = 64,64,64 74 | N_ensemble,N_cycle=3,8 75 | model=EvoMSA2XYZ.MSA2XYZ(msa_dim-1,msa_dim,N_ensemble,N_cycle,m_dim,s_dim,z_dim) 76 | model.to(device) 77 | model.eval() 78 | models = os.listdir( model_dir ) 79 | models = [amodel for amodel in models if 'model' in amodel and 'opt' not in amodel] 80 | 81 | 82 | models.sort() 83 | 84 | for amodel in models: 85 | 86 | #saved_model=os.path.join(expdir,'others','models',amodel) 87 | saved_model=os.path.join(model_dir,amodel) 88 | model.load_state_dict(torch.load(saved_model,map_location='cpu'),strict=True) 89 | ret = model.pred(msa,seq_idx,None,base_x,np.array(list(seq))) 90 | 91 | util.outpdb(ret['coor'],seq_idx,seq,out_prefix+f'{amodel}.pdb') 92 | #ret = {'plddt':ret['plddt']} 93 | pickle.dump(ret,open(out_prefix+f'{amodel}.ret','wb')) 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | if __name__ == '__main__': 102 | infasta,out_prefix,model_dir = sys.argv[2],sys.argv[3],sys.argv[4] 103 | classifier(infasta,out_prefix,model_dir) -------------------------------------------------------------------------------- /cfg_95/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from subprocess import Popen, PIPE, STDOUT 4 | import os,sys 5 | 6 | 7 | 8 | def outpdb(coor,seq_idx,seq,savefile,start=0,end=10000,energystr=''): 9 | #rama=torch.cat([rama.view(self.L,2),self.betas],dim=-1) 10 | L = coor.shape[0] 11 | 12 | Atom_name=[' P '," C4'",' N1 '] 13 | Other_Atom_name = [" O5'"," C5'"," C3'"," O3'"," C1'"] 14 | other_last_name = ['O',"C","C","O","C"] 15 | 16 | 17 | last_name=['P','C','N'] 18 | wstr=[f'REMARK {str(energystr)}'] 19 | templet='%6s%5d %4s %3s %1s%4d %8.3f%8.3f%8.3f%6.2f%6.2f %2s%2s' 20 | count=1 21 | for i in range(L): 22 | if seq[i] in ['a','g','A','G']: 23 | Atom_name = [' P '," C4'",' N9 '] 24 | #atoms = ['P','C4'] 25 | 26 | elif seq[i] in ['c','u','C','U']: 27 | Atom_name = [' P '," C4'",' N1 '] 28 | for j in range(coor.shape[1]): 29 | outs=('ATOM ',count,Atom_name[j],seq[i],'A',seq_idx[i],coor[i][j][0],coor[i][j][1],coor[i][j][2],0,0,last_name[j],'') 30 | #outs=('ATOM ',count,Atom_name[j],'ALA','A',i+1,coor_np[i][j][0],coor_np[i][j][1],coor_np[i][j][2],1.0,90,last_name[j],'') 31 | #print(outs) 32 | if i>=start-1 and i < end: 33 | wstr.append(templet % outs) 34 | 35 | # for j in range(other_np.shape[1]): 36 | # outs=('ATOM ',count,Other_Atom_name[j],self.seq[i],'A',i+1,other_np[i][j][0],other_np[i][j][1],other_np[i][j][2],0,0,other_last_name[j],'') 37 | # #outs=('ATOM ',count,Atom_name[j],'ALA','A',i+1,coor_np[i][j][0],coor_np[i][j][1],coor_np[i][j][2],1.0,90,last_name[j],'') 38 | # #print(outs) 39 | # if i>=start-1 and i < end: 40 | # wstr.append(templet % outs) 41 | count+=1 42 | wstr.append('TER') 43 | wstr='\n'.join(wstr) 44 | wfile=open(savefile,'w') 45 | wfile.write(wstr) 46 | wfile.close() 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /cfg_96/Evoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import basic,EvoPair,EvoMSA 5 | import math,sys 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class EvoBlock(nn.Module): 10 | def __init__(self,m_dim,z_dim,docheck=False): 11 | super(EvoBlock,self).__init__() 12 | self.msa_row=EvoMSA.MSARow(m_dim,z_dim) 13 | self.msa_col=EvoMSA.MSACol(m_dim) 14 | self.msa_trans=EvoMSA.MSATrans(m_dim) 15 | 16 | self.msa_opm=EvoMSA.MSAOPM(m_dim,z_dim) 17 | 18 | self.pair_triout=EvoPair.TriOut(z_dim) 19 | self.pair_triin =EvoPair.TriIn(z_dim) 20 | self.pair_tristart=EvoPair.TriAttStart(z_dim) 21 | self.pair_triend =EvoPair.TriAttEnd(z_dim) 22 | self.pair_trans = EvoPair.PairTrans(z_dim) 23 | self.docheck=docheck 24 | if docheck: 25 | print('will do checkpoint') 26 | 27 | def layerfunc_msa_row(self,m,z): 28 | return self.msa_row(m,z) + m 29 | def layerfunc_msa_col(self,m): 30 | return self.msa_col(m) + m 31 | def layerfunc_msa_trans(self,m): 32 | return self.msa_trans(m) + m 33 | def layerfunc_msa_opm(self,m,z): 34 | return self.msa_opm(m) + z 35 | 36 | def layerfunc_pair_triout(self,z): 37 | return self.pair_triout(z) + z 38 | def layerfunc_pair_triin(self,z): 39 | return self.pair_triin(z) + z 40 | def layerfunc_pair_tristart(self,z): 41 | return self.pair_tristart(z) + z 42 | def layerfunc_pair_triend(self,z): 43 | return self.pair_triend(z) + z 44 | def layerfunc_pair_trans(self,z): 45 | return self.pair_trans(z) + z 46 | def forward(self,m,z): 47 | if True: 48 | m = m + self.msa_row(m,z) 49 | m = m + self.msa_col(m) 50 | m = m + self.msa_trans(m) 51 | z = z + self.msa_opm(m) 52 | z = z + self.pair_triout(z) 53 | z = z + self.pair_triin(z) 54 | z = z + self.pair_tristart(z) 55 | z = z + self.pair_triend(z) 56 | z = z + self.pair_trans(z) 57 | return m,z 58 | else: 59 | m=checkpoint(self.layerfunc_msa_row,m,z) 60 | m=checkpoint(self.layerfunc_msa_col,m) 61 | m=checkpoint(self.layerfunc_msa_trans,m) 62 | z=checkpoint(self.layerfunc_msa_opm,m,z) 63 | 64 | z=checkpoint(self.layerfunc_pair_triout,z) 65 | z=checkpoint(self.layerfunc_pair_triin,z) 66 | z=checkpoint(self.layerfunc_pair_tristart,z) 67 | z=checkpoint(self.layerfunc_pair_triend,z) 68 | z=checkpoint(self.layerfunc_pair_trans,z) 69 | 70 | return m,z 71 | 72 | 73 | class Evoformer(nn.Module): 74 | def __init__(self,m_dim,z_dim,docheck=False): 75 | super(Evoformer,self).__init__() 76 | self.layers=[16] 77 | self.docheck=docheck 78 | if docheck: 79 | pass 80 | #print('will do checkpoint') 81 | self.evos=nn.ModuleList([EvoBlock(m_dim,z_dim,True) for i in range(self.layers[0])]) 82 | 83 | def layerfunc(self,layermodule,m,z): 84 | m_,z_=layermodule(m,z) 85 | return m_,z_ 86 | 87 | 88 | # def forward(self,m,z): 89 | 90 | # if True: 91 | # #print('will do checkpoint in Evoformer') 92 | # for i in range(self.layers[0]): 93 | # m,z=checkpoint(self.layerfunc,self.evos[i],m,z) 94 | # #m,z=self.evos[i](m,z) 95 | # return m,z 96 | # else: 97 | # for i in range(self.layers[0]): 98 | # m,z=self.evos[i](m,z) 99 | 100 | # return m,z 101 | def forward_n(self,m,z,starti,endi): 102 | for i in range(starti,endi): 103 | #print(i) 104 | m,z=self.evos[i](m,z) 105 | return m,z 106 | def forward(self,m,z): 107 | 108 | # if True: 109 | # #print('will do checkpoint in Evoformer') 110 | # for i in range(self.layers[0]): 111 | # #m,z=checkpoint(self.layerfunc,self.evos[i],m,z) 112 | # m,z=self.evos[i](m,z) 113 | # return m,z 114 | m,z = checkpoint(self.forward_n,m,z,0,3) 115 | m,z = checkpoint(self.forward_n,m,z,3,6) 116 | m,z = checkpoint(self.forward_n,m,z,6,10) 117 | m,z = checkpoint(self.forward_n,m,z,10,13) 118 | m,z = checkpoint(self.forward_n,m,z,13,16) 119 | return m,z 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | if __name__ == "__main__": 129 | N=10 130 | L=30 131 | m_dim=16 132 | z_dim=8 133 | m=torch.rand(N,L,m_dim) 134 | z=torch.rand(L,L,z_dim) 135 | model = Evoformer(m_dim,z_dim) 136 | m,z=model(m,z) 137 | print(model.parameters()) 138 | for param in model.parameters(): 139 | print(type(param), param.size()) 140 | print(m.shape,z.shape) -------------------------------------------------------------------------------- /cfg_96/IPA.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import basic 6 | import math 7 | import os 8 | expdir=os.path.dirname(os.path.abspath(__file__)) 9 | lines = open(os.path.join(expdir,'newconfig')).readlines() 10 | attdrop = lines[0].strip().split()[-1] == '1' 11 | denoisee2e = lines[1].strip().split()[-1] == '1' 12 | ss_type = lines[2].strip().split()[-1] 13 | class InvariantPointAttention(nn.Module): 14 | def __init__(self,dim_in,dim_z,N_head=8,c=16,N_query=4,N_p_values=6,) -> None: 15 | super(InvariantPointAttention,self).__init__() 16 | self.dim_in=dim_in 17 | self.dim_z=dim_z 18 | self.N_head =N_head 19 | self.c=c 20 | self.c_squ = 1.0/math.sqrt(c) 21 | self.W_c = math.sqrt(2.0/(9*N_query)) 22 | self.W_L = math.sqrt(1.0/3) 23 | self.N_query=N_query 24 | self.N_p_values=N_p_values 25 | self.liner_nb_q1=basic.LinearNoBias(dim_in,self.c*N_head) 26 | self.liner_nb_k1=basic.LinearNoBias(dim_in,self.c*N_head) 27 | self.liner_nb_v1=basic.LinearNoBias(dim_in,self.c*N_head) 28 | 29 | self.liner_nb_q2=basic.LinearNoBias(dim_in,N_head*N_query*3) 30 | self.liner_nb_k2=basic.LinearNoBias(dim_in,N_head*N_query*3) 31 | 32 | self.liner_nb_v3=basic.LinearNoBias(dim_in,N_head*N_p_values*3) 33 | 34 | self.liner_nb_z=basic.LinearNoBias(dim_z,N_head) 35 | self.lastlinear1=basic.Linear(N_head*dim_z,dim_in) 36 | self.lastlinear2=basic.Linear(N_head*c,dim_in) 37 | self.lastlinear3=basic.Linear(N_head*N_p_values*3,dim_in) 38 | self.gama = nn.ParameterList([nn.Parameter(torch.zeros(N_head))]) 39 | self.cos_f=nn.CosineSimilarity(dim=-1) 40 | 41 | def forward(self,s,z,rot,trans): 42 | L=s.shape[0] 43 | q1=self.liner_nb_q1(s).reshape(L,self.N_head,self.c) # Lq, 44 | k1=self.liner_nb_k1(s).reshape(L,self.N_head,self.c) 45 | v1=self.liner_nb_v1(s).reshape(L,self.N_head,self.c) # lv,h,c 46 | 47 | attmap=torch.einsum('ihc,jhc->ijh',q1,k1) * self.c_squ # Lq,Lk_v,h 48 | bias_z=self.liner_nb_z(z) # L L h 49 | 50 | q2 = self.liner_nb_q2(s).reshape(L,self.N_head,self.N_query,3) 51 | k2 = self.liner_nb_k2(s).reshape(L,self.N_head,self.N_query,3) 52 | 53 | v3 = self.liner_nb_v3(s).reshape(L,self.N_head,self.N_p_values,3) 54 | 55 | q2 = basic.IPA_transform(q2,rot,trans) # Lq,self.N_head,self.N_query,3 56 | k2 = basic.IPA_transform(k2,rot,trans) # Lk,self.N_head,self.N_query,3 57 | 58 | dismap=((q2[:,None,:,:,:] - k2[None,:,:,:,:])**2).sum([3,4]) ## Lq,Lk, self.N_head, 59 | #dismap=dismap - (self.cos_f(q2[:,None,:,:,:] , k2[None,:,:,:,:])).sum(3) 60 | attmap = attmap + bias_z - F.softplus(self.gama[0])[None,None,:]*dismap*self.W_c*0.5 61 | #print(torch.max(attmap*self.W_L),torch.min(attmap)*self.W_L) 62 | #attmap = F.softmax( torch.clamp(attmap*self.W_L,-5,5),dim=1 ) # Lk dim, [Lq,Lk, self.N_head] 63 | 64 | attmap = F.softmax( attmap*self.W_L,dim=1 ) # Lk dim, [Lq,Lk, self.N_head] 65 | if attdrop: 66 | if self.training: 67 | attmap = basic.DropAtt(attmap,dim=1) 68 | o1 = (attmap[:,:,:,None] * z[:,:,None,:]).sum(1) # Lq, N_head, c_z 69 | o2 = torch.einsum('abc,dab->dbc',v1,attmap) # Lq, N_head, c 70 | o3 = basic.IPA_transform(v3,rot,trans) # Lv, h, p* ,3 71 | o3 = basic.IPA_inverse_transform( torch.einsum('vhpt,gvh->ghpt',o3,attmap),rot,trans) #Lv, h, p* ,3 72 | 73 | return self.lastlinear1(o1.reshape(L,-1)) + self.lastlinear2(o2.reshape(L,-1)) + self.lastlinear3(o3.reshape(L,-1)) 74 | 75 | 76 | 77 | 78 | 79 | 80 | if __name__ == "__main__": 81 | dim_in,dim_z = 8,4 82 | L = 10 83 | ipa = InvariantPointAttention(dim_in,dim_z) 84 | s=torch.rand(L,dim_in) 85 | z=torch.rand(L,L,dim_z) 86 | rot=(torch.eye(3)[None,:,:]).repeat(L,1,1) 87 | trans=torch.rand(L,3) 88 | 89 | out=ipa(s,z,rot,trans) 90 | print(out) 91 | print(out.shape) 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /cfg_96/RNALM2/EvoMSA.py: -------------------------------------------------------------------------------- 1 | from numpy import select 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from RNALM2 import basic 6 | import math 7 | 8 | def SignedSqrt( x): 9 | x = torch.sqrt(torch.relu(x)) - torch.sqrt(torch.relu(-x)) 10 | return x 11 | class MSARow(nn.Module): 12 | def __init__(self,m_dim,z_dim,N_head=8,c=8): 13 | super(MSARow,self).__init__() 14 | self.N_head = N_head 15 | self.c = c 16 | self.sq_c = 1/math.sqrt(c) 17 | self.norm1=nn.LayerNorm(m_dim) 18 | self.qlinear = basic.LinearNoBias(m_dim,N_head*c) 19 | self.klinear = basic.LinearNoBias(m_dim,N_head*c) 20 | self.vlinear = basic.LinearNoBias(m_dim,N_head*c) 21 | self.norm_z = nn.LayerNorm(z_dim) 22 | self.zlinear = basic.LinearNoBias(z_dim,N_head) 23 | self.glinear = basic.Linear(m_dim,N_head*c) 24 | self.olinear = basic.Linear(N_head*c,m_dim) 25 | 26 | def forward(self,m,z): 27 | # m : N L 32 28 | N,L,D = m.shape 29 | m = self.norm1(m) 30 | q = self.qlinear(m).reshape(N,L,self.N_head,self.c) #s rq h c 31 | k = self.klinear(m).reshape(N,L,self.N_head,self.c) #s rv h c 32 | v = self.vlinear(m).reshape(N,L,self.N_head,self.c) 33 | b = self.zlinear(self.norm_z(z)) 34 | g = torch.sigmoid(self.glinear(m)).reshape(N,L,self.N_head,self.c) 35 | att=torch.einsum('bqhc,bvhc->bqvh',q,k) * (self.sq_c) + b[None,:,:,:] # rq rv h 36 | att=F.softmax(SignedSqrt(att),dim=2) 37 | o = torch.einsum('bqvh,bvhc->bqhc',att,v) * g 38 | m_ = self.olinear(o.reshape(N,L,-1)) 39 | return m_ 40 | 41 | class MSACol(nn.Module): 42 | def __init__(self,m_dim,N_head=8,c=8): 43 | super(MSACol,self).__init__() 44 | self.N_head = N_head 45 | self.c = c 46 | self.sq_c = 1/math.sqrt(c) 47 | self.norm1=nn.LayerNorm(m_dim) 48 | self.qlinear = basic.LinearNoBias(m_dim,N_head*c) 49 | self.klinear = basic.LinearNoBias(m_dim,N_head*c) 50 | self.vlinear = basic.LinearNoBias(m_dim,N_head*c) 51 | 52 | self.glinear = basic.Linear(m_dim,N_head*c) 53 | self.olinear = basic.Linear(N_head*c,m_dim) 54 | 55 | def forward(self,m): 56 | # m : N L 32 57 | N,L,D = m.shape 58 | m = self.norm1(m) 59 | q = self.qlinear(m).reshape(N,L,self.N_head,self.c) #s rq h c 60 | k = self.klinear(m).reshape(N,L,self.N_head,self.c) #s rv h c 61 | v = self.vlinear(m).reshape(N,L,self.N_head,self.c) 62 | 63 | g = torch.sigmoid(self.glinear(m)).reshape(N,L,self.N_head,self.c) 64 | 65 | att=torch.einsum('slhc,tlhc->stlh',q,k) * (self.sq_c) # rq rv h 66 | att=F.softmax(SignedSqrt(att),dim=1) 67 | o = torch.einsum('stlh,tlhc->slhc',att,v) * g 68 | m_ = self.olinear(o.reshape(N,L,-1)) 69 | return m_ 70 | 71 | class MSATrans(nn.Module): 72 | def __init__(self,m_dim,c_expand=2): 73 | super(MSATrans,self).__init__() 74 | self.c_expand=4 75 | self.m_dim=m_dim 76 | self.norm=nn.LayerNorm(m_dim) 77 | self.linear1 = basic.Linear(m_dim,m_dim*c_expand) 78 | self.linear2 = basic.Linear(m_dim*c_expand,m_dim) 79 | def forward(self,m): 80 | m = self.norm(m) 81 | m = self.linear1(m) 82 | m = self.linear2(F.relu(m)) 83 | return m 84 | 85 | class MSAOPM(nn.Module): 86 | def __init__(self,m_dim,z_dim,c=12): 87 | super(MSAOPM,self).__init__() 88 | self.m_dim=m_dim 89 | self.c=c 90 | self.norm=nn.LayerNorm(m_dim) 91 | self.linear1=basic.Linear(m_dim,c) 92 | self.linear2=basic.Linear(m_dim,c) 93 | self.linear3=basic.Linear(c*c,z_dim) 94 | def forward(self,m): 95 | N,L,D=m.shape 96 | o=self.norm(m) 97 | a=self.linear2(o) 98 | b=self.linear1(o) 99 | o = torch.einsum('nia,njb->nijab',a,b).mean(dim=0) 100 | o = self.linear3(o.reshape(L,L,-1)) 101 | return o 102 | 103 | 104 | 105 | 106 | 107 | 108 | if __name__ == "__main__": 109 | N=10 110 | L=30 111 | m_dim=16 112 | z_dim=8 113 | m=torch.rand(N,L,m_dim) 114 | z=torch.rand(L,L,z_dim) 115 | msarow=MSARow(m_dim,z_dim) 116 | msacol=MSACol(m_dim) 117 | msatrans=MSATrans(m_dim) 118 | msaopm=MSAOPM(m_dim,z_dim) 119 | y=msaopm(m) 120 | print(y.shape) -------------------------------------------------------------------------------- /cfg_96/RNALM2/EvoPair.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from RNALM2 import basic 5 | import math 6 | 7 | def SignedSqrt( x): 8 | x = torch.sqrt(torch.relu(x)) - torch.sqrt(torch.relu(-x)) 9 | return x 10 | class TriOut(nn.Module): 11 | def __init__(self,z_dim,c=32): 12 | super(TriOut,self).__init__() 13 | self.z_dim = z_dim 14 | self.norm =nn.LayerNorm(z_dim) 15 | self.onorm =nn.LayerNorm(c) 16 | self.alinear=basic.Linear(z_dim,c) 17 | self.blinear=basic.Linear(z_dim,c) 18 | self.aglinear=basic.Linear(z_dim,c) 19 | self.bglinear=basic.Linear(z_dim,c) 20 | self.glinear =basic.Linear(z_dim,z_dim) 21 | self.olinear=basic.Linear(c,z_dim) 22 | 23 | def forward(self,z): 24 | z = self.norm(z) 25 | a = self.alinear(z) * torch.sigmoid(self.aglinear(z)) 26 | b = self.blinear(z) * torch.sigmoid(self.bglinear(z)) 27 | o = torch.einsum('ilc,jlc->ijc',a,b) 28 | o = self.onorm(o) 29 | o = self.olinear(o) 30 | o = o * torch.sigmoid(self.glinear(z)) 31 | return o 32 | 33 | class TriIn(nn.Module): 34 | def __init__(self,z_dim,c=32): 35 | super(TriIn,self).__init__() 36 | self.z_dim = z_dim 37 | self.norm =nn.LayerNorm(z_dim) 38 | self.onorm =nn.LayerNorm(c) 39 | self.alinear=basic.Linear(z_dim,c) 40 | self.blinear=basic.Linear(z_dim,c) 41 | self.aglinear=basic.Linear(z_dim,c) 42 | self.bglinear=basic.Linear(z_dim,c) 43 | self.glinear =basic.Linear(z_dim,z_dim) 44 | self.olinear=basic.Linear(c,z_dim) 45 | 46 | def forward(self,z): 47 | z = self.norm(z) 48 | a = self.alinear(z) * torch.sigmoid(self.aglinear(z)) 49 | b = self.blinear(z) * torch.sigmoid(self.bglinear(z)) 50 | o = torch.einsum('lic,ljc->ijc',a,b) 51 | o = self.onorm(o) 52 | o = self.olinear(o) 53 | o = o * torch.sigmoid(self.glinear(z)) 54 | return o 55 | 56 | 57 | class TriAttStart(nn.Module): 58 | def __init__(self,z_dim,N_head=4,c=8): 59 | super(TriAttStart,self).__init__() 60 | self.z_dim = z_dim 61 | self.N_head = N_head 62 | self.c = c 63 | self.sq_c = 1/math.sqrt(c) 64 | self.norm=nn.LayerNorm(z_dim) 65 | self.qlinear=basic.Linear(z_dim,c*N_head) 66 | self.klinear=basic.Linear(z_dim,c*N_head) 67 | self.vlinear=basic.Linear(z_dim,c*N_head) 68 | self.blinear=basic.Linear(z_dim,N_head) 69 | self.glinear=basic.Linear(z_dim,c*N_head) 70 | self.olinear=basic.Linear(c*N_head,z_dim) 71 | 72 | def forward(self,z_): 73 | L1,L2,D=z_.shape 74 | z = self.norm(z_) 75 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 76 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 77 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 78 | b = self.blinear(z) 79 | att = torch.einsum('blhc,bkhc->blkh',q,k)*self.sq_c + b[None,:,:,:] 80 | att = F.softmax(SignedSqrt(att),dim=2) 81 | o = torch.einsum('blkh,bkhc->blhc',att,v) 82 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 83 | o = self.olinear(o) 84 | return o 85 | 86 | class TriAttEnd(nn.Module): 87 | def __init__(self,z_dim,N_head=4,c=8): 88 | super(TriAttEnd,self).__init__() 89 | self.z_dim = z_dim 90 | self.N_head = N_head 91 | self.c = c 92 | self.sq_c = 1/math.sqrt(c) 93 | self.norm=nn.LayerNorm(z_dim) 94 | self.qlinear=basic.Linear(z_dim,c*N_head) 95 | self.klinear=basic.Linear(z_dim,c*N_head) 96 | self.vlinear=basic.Linear(z_dim,c*N_head) 97 | self.blinear=basic.Linear(z_dim,N_head) 98 | self.glinear=basic.Linear(z_dim,c*N_head) 99 | self.olinear=basic.Linear(c*N_head,z_dim) 100 | 101 | def forward(self,z_): 102 | L1,L2,D=z_.shape 103 | z = self.norm(z_) 104 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 105 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 106 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 107 | b = self.blinear(z) 108 | att = torch.einsum('blhc,kbhc->blkh',q,k)*self.sq_c + b[None,:,:,:].permute(0,2,1,3) 109 | att = F.softmax(SignedSqrt(att),dim=2) 110 | o = torch.einsum('blkh,klhc->blhc',att,v) 111 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 112 | o = self.olinear(o) 113 | return o 114 | def forward2(self,z_): 115 | z = z_.permute(1,0,2) 116 | L1,L2,D=z_.shape 117 | z = self.norm(z_) 118 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 119 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 120 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 121 | b = self.blinear(z) 122 | att = torch.einsum('blhc,bkhc->blkh',q,k)*self.sq_c + b[None,:,:,:] 123 | att = F.softmax(att,dim=2) 124 | o = torch.einsum('blkh,bkhc->blhc',att,v) 125 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 126 | o = self.olinear(o) 127 | o = o.permute(1,0,2) 128 | return o 129 | class PairTrans(nn.Module): 130 | def __init__(self,z_dim,c_expand=2): 131 | super(PairTrans,self).__init__() 132 | self.z_dim=z_dim 133 | self.c_expand=c_expand 134 | self.norm = nn.LayerNorm(z_dim) 135 | self.linear1=basic.Linear(z_dim,z_dim*c_expand) 136 | self.linear2=basic.Linear(z_dim*c_expand,z_dim) 137 | def forward(self,z): 138 | a = self.linear1(self.norm(z)) 139 | a = self.linear2(F.relu(a)) 140 | return a 141 | 142 | 143 | 144 | 145 | if __name__ == "__main__": 146 | N=10 147 | L=30 148 | m_dim=16 149 | z_dim=8 150 | m=torch.rand(N,L,m_dim) 151 | z=torch.rand(L,L,z_dim) 152 | 153 | tr1=TriAttEnd(z_dim) 154 | tr2=PairTrans(z_dim) 155 | y=tr1(z) 156 | y2=tr1.forward2(z) 157 | y3=tr2(z) 158 | print(y3.shape) 159 | 160 | -------------------------------------------------------------------------------- /cfg_96/RNALM2/Evoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from RNALM2 import basic,EvoPair,EvoMSA 5 | import math,sys 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class EvoBlock(nn.Module): 10 | def __init__(self,m_dim,z_dim,docheck=False): 11 | super(EvoBlock,self).__init__() 12 | N_head = 16 13 | c = 16 14 | self.msa_row=EvoMSA.MSARow(m_dim,z_dim,N_head,c) 15 | self.msa_col=EvoMSA.MSACol(m_dim,N_head,c) 16 | self.msa_trans=EvoMSA.MSATrans(m_dim) 17 | 18 | self.msa_opm=EvoMSA.MSAOPM(m_dim,z_dim) 19 | 20 | self.pair_triout=EvoPair.TriOut(z_dim,72) 21 | self.pair_triin =EvoPair.TriIn(z_dim,72) 22 | self.pair_tristart=EvoPair.TriAttStart(z_dim) 23 | self.pair_triend =EvoPair.TriAttEnd(z_dim) 24 | self.pair_trans = EvoPair.PairTrans(z_dim) 25 | self.docheck=docheck 26 | if docheck: 27 | print('will do checkpoint') 28 | 29 | def layerfunc_msa_row(self,m,z): 30 | return self.msa_row(m,z) + m 31 | def layerfunc_msa_col(self,m): 32 | return self.msa_col(m) + m 33 | def layerfunc_msa_trans(self,m): 34 | return self.msa_trans(m) + m 35 | def layerfunc_msa_opm(self,m,z): 36 | return self.msa_opm(m) + z 37 | 38 | def layerfunc_pair_triout(self,z): 39 | return self.pair_triout(z) + z 40 | def layerfunc_pair_triin(self,z): 41 | return self.pair_triin(z) + z 42 | def layerfunc_pair_tristart(self,z): 43 | return self.pair_tristart(z) + z 44 | def layerfunc_pair_triend(self,z): 45 | return self.pair_triend(z) + z 46 | def layerfunc_pair_trans(self,z): 47 | return self.pair_trans(z) + z 48 | def forward(self,m,z): 49 | if True: 50 | m = m + self.msa_row(m,z) 51 | m = m + self.msa_col(m) 52 | m = m + self.msa_trans(m) 53 | z = z + self.msa_opm(m) 54 | z = z + self.pair_triout(z) 55 | z = z + self.pair_triin(z) 56 | #z = z + self.pair_tristart(z) 57 | #z = z + self.pair_triend(z) 58 | z = z + self.pair_trans(z) 59 | return m,z 60 | else: 61 | m=checkpoint(self.layerfunc_msa_row,m,z) 62 | m=checkpoint(self.layerfunc_msa_col,m) 63 | m=checkpoint(self.layerfunc_msa_trans,m) 64 | z=checkpoint(self.layerfunc_msa_opm,m,z) 65 | 66 | z=checkpoint(self.layerfunc_pair_triout,z) 67 | z=checkpoint(self.layerfunc_pair_triin,z) 68 | z=checkpoint(self.layerfunc_pair_tristart,z) 69 | z=checkpoint(self.layerfunc_pair_triend,z) 70 | z=checkpoint(self.layerfunc_pair_trans,z) 71 | 72 | return m,z 73 | 74 | 75 | class Evoformer(nn.Module): 76 | def __init__(self,m_dim,z_dim,N_elayers=12,docheck=False): 77 | super(Evoformer,self).__init__() 78 | self.layers=[N_elayers] 79 | self.docheck=docheck 80 | if docheck: 81 | pass 82 | #print('will do checkpoint') 83 | self.evos=nn.ModuleList([EvoBlock(m_dim,z_dim,True) for i in range(self.layers[0])]) 84 | 85 | def layerfunc(self,layermodule,m,z): 86 | m_,z_=layermodule(m,z) 87 | return m_,z_ 88 | 89 | 90 | def forward(self,m,z): 91 | 92 | if True: 93 | #print('will do checkpoint in Evoformer') 94 | for i in range(self.layers[0]): 95 | #m,z=checkpoint(self.layerfunc,self.evos[i],m,z) 96 | m,z=self.evos[i](m,z) 97 | return m,z 98 | else: 99 | for i in range(self.layers[0]): 100 | m,z=self.evos[i](m,z) 101 | 102 | return m,z 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | if __name__ == "__main__": 113 | N=10 114 | L=30 115 | m_dim=16 116 | z_dim=8 117 | m=torch.rand(N,L,m_dim) 118 | z=torch.rand(L,L,z_dim) 119 | model = Evoformer(m_dim,z_dim) 120 | m,z=model(m,z) 121 | print(model.parameters()) 122 | for param in model.parameters(): 123 | print(type(param), param.size()) 124 | print(m.shape,z.shape) -------------------------------------------------------------------------------- /cfg_96/RNALM2/Model.py: -------------------------------------------------------------------------------- 1 | from numpy import select 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from RNALM2 import basic,Evoformer 6 | import math,sys 7 | from torch.utils.checkpoint import checkpoint 8 | import numpy as np 9 | 10 | 11 | def one_d(idx_, d, max_len=2056*8): 12 | idx = idx_[None] 13 | K = torch.arange(d//2).to(idx.device) 14 | sin_e = torch.sin(idx[..., None] * math.pi / (max_len**(2*K[None]/d))).to(idx.device) 15 | cos_e = torch.cos(idx[..., None] * math.pi / (max_len**(2*K[None]/d))).to(idx.device) 16 | return torch.cat([sin_e, cos_e], axis=-1)[0] 17 | 18 | 19 | 20 | 21 | 22 | 23 | class RNAembedding(nn.Module): 24 | def __init__(self,cfg): 25 | super(RNAembedding,self).__init__() 26 | self.s_in_dim=cfg['s_in_dim'] 27 | self.z_in_dim=cfg['z_in_dim'] 28 | self.s_dim=cfg['s_dim'] 29 | self.z_dim=cfg['z_dim'] 30 | self.qlinear =basic.Linear(self.s_in_dim+1,self.z_dim) 31 | self.klinear =basic.Linear(self.s_in_dim+1,self.z_dim) 32 | self.slinear =basic.Linear(self.s_in_dim+1,self.s_dim) 33 | self.zlinear =basic.Linear(self.z_in_dim+1,self.z_dim) 34 | 35 | self.poslinears = basic.Linear(64,self.s_dim) 36 | self.poslinearz = basic.Linear(64,self.z_dim) 37 | def forward(self,in_dict): 38 | # msa N L D, seq L D 39 | # mask: maksing, L, 1 means masked 40 | # aa: L x s_in_dim 41 | # ss: L x L x 2 42 | # idx: L (LongTensor) 43 | L = in_dict['aa'].shape[0] 44 | aamask = in_dict['mask'][:,None] 45 | zmask = in_dict['mask'][:,None] + in_dict['mask'][None,:] 46 | zmask[zmask>0.5]=1 47 | zmask = zmask[...,None] 48 | s = torch.cat([aamask,(1-aamask)*in_dict['aa']],dim=-1) 49 | sq=self.qlinear(s) 50 | sk=self.klinear(s) 51 | z=sq[None,:,:]+sk[:,None,:] 52 | seq_idx = in_dict['idx'][None] 53 | relative_pos = seq_idx[:, :, None] - seq_idx[:, None, :] 54 | relative_pos = relative_pos.reshape([1, -1]) 55 | relative_pos =one_d(relative_pos,64) 56 | z = z + self.poslinearz( relative_pos.reshape([1, L, L, -1])[0] ) 57 | 58 | s = self.slinear(s) + self.poslinears( one_d(in_dict['idx'], 64) ) 59 | 60 | return s,z 61 | 62 | 63 | class RNA2nd(nn.Module): 64 | def __init__(self,cfg): 65 | super(RNA2nd,self).__init__() 66 | self.s_in_dim=cfg['s_in_dim'] 67 | self.z_in_dim=cfg['z_in_dim'] 68 | self.s_dim=cfg['s_dim'] 69 | self.z_dim=cfg['z_dim'] 70 | self.N_elayers =cfg['N_elayers'] 71 | self.emb = RNAembedding(cfg) 72 | self.evmodel=Evoformer.Evoformer(self.s_dim,self.z_dim,self.N_elayers) 73 | self.seq_head = basic.Linear(self.s_dim,self.s_in_dim) 74 | self.joint_head = basic.Linear(self.z_dim,self.s_in_dim*self.s_in_dim) 75 | 76 | 77 | 78 | 79 | def embedding(self,in_dict): 80 | s,z = self.emb(in_dict) 81 | s,z = self.evmodel(s[None,...],z) 82 | return s[0],z 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /cfg_96/RNALM2/basic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import random 5 | class Linear(nn.Module): 6 | def __init__(self,dim_in,dim_out): 7 | super(Linear,self).__init__() 8 | self.linear = nn.Linear(dim_in,dim_out) 9 | def forward(self,x): 10 | x = self.linear(x) 11 | return x 12 | 13 | 14 | class LinearNoBias(nn.Module): 15 | def __init__(self,dim_in,dim_out): 16 | super(LinearNoBias,self).__init__() 17 | self.linear = nn.Linear(dim_in,dim_out,bias=False) 18 | def forward(self,x): 19 | x = self.linear(x) 20 | return x 21 | 22 | 23 | 24 | def transform(k,rotation,translation): 25 | # K L x 3 26 | # rotation 27 | return torch.matmul(k,rotation) + translation 28 | 29 | 30 | def batch_transform(k,rotation,translation): 31 | # k: L 3 32 | # rotation: L 3 x 3 33 | # translation: L 3 34 | return torch.einsum('ba,bad->bd',k,rotation) + translation 35 | 36 | def batch_atom_transform(k,rotation,translation): 37 | # k: L N 3 38 | # rotation: L 3 x 3 39 | # translation: L 3 40 | return torch.einsum('bja,bad->bjd',k,rotation) + translation[:,None,:] 41 | 42 | def IPA_transform(k,rotation,translation): 43 | # k: L d1, d2, 3 44 | # rotation: L 3 x 3 45 | # translation: L 3 46 | return torch.einsum('bija,bad->bijd',k,rotation)+translation[:,None,None,:] 47 | 48 | def IPA_inverse_transform(k,rotation,translation): 49 | # k: L d1, d2, 3 50 | # rotation: L 3 x 3 51 | # translation: L 3 52 | return torch.einsum('bija,bad->bijd',k-translation[:,None,None,:],rotation.transpose(-1,-2)) 53 | 54 | def update_transform(t,tr,rotation,translation): 55 | return torch.einsum('bja,bad->bjd',t,rotation),torch.einsum('ba,bad->bd',tr,rotation) +translation 56 | 57 | 58 | def quat2rot(q,L): 59 | scale= ((q**2).sum(dim=-1,keepdim=True) +1) [:,:,None] 60 | u=torch.empty([L,3,3],device=q.device) 61 | u[:,0,0]=1*1+q[:,0]*q[:,0]-q[:,1]*q[:,1]-q[:,2]*q[:,2] 62 | u[:,0,1]=2*(q[:,0]*q[:,1]-1*q[:,2]) 63 | u[:,0,2]=2*(q[:,0]*q[:,2]+1*q[:,1]) 64 | u[:,1,0]=2*(q[:,0]*q[:,1]+1*q[:,2]) 65 | u[:,1,1]=1*1-q[:,0]*q[:,0]+q[:,1]*q[:,1]-q[:,2]*q[:,2] 66 | u[:,1,2]=2*(q[:,1]*q[:,2]-1*q[:,0]) 67 | u[:,2,0]=2*(q[:,0]*q[:,2]-1*q[:,1]) 68 | u[:,2,1]=2*(q[:,1]*q[:,2]+1*q[:,0]) 69 | u[:,2,2]=1*1-q[:,0]*q[:,0]-q[:,1]*q[:,1]+q[:,2]*q[:,2] 70 | return u/scale 71 | 72 | 73 | def rotation_x(sintheta,costheta,ones,zeros): 74 | # L x 1 75 | return torch.stack([torch.stack([ones, zeros, zeros]), 76 | torch.stack([zeros, costheta, sintheta]), 77 | torch.stack([zeros, -sintheta, costheta])]) 78 | def rotation_y(sintheta,costheta,ones,zeros): 79 | # L x 1 80 | return torch.stack([torch.stack([costheta, zeros, sintheta]), 81 | torch.stack([zeros, ones, zeros]), 82 | torch.stack([-sintheta, zeros, costheta])]) 83 | def rotation_z(sintheta,costheta,ones,zeros): 84 | # L x 1 85 | return torch.stack([torch.stack([costheta, sintheta, zeros]), 86 | torch.stack([-sintheta, costheta, zeros]), 87 | torch.stack([zeros, zeros, ones])]) 88 | def batch_rotation(k,rotation): 89 | # k: L 3 90 | # rotation: L 3 x 3 91 | # translation: L 3 92 | return torch.einsum('ba,bad->bd',k,rotation) 93 | 94 | def compute_cb(bl,sin_angle,cos_angle,sin_torsion,cos_torsion): 95 | L=bl.shape[0] 96 | ones=torch.ones(L,device=bl.device) 97 | zeros=torch.zeros(L,device=bl.device) 98 | cb=torch.stack([bl,zeros,zeros]).permute(1,0) 99 | rotz=rotation_z(sin_angle,cos_angle,ones,zeros).permute(2,0,1) 100 | rotx=rotation_x(sin_torsion,cos_torsion,ones,zeros).permute(2,0,1) 101 | cb=batch_rotation(cb,rotz) 102 | cb=batch_rotation(cb,rotx) 103 | return cb 104 | 105 | def rigidFrom3Points_(x1,x2,x3): 106 | v1=x3-x2 107 | v2=x1-x2 108 | e1=v1/(torch.norm(v1,dim=-1,keepdim=True) + 1e-03) 109 | u2=v2 - e1*(torch.einsum('bn,bn->b',e1,v2)[:,None]) 110 | e2 = u2/(torch.norm(u2,dim=-1,keepdim=True) + 1e-03) 111 | e3=torch.cross(e1,e2,dim=-1) 112 | 113 | return torch.stack([e1,e2,e3],dim=1),x2[:,:] 114 | def rigidFrom3Points(x1,x2,x3):# L 3 115 | the_dim=1 116 | x = torch.stack([x1,x2,x3],dim=the_dim) 117 | x_mean = torch.mean(x,dim=the_dim,keepdim=True) 118 | x = x - x_mean 119 | 120 | 121 | m = x.view(-1, 3, 3) 122 | u, s, v = torch.svd(m) 123 | vt = torch.transpose(v, 1, 2) 124 | det = torch.det(torch.matmul(u, vt)) 125 | det = det.view(-1, 1, 1) 126 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 127 | r = torch.matmul(u, vt) 128 | return r,x_mean.squeeze() 129 | def Kabsch_rigid(bases,x1,x2,x3): 130 | ''' 131 | return the direction from to_q to from_p 132 | ''' 133 | the_dim=1 134 | to_q = torch.stack([x1,x2,x3],dim=the_dim) 135 | biasq=torch.mean(to_q,dim=the_dim,keepdim=True) 136 | q=to_q-biasq 137 | m = torch.einsum('bnz,bny->bzy',bases,q) 138 | u, s, v = torch.svd(m) 139 | vt = torch.transpose(v, 1, 2) 140 | det = torch.det(torch.matmul(u, vt)) 141 | det = det.view(-1, 1, 1) 142 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 143 | r = torch.matmul(u, vt) 144 | return r,biasq.squeeze() 145 | def Generate_msa_mask(n,l): 146 | # 1: 15% mask out 147 | randommatrix=torch.rand(n,l) 148 | mask = (randommatrix <0.1).float() 149 | # 2 random a segment 150 | seqlength = int(l*0.1) 151 | sindex=round(random.random()*(l-seqlength)) 152 | endindex=min(l,sindex+seqlength) 153 | mask[:,sindex:endindex]=1 154 | return mask 155 | 156 | 157 | -------------------------------------------------------------------------------- /cfg_96/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/cfg_96/__init__.py -------------------------------------------------------------------------------- /cfg_96/base.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/cfg_96/base.npy -------------------------------------------------------------------------------- /cfg_96/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os,math 3 | import numpy as np 4 | from torch.utils.data import Dataset, DataLoader 5 | from random import sample 6 | from numpy import float32 7 | import random 8 | from scipy.spatial.distance import cdist 9 | from subprocess import Popen, PIPE, STDOUT 10 | expdir=os.path.dirname(os.path.abspath(__file__)) 11 | #lines = open(os.path.join(expdir,'newconfig')).readlines() 12 | #attdrop = lines[0].strip().split()[-1] == '1' 13 | # denoisee2e = lines[1].strip().split()[-1] == '1' 14 | # ss_type = lines[2].strip().split()[-1] 15 | code_standard = { 16 | 'A':'A','G':'G','C':'C','U':'U','a':'A','g':'G','c':'C','u':'U','T':'U','t':'U' 17 | } 18 | expdir=os.path.dirname(os.path.abspath(__file__)) 19 | parentdir = os.path.dirname(expdir) 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | def parse_seq(inseq): 32 | seqnpy=np.zeros(len(inseq)) 33 | seq1=np.array(list(inseq)) 34 | seqnpy[seq1=='A']=1 35 | seqnpy[seq1=='G']=2 36 | seqnpy[seq1=='C']=3 37 | seqnpy[seq1=='U']=4 38 | seqnpy[seq1=='T']=4 39 | return seqnpy 40 | 41 | 42 | def Get_base(seq,basenpy_standard): 43 | basenpy = np.zeros([len(seq),3,3]) 44 | seqnpy = np.array(list(seq)) 45 | basenpy[seqnpy=='A']=basenpy_standard[0] 46 | basenpy[seqnpy=='a']=basenpy_standard[0] 47 | 48 | basenpy[seqnpy=='G']=basenpy_standard[1] 49 | basenpy[seqnpy=='g']=basenpy_standard[1] 50 | 51 | basenpy[seqnpy=='C']=basenpy_standard[2] 52 | basenpy[seqnpy=='c']=basenpy_standard[2] 53 | 54 | basenpy[seqnpy=='U']=basenpy_standard[3] 55 | basenpy[seqnpy=='u']=basenpy_standard[3] 56 | 57 | basenpy[seqnpy=='T']=basenpy_standard[3] 58 | basenpy[seqnpy=='t']=basenpy_standard[3] 59 | return basenpy 60 | -------------------------------------------------------------------------------- /cfg_96/newconfig: -------------------------------------------------------------------------------- 1 | attdrop: 1 2 | denoisee2e: 1 3 | ss_type: attention -------------------------------------------------------------------------------- /cfg_96/test_modeldir.py: -------------------------------------------------------------------------------- 1 | import random 2 | random.seed(0) 3 | import numpy as np 4 | np.random.seed(0) 5 | import os,sys,re,random 6 | from numpy import select 7 | import torch 8 | torch.manual_seed(0) 9 | torch.backends.cudnn.deterministic = True 10 | torch.backends.cudnn.benchmark = False 11 | expdir=os.path.dirname(os.path.abspath(__file__)) 12 | 13 | # from pathlib import Path 14 | # path = Path(expdir) 15 | # parepath = path.parent.absolute() 16 | 17 | 18 | import torch.optim as opt 19 | from torch.nn import functional as F 20 | import data,util 21 | import EvoMSA2XYZ,basic 22 | import math 23 | import pickle 24 | Batch_size=3 25 | Num_cycle=3 26 | TEST_STEP=1000 27 | VISION_STEP=50 28 | device = sys.argv[1] 29 | 30 | 31 | # expdir=os.path.dirname(os.path.abspath(__file__)) 32 | # expround=expdir.split('_')[-1] 33 | # model_path=os.path.join(expdir,'others','models') 34 | # if not os.path.isdir(model_path): 35 | # try: 36 | # os.makedirs(model_path) 37 | # except: 38 | # pass 39 | # testdir=os.path.join(expdir,'others','preds') 40 | # if not os.path.isdir(testdir): 41 | # try: 42 | # os.makedirs(testdir) 43 | # except: 44 | # pass 45 | 46 | 47 | 48 | basenpy_standard= np.load( os.path.join(os.path.dirname(os.path.abspath(__file__)),'base.npy' ) ) 49 | def data_collect(pdb_seq): 50 | aa_type = data.parse_seq(pdb_seq) 51 | base = data.Get_base(pdb_seq,basenpy_standard) 52 | seq_idx = np.arange(len(pdb_seq)) + 1 53 | 54 | msa=aa_type[None,:] 55 | msa=torch.from_numpy(msa).to(device) 56 | msa=torch.cat([msa,msa],0) 57 | msa=F.one_hot(msa.long(),6).float() 58 | 59 | base_x = torch.from_numpy(base).float().to(device) 60 | seq_idx = torch.from_numpy(seq_idx).long().to(device) 61 | return msa,base_x,seq_idx 62 | 63 | 64 | 65 | def classifier(infasta,out_prefix,model_dir): 66 | with torch.no_grad(): 67 | lines = open(infasta).readlines()[1:] 68 | seqs = [aline.strip() for aline in lines] 69 | seq = ''.join(seqs) 70 | msa,base_x,seq_idx = data_collect(seq) 71 | # seq_idx = np.genfromtxt(idxfile).astype(int) 72 | # seq_idx = torch.from_numpy(seq_idx).long().to(device) 73 | 74 | msa_dim=6+1 75 | m_dim,s_dim,z_dim = 64,64,64 76 | N_ensemble,N_cycle=3,8 77 | model=EvoMSA2XYZ.MSA2XYZ(msa_dim-1,msa_dim,N_ensemble,N_cycle,m_dim,s_dim,z_dim) 78 | model.to(device) 79 | model.eval() 80 | models = os.listdir( model_dir ) 81 | models = [amodel for amodel in models if 'model' in amodel and 'opt' not in amodel] 82 | models.sort() 83 | 84 | for amodel in models: 85 | #saved_model=os.path.join(expdir,'others','models',amodel) 86 | saved_model=os.path.join(model_dir,amodel) 87 | model.load_state_dict(torch.load(saved_model,map_location='cpu'),strict=True) 88 | ret = model.pred(msa,seq_idx,None,base_x,np.array(list(seq))) 89 | 90 | util.outpdb(ret['coor'],seq_idx,seq,out_prefix+f'{amodel}.pdb') 91 | #ret = {'plddt':ret['plddt']} 92 | pickle.dump(ret,open(out_prefix+f'{amodel}.ret','wb')) 93 | # for akey in ret: 94 | # print(akey,ret[akey].shape) 95 | 96 | 97 | 98 | 99 | 100 | 101 | if __name__ == '__main__': 102 | infasta,out_prefix,model_dir = sys.argv[2],sys.argv[3],sys.argv[4] 103 | classifier(infasta,out_prefix,model_dir) -------------------------------------------------------------------------------- /cfg_96/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from subprocess import Popen, PIPE, STDOUT 4 | import os,sys 5 | 6 | 7 | 8 | def outpdb(coor,seq_idx,seq,savefile,start=0,end=10000,energystr=''): 9 | #rama=torch.cat([rama.view(self.L,2),self.betas],dim=-1) 10 | L = coor.shape[0] 11 | 12 | Atom_name=[' P '," C4'",' N1 '] 13 | Other_Atom_name = [" O5'"," C5'"," C3'"," O3'"," C1'"] 14 | other_last_name = ['O',"C","C","O","C"] 15 | 16 | 17 | last_name=['P','C','N'] 18 | wstr=[f'REMARK {str(energystr)}'] 19 | templet='%6s%5d %4s %3s %1s%4d %8.3f%8.3f%8.3f%6.2f%6.2f %2s%2s' 20 | count=1 21 | for i in range(L): 22 | if seq[i] in ['a','g','A','G']: 23 | Atom_name = [' P '," C4'",' N9 '] 24 | #atoms = ['P','C4'] 25 | 26 | elif seq[i] in ['c','u','C','U']: 27 | Atom_name = [' P '," C4'",' N1 '] 28 | for j in range(coor.shape[1]): 29 | outs=('ATOM ',count,Atom_name[j],seq[i],'A',seq_idx[i],coor[i][j][0],coor[i][j][1],coor[i][j][2],0,0,last_name[j],'') 30 | #outs=('ATOM ',count,Atom_name[j],'ALA','A',i+1,coor_np[i][j][0],coor_np[i][j][1],coor_np[i][j][2],1.0,90,last_name[j],'') 31 | #print(outs) 32 | if i>=start-1 and i < end: 33 | wstr.append(templet % outs) 34 | 35 | # for j in range(other_np.shape[1]): 36 | # outs=('ATOM ',count,Other_Atom_name[j],self.seq[i],'A',i+1,other_np[i][j][0],other_np[i][j][1],other_np[i][j][2],0,0,other_last_name[j],'') 37 | # #outs=('ATOM ',count,Atom_name[j],'ALA','A',i+1,coor_np[i][j][0],coor_np[i][j][1],coor_np[i][j][2],1.0,90,last_name[j],'') 38 | # #print(outs) 39 | # if i>=start-1 and i < end: 40 | # wstr.append(templet % outs) 41 | count+=1 42 | wstr.append('TER') 43 | wstr='\n'.join(wstr) 44 | wfile=open(savefile,'w') 45 | wfile.write(wstr) 46 | wfile.close() 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /cfg_97/Evoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import basic,EvoPair,EvoMSA 5 | import math,sys 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class EvoBlock(nn.Module): 10 | def __init__(self,m_dim,z_dim,docheck=False): 11 | super(EvoBlock,self).__init__() 12 | self.msa_row=EvoMSA.MSARow(m_dim,z_dim) 13 | self.msa_col=EvoMSA.MSACol(m_dim) 14 | self.msa_trans=EvoMSA.MSATrans(m_dim) 15 | 16 | self.msa_opm=EvoMSA.MSAOPM(m_dim,z_dim) 17 | 18 | self.pair_triout=EvoPair.TriOut(z_dim) 19 | self.pair_triin =EvoPair.TriIn(z_dim) 20 | self.pair_tristart=EvoPair.TriAttStart(z_dim) 21 | self.pair_triend =EvoPair.TriAttEnd(z_dim) 22 | self.pair_trans = EvoPair.PairTrans(z_dim) 23 | self.docheck=docheck 24 | if docheck: 25 | print('will do checkpoint') 26 | 27 | def layerfunc_msa_row(self,m,z): 28 | return self.msa_row(m,z) + m 29 | def layerfunc_msa_col(self,m): 30 | return self.msa_col(m) + m 31 | def layerfunc_msa_trans(self,m): 32 | return self.msa_trans(m) + m 33 | def layerfunc_msa_opm(self,m,z): 34 | return self.msa_opm(m) + z 35 | 36 | def layerfunc_pair_triout(self,z): 37 | return self.pair_triout(z) + z 38 | def layerfunc_pair_triin(self,z): 39 | return self.pair_triin(z) + z 40 | def layerfunc_pair_tristart(self,z): 41 | return self.pair_tristart(z) + z 42 | def layerfunc_pair_triend(self,z): 43 | return self.pair_triend(z) + z 44 | def layerfunc_pair_trans(self,z): 45 | return self.pair_trans(z) + z 46 | def forward(self,m,z): 47 | if True: 48 | m = m + self.msa_row(m,z) 49 | m = m + self.msa_col(m) 50 | m = m + self.msa_trans(m) 51 | z = z + self.msa_opm(m) 52 | z = z + self.pair_triout(z) 53 | z = z + self.pair_triin(z) 54 | z = z + self.pair_tristart(z) 55 | z = z + self.pair_triend(z) 56 | z = z + self.pair_trans(z) 57 | return m,z 58 | else: 59 | m=checkpoint(self.layerfunc_msa_row,m,z) 60 | m=checkpoint(self.layerfunc_msa_col,m) 61 | m=checkpoint(self.layerfunc_msa_trans,m) 62 | z=checkpoint(self.layerfunc_msa_opm,m,z) 63 | 64 | z=checkpoint(self.layerfunc_pair_triout,z) 65 | z=checkpoint(self.layerfunc_pair_triin,z) 66 | z=checkpoint(self.layerfunc_pair_tristart,z) 67 | z=checkpoint(self.layerfunc_pair_triend,z) 68 | z=checkpoint(self.layerfunc_pair_trans,z) 69 | 70 | return m,z 71 | 72 | 73 | class Evoformer(nn.Module): 74 | def __init__(self,m_dim,z_dim,docheck=False): 75 | super(Evoformer,self).__init__() 76 | self.layers=[16] 77 | self.docheck=docheck 78 | if docheck: 79 | pass 80 | #print('will do checkpoint') 81 | self.evos=nn.ModuleList([EvoBlock(m_dim,z_dim,True) for i in range(self.layers[0])]) 82 | 83 | def layerfunc(self,layermodule,m,z): 84 | m_,z_=layermodule(m,z) 85 | return m_,z_ 86 | 87 | 88 | # def forward(self,m,z): 89 | 90 | # if True: 91 | # #print('will do checkpoint in Evoformer') 92 | # for i in range(self.layers[0]): 93 | # m,z=checkpoint(self.layerfunc,self.evos[i],m,z) 94 | # #m,z=self.evos[i](m,z) 95 | # return m,z 96 | # else: 97 | # for i in range(self.layers[0]): 98 | # m,z=self.evos[i](m,z) 99 | 100 | # return m,z 101 | def forward_n(self,m,z,starti,endi): 102 | for i in range(starti,endi): 103 | #print(i) 104 | m,z=self.evos[i](m,z) 105 | return m,z 106 | def forward(self,m,z): 107 | 108 | # if True: 109 | # #print('will do checkpoint in Evoformer') 110 | # for i in range(self.layers[0]): 111 | # #m,z=checkpoint(self.layerfunc,self.evos[i],m,z) 112 | # m,z=self.evos[i](m,z) 113 | # return m,z 114 | m,z = checkpoint(self.forward_n,m,z,0,3) 115 | m,z = checkpoint(self.forward_n,m,z,3,6) 116 | m,z = checkpoint(self.forward_n,m,z,6,10) 117 | m,z = checkpoint(self.forward_n,m,z,10,13) 118 | m,z = checkpoint(self.forward_n,m,z,13,16) 119 | return m,z 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | if __name__ == "__main__": 129 | N=10 130 | L=30 131 | m_dim=16 132 | z_dim=8 133 | m=torch.rand(N,L,m_dim) 134 | z=torch.rand(L,L,z_dim) 135 | model = Evoformer(m_dim,z_dim) 136 | m,z=model(m,z) 137 | print(model.parameters()) 138 | for param in model.parameters(): 139 | print(type(param), param.size()) 140 | print(m.shape,z.shape) -------------------------------------------------------------------------------- /cfg_97/IPA.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import basic 6 | import math 7 | import os 8 | expdir=os.path.dirname(os.path.abspath(__file__)) 9 | lines = open(os.path.join(expdir,'newconfig')).readlines() 10 | attdrop = lines[0].strip().split()[-1] == '1' 11 | denoisee2e = lines[1].strip().split()[-1] == '1' 12 | ss_type = lines[2].strip().split()[-1] 13 | class InvariantPointAttention(nn.Module): 14 | def __init__(self,dim_in,dim_z,N_head=8,c=16,N_query=4,N_p_values=6,) -> None: 15 | super(InvariantPointAttention,self).__init__() 16 | self.dim_in=dim_in 17 | self.dim_z=dim_z 18 | self.N_head =N_head 19 | self.c=c 20 | self.c_squ = 1.0/math.sqrt(c) 21 | self.W_c = math.sqrt(2.0/(9*N_query)) 22 | self.W_L = math.sqrt(1.0/3) 23 | self.N_query=N_query 24 | self.N_p_values=N_p_values 25 | self.liner_nb_q1=basic.LinearNoBias(dim_in,self.c*N_head) 26 | self.liner_nb_k1=basic.LinearNoBias(dim_in,self.c*N_head) 27 | self.liner_nb_v1=basic.LinearNoBias(dim_in,self.c*N_head) 28 | 29 | self.liner_nb_q2=basic.LinearNoBias(dim_in,N_head*N_query*3) 30 | self.liner_nb_k2=basic.LinearNoBias(dim_in,N_head*N_query*3) 31 | 32 | self.liner_nb_v3=basic.LinearNoBias(dim_in,N_head*N_p_values*3) 33 | 34 | self.liner_nb_z=basic.LinearNoBias(dim_z,N_head) 35 | self.lastlinear1=basic.Linear(N_head*dim_z,dim_in) 36 | self.lastlinear2=basic.Linear(N_head*c,dim_in) 37 | self.lastlinear3=basic.Linear(N_head*N_p_values*3,dim_in) 38 | self.gama = nn.ParameterList([nn.Parameter(torch.zeros(N_head))]) 39 | self.cos_f=nn.CosineSimilarity(dim=-1) 40 | 41 | def forward(self,s,z,rot,trans): 42 | L=s.shape[0] 43 | q1=self.liner_nb_q1(s).reshape(L,self.N_head,self.c) # Lq, 44 | k1=self.liner_nb_k1(s).reshape(L,self.N_head,self.c) 45 | v1=self.liner_nb_v1(s).reshape(L,self.N_head,self.c) # lv,h,c 46 | 47 | attmap=torch.einsum('ihc,jhc->ijh',q1,k1) * self.c_squ # Lq,Lk_v,h 48 | bias_z=self.liner_nb_z(z) # L L h 49 | 50 | q2 = self.liner_nb_q2(s).reshape(L,self.N_head,self.N_query,3) 51 | k2 = self.liner_nb_k2(s).reshape(L,self.N_head,self.N_query,3) 52 | 53 | v3 = self.liner_nb_v3(s).reshape(L,self.N_head,self.N_p_values,3) 54 | 55 | q2 = basic.IPA_transform(q2,rot,trans) # Lq,self.N_head,self.N_query,3 56 | k2 = basic.IPA_transform(k2,rot,trans) # Lk,self.N_head,self.N_query,3 57 | 58 | dismap=((q2[:,None,:,:,:] - k2[None,:,:,:,:])**2).sum([3,4]) ## Lq,Lk, self.N_head, 59 | #dismap=dismap - (self.cos_f(q2[:,None,:,:,:] , k2[None,:,:,:,:])).sum(3) 60 | attmap = attmap + bias_z - F.softplus(self.gama[0])[None,None,:]*dismap*self.W_c*0.5 61 | #print(torch.max(attmap*self.W_L),torch.min(attmap)*self.W_L) 62 | #attmap = F.softmax( torch.clamp(attmap*self.W_L,-5,5),dim=1 ) # Lk dim, [Lq,Lk, self.N_head] 63 | 64 | attmap = F.softmax( attmap*self.W_L,dim=1 ) # Lk dim, [Lq,Lk, self.N_head] 65 | if attdrop: 66 | if self.training: 67 | attmap = basic.DropAtt(attmap,dim=1) 68 | o1 = (attmap[:,:,:,None] * z[:,:,None,:]).sum(1) # Lq, N_head, c_z 69 | o2 = torch.einsum('abc,dab->dbc',v1,attmap) # Lq, N_head, c 70 | o3 = basic.IPA_transform(v3,rot,trans) # Lv, h, p* ,3 71 | o3 = basic.IPA_inverse_transform( torch.einsum('vhpt,gvh->ghpt',o3,attmap),rot,trans) #Lv, h, p* ,3 72 | 73 | return self.lastlinear1(o1.reshape(L,-1)) + self.lastlinear2(o2.reshape(L,-1)) + self.lastlinear3(o3.reshape(L,-1)) 74 | 75 | 76 | 77 | 78 | 79 | 80 | if __name__ == "__main__": 81 | dim_in,dim_z = 8,4 82 | L = 10 83 | ipa = InvariantPointAttention(dim_in,dim_z) 84 | s=torch.rand(L,dim_in) 85 | z=torch.rand(L,L,dim_z) 86 | rot=(torch.eye(3)[None,:,:]).repeat(L,1,1) 87 | trans=torch.rand(L,3) 88 | 89 | out=ipa(s,z,rot,trans) 90 | print(out) 91 | print(out.shape) 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /cfg_97/RNALM2/EvoMSA.py: -------------------------------------------------------------------------------- 1 | from numpy import select 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from RNALM2 import basic 6 | import math 7 | 8 | def SignedSqrt( x): 9 | x = torch.sqrt(torch.relu(x)) - torch.sqrt(torch.relu(-x)) 10 | return x 11 | class MSARow(nn.Module): 12 | def __init__(self,m_dim,z_dim,N_head=8,c=8): 13 | super(MSARow,self).__init__() 14 | self.N_head = N_head 15 | self.c = c 16 | self.sq_c = 1/math.sqrt(c) 17 | self.norm1=nn.LayerNorm(m_dim) 18 | self.qlinear = basic.LinearNoBias(m_dim,N_head*c) 19 | self.klinear = basic.LinearNoBias(m_dim,N_head*c) 20 | self.vlinear = basic.LinearNoBias(m_dim,N_head*c) 21 | self.norm_z = nn.LayerNorm(z_dim) 22 | self.zlinear = basic.LinearNoBias(z_dim,N_head) 23 | self.glinear = basic.Linear(m_dim,N_head*c) 24 | self.olinear = basic.Linear(N_head*c,m_dim) 25 | 26 | def forward(self,m,z): 27 | # m : N L 32 28 | N,L,D = m.shape 29 | m = self.norm1(m) 30 | q = self.qlinear(m).reshape(N,L,self.N_head,self.c) #s rq h c 31 | k = self.klinear(m).reshape(N,L,self.N_head,self.c) #s rv h c 32 | v = self.vlinear(m).reshape(N,L,self.N_head,self.c) 33 | b = self.zlinear(self.norm_z(z)) 34 | g = torch.sigmoid(self.glinear(m)).reshape(N,L,self.N_head,self.c) 35 | att=torch.einsum('bqhc,bvhc->bqvh',q,k) * (self.sq_c) + b[None,:,:,:] # rq rv h 36 | att=F.softmax(SignedSqrt(att),dim=2) 37 | o = torch.einsum('bqvh,bvhc->bqhc',att,v) * g 38 | m_ = self.olinear(o.reshape(N,L,-1)) 39 | return m_ 40 | 41 | class MSACol(nn.Module): 42 | def __init__(self,m_dim,N_head=8,c=8): 43 | super(MSACol,self).__init__() 44 | self.N_head = N_head 45 | self.c = c 46 | self.sq_c = 1/math.sqrt(c) 47 | self.norm1=nn.LayerNorm(m_dim) 48 | self.qlinear = basic.LinearNoBias(m_dim,N_head*c) 49 | self.klinear = basic.LinearNoBias(m_dim,N_head*c) 50 | self.vlinear = basic.LinearNoBias(m_dim,N_head*c) 51 | 52 | self.glinear = basic.Linear(m_dim,N_head*c) 53 | self.olinear = basic.Linear(N_head*c,m_dim) 54 | 55 | def forward(self,m): 56 | # m : N L 32 57 | N,L,D = m.shape 58 | m = self.norm1(m) 59 | q = self.qlinear(m).reshape(N,L,self.N_head,self.c) #s rq h c 60 | k = self.klinear(m).reshape(N,L,self.N_head,self.c) #s rv h c 61 | v = self.vlinear(m).reshape(N,L,self.N_head,self.c) 62 | 63 | g = torch.sigmoid(self.glinear(m)).reshape(N,L,self.N_head,self.c) 64 | 65 | att=torch.einsum('slhc,tlhc->stlh',q,k) * (self.sq_c) # rq rv h 66 | att=F.softmax(SignedSqrt(att),dim=1) 67 | o = torch.einsum('stlh,tlhc->slhc',att,v) * g 68 | m_ = self.olinear(o.reshape(N,L,-1)) 69 | return m_ 70 | 71 | class MSATrans(nn.Module): 72 | def __init__(self,m_dim,c_expand=2): 73 | super(MSATrans,self).__init__() 74 | self.c_expand=4 75 | self.m_dim=m_dim 76 | self.norm=nn.LayerNorm(m_dim) 77 | self.linear1 = basic.Linear(m_dim,m_dim*c_expand) 78 | self.linear2 = basic.Linear(m_dim*c_expand,m_dim) 79 | def forward(self,m): 80 | m = self.norm(m) 81 | m = self.linear1(m) 82 | m = self.linear2(F.relu(m)) 83 | return m 84 | 85 | class MSAOPM(nn.Module): 86 | def __init__(self,m_dim,z_dim,c=12): 87 | super(MSAOPM,self).__init__() 88 | self.m_dim=m_dim 89 | self.c=c 90 | self.norm=nn.LayerNorm(m_dim) 91 | self.linear1=basic.Linear(m_dim,c) 92 | self.linear2=basic.Linear(m_dim,c) 93 | self.linear3=basic.Linear(c*c,z_dim) 94 | def forward(self,m): 95 | N,L,D=m.shape 96 | o=self.norm(m) 97 | a=self.linear2(o) 98 | b=self.linear1(o) 99 | o = torch.einsum('nia,njb->nijab',a,b).mean(dim=0) 100 | o = self.linear3(o.reshape(L,L,-1)) 101 | return o 102 | 103 | 104 | 105 | 106 | 107 | 108 | if __name__ == "__main__": 109 | N=10 110 | L=30 111 | m_dim=16 112 | z_dim=8 113 | m=torch.rand(N,L,m_dim) 114 | z=torch.rand(L,L,z_dim) 115 | msarow=MSARow(m_dim,z_dim) 116 | msacol=MSACol(m_dim) 117 | msatrans=MSATrans(m_dim) 118 | msaopm=MSAOPM(m_dim,z_dim) 119 | y=msaopm(m) 120 | print(y.shape) -------------------------------------------------------------------------------- /cfg_97/RNALM2/EvoPair.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from RNALM2 import basic 5 | import math 6 | 7 | def SignedSqrt( x): 8 | x = torch.sqrt(torch.relu(x)) - torch.sqrt(torch.relu(-x)) 9 | return x 10 | class TriOut(nn.Module): 11 | def __init__(self,z_dim,c=32): 12 | super(TriOut,self).__init__() 13 | self.z_dim = z_dim 14 | self.norm =nn.LayerNorm(z_dim) 15 | self.onorm =nn.LayerNorm(c) 16 | self.alinear=basic.Linear(z_dim,c) 17 | self.blinear=basic.Linear(z_dim,c) 18 | self.aglinear=basic.Linear(z_dim,c) 19 | self.bglinear=basic.Linear(z_dim,c) 20 | self.glinear =basic.Linear(z_dim,z_dim) 21 | self.olinear=basic.Linear(c,z_dim) 22 | 23 | def forward(self,z): 24 | z = self.norm(z) 25 | a = self.alinear(z) * torch.sigmoid(self.aglinear(z)) 26 | b = self.blinear(z) * torch.sigmoid(self.bglinear(z)) 27 | o = torch.einsum('ilc,jlc->ijc',a,b) 28 | o = self.onorm(o) 29 | o = self.olinear(o) 30 | o = o * torch.sigmoid(self.glinear(z)) 31 | return o 32 | 33 | class TriIn(nn.Module): 34 | def __init__(self,z_dim,c=32): 35 | super(TriIn,self).__init__() 36 | self.z_dim = z_dim 37 | self.norm =nn.LayerNorm(z_dim) 38 | self.onorm =nn.LayerNorm(c) 39 | self.alinear=basic.Linear(z_dim,c) 40 | self.blinear=basic.Linear(z_dim,c) 41 | self.aglinear=basic.Linear(z_dim,c) 42 | self.bglinear=basic.Linear(z_dim,c) 43 | self.glinear =basic.Linear(z_dim,z_dim) 44 | self.olinear=basic.Linear(c,z_dim) 45 | 46 | def forward(self,z): 47 | z = self.norm(z) 48 | a = self.alinear(z) * torch.sigmoid(self.aglinear(z)) 49 | b = self.blinear(z) * torch.sigmoid(self.bglinear(z)) 50 | o = torch.einsum('lic,ljc->ijc',a,b) 51 | o = self.onorm(o) 52 | o = self.olinear(o) 53 | o = o * torch.sigmoid(self.glinear(z)) 54 | return o 55 | 56 | 57 | class TriAttStart(nn.Module): 58 | def __init__(self,z_dim,N_head=4,c=8): 59 | super(TriAttStart,self).__init__() 60 | self.z_dim = z_dim 61 | self.N_head = N_head 62 | self.c = c 63 | self.sq_c = 1/math.sqrt(c) 64 | self.norm=nn.LayerNorm(z_dim) 65 | self.qlinear=basic.Linear(z_dim,c*N_head) 66 | self.klinear=basic.Linear(z_dim,c*N_head) 67 | self.vlinear=basic.Linear(z_dim,c*N_head) 68 | self.blinear=basic.Linear(z_dim,N_head) 69 | self.glinear=basic.Linear(z_dim,c*N_head) 70 | self.olinear=basic.Linear(c*N_head,z_dim) 71 | 72 | def forward(self,z_): 73 | L1,L2,D=z_.shape 74 | z = self.norm(z_) 75 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 76 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 77 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 78 | b = self.blinear(z) 79 | att = torch.einsum('blhc,bkhc->blkh',q,k)*self.sq_c + b[None,:,:,:] 80 | att = F.softmax(SignedSqrt(att),dim=2) 81 | o = torch.einsum('blkh,bkhc->blhc',att,v) 82 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 83 | o = self.olinear(o) 84 | return o 85 | 86 | class TriAttEnd(nn.Module): 87 | def __init__(self,z_dim,N_head=4,c=8): 88 | super(TriAttEnd,self).__init__() 89 | self.z_dim = z_dim 90 | self.N_head = N_head 91 | self.c = c 92 | self.sq_c = 1/math.sqrt(c) 93 | self.norm=nn.LayerNorm(z_dim) 94 | self.qlinear=basic.Linear(z_dim,c*N_head) 95 | self.klinear=basic.Linear(z_dim,c*N_head) 96 | self.vlinear=basic.Linear(z_dim,c*N_head) 97 | self.blinear=basic.Linear(z_dim,N_head) 98 | self.glinear=basic.Linear(z_dim,c*N_head) 99 | self.olinear=basic.Linear(c*N_head,z_dim) 100 | 101 | def forward(self,z_): 102 | L1,L2,D=z_.shape 103 | z = self.norm(z_) 104 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 105 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 106 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 107 | b = self.blinear(z) 108 | att = torch.einsum('blhc,kbhc->blkh',q,k)*self.sq_c + b[None,:,:,:].permute(0,2,1,3) 109 | att = F.softmax(SignedSqrt(att),dim=2) 110 | o = torch.einsum('blkh,klhc->blhc',att,v) 111 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 112 | o = self.olinear(o) 113 | return o 114 | def forward2(self,z_): 115 | z = z_.permute(1,0,2) 116 | L1,L2,D=z_.shape 117 | z = self.norm(z_) 118 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 119 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 120 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 121 | b = self.blinear(z) 122 | att = torch.einsum('blhc,bkhc->blkh',q,k)*self.sq_c + b[None,:,:,:] 123 | att = F.softmax(att,dim=2) 124 | o = torch.einsum('blkh,bkhc->blhc',att,v) 125 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 126 | o = self.olinear(o) 127 | o = o.permute(1,0,2) 128 | return o 129 | class PairTrans(nn.Module): 130 | def __init__(self,z_dim,c_expand=2): 131 | super(PairTrans,self).__init__() 132 | self.z_dim=z_dim 133 | self.c_expand=c_expand 134 | self.norm = nn.LayerNorm(z_dim) 135 | self.linear1=basic.Linear(z_dim,z_dim*c_expand) 136 | self.linear2=basic.Linear(z_dim*c_expand,z_dim) 137 | def forward(self,z): 138 | a = self.linear1(self.norm(z)) 139 | a = self.linear2(F.relu(a)) 140 | return a 141 | 142 | 143 | 144 | 145 | if __name__ == "__main__": 146 | N=10 147 | L=30 148 | m_dim=16 149 | z_dim=8 150 | m=torch.rand(N,L,m_dim) 151 | z=torch.rand(L,L,z_dim) 152 | 153 | tr1=TriAttEnd(z_dim) 154 | tr2=PairTrans(z_dim) 155 | y=tr1(z) 156 | y2=tr1.forward2(z) 157 | y3=tr2(z) 158 | print(y3.shape) 159 | 160 | -------------------------------------------------------------------------------- /cfg_97/RNALM2/Evoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from RNALM2 import basic,EvoPair,EvoMSA 5 | import math,sys 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class EvoBlock(nn.Module): 10 | def __init__(self,m_dim,z_dim,docheck=False): 11 | super(EvoBlock,self).__init__() 12 | N_head = 16 13 | c = 16 14 | self.msa_row=EvoMSA.MSARow(m_dim,z_dim,N_head,c) 15 | self.msa_col=EvoMSA.MSACol(m_dim,N_head,c) 16 | self.msa_trans=EvoMSA.MSATrans(m_dim) 17 | 18 | self.msa_opm=EvoMSA.MSAOPM(m_dim,z_dim) 19 | 20 | self.pair_triout=EvoPair.TriOut(z_dim,72) 21 | self.pair_triin =EvoPair.TriIn(z_dim,72) 22 | self.pair_tristart=EvoPair.TriAttStart(z_dim) 23 | self.pair_triend =EvoPair.TriAttEnd(z_dim) 24 | self.pair_trans = EvoPair.PairTrans(z_dim) 25 | self.docheck=docheck 26 | if docheck: 27 | print('will do checkpoint') 28 | 29 | def layerfunc_msa_row(self,m,z): 30 | return self.msa_row(m,z) + m 31 | def layerfunc_msa_col(self,m): 32 | return self.msa_col(m) + m 33 | def layerfunc_msa_trans(self,m): 34 | return self.msa_trans(m) + m 35 | def layerfunc_msa_opm(self,m,z): 36 | return self.msa_opm(m) + z 37 | 38 | def layerfunc_pair_triout(self,z): 39 | return self.pair_triout(z) + z 40 | def layerfunc_pair_triin(self,z): 41 | return self.pair_triin(z) + z 42 | def layerfunc_pair_tristart(self,z): 43 | return self.pair_tristart(z) + z 44 | def layerfunc_pair_triend(self,z): 45 | return self.pair_triend(z) + z 46 | def layerfunc_pair_trans(self,z): 47 | return self.pair_trans(z) + z 48 | def forward(self,m,z): 49 | if True: 50 | m = m + self.msa_row(m,z) 51 | m = m + self.msa_col(m) 52 | m = m + self.msa_trans(m) 53 | z = z + self.msa_opm(m) 54 | z = z + self.pair_triout(z) 55 | z = z + self.pair_triin(z) 56 | #z = z + self.pair_tristart(z) 57 | #z = z + self.pair_triend(z) 58 | z = z + self.pair_trans(z) 59 | return m,z 60 | else: 61 | m=checkpoint(self.layerfunc_msa_row,m,z) 62 | m=checkpoint(self.layerfunc_msa_col,m) 63 | m=checkpoint(self.layerfunc_msa_trans,m) 64 | z=checkpoint(self.layerfunc_msa_opm,m,z) 65 | 66 | z=checkpoint(self.layerfunc_pair_triout,z) 67 | z=checkpoint(self.layerfunc_pair_triin,z) 68 | z=checkpoint(self.layerfunc_pair_tristart,z) 69 | z=checkpoint(self.layerfunc_pair_triend,z) 70 | z=checkpoint(self.layerfunc_pair_trans,z) 71 | 72 | return m,z 73 | 74 | 75 | class Evoformer(nn.Module): 76 | def __init__(self,m_dim,z_dim,N_elayers=12,docheck=False): 77 | super(Evoformer,self).__init__() 78 | self.layers=[N_elayers] 79 | self.docheck=docheck 80 | if docheck: 81 | pass 82 | #print('will do checkpoint') 83 | self.evos=nn.ModuleList([EvoBlock(m_dim,z_dim,True) for i in range(self.layers[0])]) 84 | 85 | def layerfunc(self,layermodule,m,z): 86 | m_,z_=layermodule(m,z) 87 | return m_,z_ 88 | 89 | 90 | def forward(self,m,z): 91 | 92 | if True: 93 | #print('will do checkpoint in Evoformer') 94 | for i in range(self.layers[0]): 95 | #m,z=checkpoint(self.layerfunc,self.evos[i],m,z) 96 | m,z=self.evos[i](m,z) 97 | return m,z 98 | else: 99 | for i in range(self.layers[0]): 100 | m,z=self.evos[i](m,z) 101 | 102 | return m,z 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | if __name__ == "__main__": 113 | N=10 114 | L=30 115 | m_dim=16 116 | z_dim=8 117 | m=torch.rand(N,L,m_dim) 118 | z=torch.rand(L,L,z_dim) 119 | model = Evoformer(m_dim,z_dim) 120 | m,z=model(m,z) 121 | print(model.parameters()) 122 | for param in model.parameters(): 123 | print(type(param), param.size()) 124 | print(m.shape,z.shape) -------------------------------------------------------------------------------- /cfg_97/RNALM2/Model.py: -------------------------------------------------------------------------------- 1 | from numpy import select 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from RNALM2 import basic,Evoformer 6 | import math,sys 7 | from torch.utils.checkpoint import checkpoint 8 | import numpy as np 9 | 10 | 11 | def one_d(idx_, d, max_len=2056*8): 12 | idx = idx_[None] 13 | K = torch.arange(d//2).to(idx.device) 14 | sin_e = torch.sin(idx[..., None] * math.pi / (max_len**(2*K[None]/d))).to(idx.device) 15 | cos_e = torch.cos(idx[..., None] * math.pi / (max_len**(2*K[None]/d))).to(idx.device) 16 | return torch.cat([sin_e, cos_e], axis=-1)[0] 17 | 18 | 19 | 20 | 21 | 22 | 23 | class RNAembedding(nn.Module): 24 | def __init__(self,cfg): 25 | super(RNAembedding,self).__init__() 26 | self.s_in_dim=cfg['s_in_dim'] 27 | self.z_in_dim=cfg['z_in_dim'] 28 | self.s_dim=cfg['s_dim'] 29 | self.z_dim=cfg['z_dim'] 30 | self.qlinear =basic.Linear(self.s_in_dim+1,self.z_dim) 31 | self.klinear =basic.Linear(self.s_in_dim+1,self.z_dim) 32 | self.slinear =basic.Linear(self.s_in_dim+1,self.s_dim) 33 | self.zlinear =basic.Linear(self.z_in_dim+1,self.z_dim) 34 | 35 | self.poslinears = basic.Linear(64,self.s_dim) 36 | self.poslinearz = basic.Linear(64,self.z_dim) 37 | def forward(self,in_dict): 38 | # msa N L D, seq L D 39 | # mask: maksing, L, 1 means masked 40 | # aa: L x s_in_dim 41 | # ss: L x L x 2 42 | # idx: L (LongTensor) 43 | L = in_dict['aa'].shape[0] 44 | aamask = in_dict['mask'][:,None] 45 | zmask = in_dict['mask'][:,None] + in_dict['mask'][None,:] 46 | zmask[zmask>0.5]=1 47 | zmask = zmask[...,None] 48 | s = torch.cat([aamask,(1-aamask)*in_dict['aa']],dim=-1) 49 | sq=self.qlinear(s) 50 | sk=self.klinear(s) 51 | z=sq[None,:,:]+sk[:,None,:] 52 | seq_idx = in_dict['idx'][None] 53 | relative_pos = seq_idx[:, :, None] - seq_idx[:, None, :] 54 | relative_pos = relative_pos.reshape([1, -1]) 55 | relative_pos =one_d(relative_pos,64) 56 | z = z + self.poslinearz( relative_pos.reshape([1, L, L, -1])[0] ) 57 | 58 | s = self.slinear(s) + self.poslinears( one_d(in_dict['idx'], 64) ) 59 | 60 | return s,z 61 | 62 | 63 | class RNA2nd(nn.Module): 64 | def __init__(self,cfg): 65 | super(RNA2nd,self).__init__() 66 | self.s_in_dim=cfg['s_in_dim'] 67 | self.z_in_dim=cfg['z_in_dim'] 68 | self.s_dim=cfg['s_dim'] 69 | self.z_dim=cfg['z_dim'] 70 | self.N_elayers =cfg['N_elayers'] 71 | self.emb = RNAembedding(cfg) 72 | self.evmodel=Evoformer.Evoformer(self.s_dim,self.z_dim,self.N_elayers) 73 | self.seq_head = basic.Linear(self.s_dim,self.s_in_dim) 74 | self.joint_head = basic.Linear(self.z_dim,self.s_in_dim*self.s_in_dim) 75 | 76 | 77 | 78 | 79 | def embedding(self,in_dict): 80 | s,z = self.emb(in_dict) 81 | s,z = self.evmodel(s[None,...],z) 82 | return s[0],z 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /cfg_97/RNALM2/basic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import random 5 | class Linear(nn.Module): 6 | def __init__(self,dim_in,dim_out): 7 | super(Linear,self).__init__() 8 | self.linear = nn.Linear(dim_in,dim_out) 9 | def forward(self,x): 10 | x = self.linear(x) 11 | return x 12 | 13 | 14 | class LinearNoBias(nn.Module): 15 | def __init__(self,dim_in,dim_out): 16 | super(LinearNoBias,self).__init__() 17 | self.linear = nn.Linear(dim_in,dim_out,bias=False) 18 | def forward(self,x): 19 | x = self.linear(x) 20 | return x 21 | 22 | 23 | 24 | def transform(k,rotation,translation): 25 | # K L x 3 26 | # rotation 27 | return torch.matmul(k,rotation) + translation 28 | 29 | 30 | def batch_transform(k,rotation,translation): 31 | # k: L 3 32 | # rotation: L 3 x 3 33 | # translation: L 3 34 | return torch.einsum('ba,bad->bd',k,rotation) + translation 35 | 36 | def batch_atom_transform(k,rotation,translation): 37 | # k: L N 3 38 | # rotation: L 3 x 3 39 | # translation: L 3 40 | return torch.einsum('bja,bad->bjd',k,rotation) + translation[:,None,:] 41 | 42 | def IPA_transform(k,rotation,translation): 43 | # k: L d1, d2, 3 44 | # rotation: L 3 x 3 45 | # translation: L 3 46 | return torch.einsum('bija,bad->bijd',k,rotation)+translation[:,None,None,:] 47 | 48 | def IPA_inverse_transform(k,rotation,translation): 49 | # k: L d1, d2, 3 50 | # rotation: L 3 x 3 51 | # translation: L 3 52 | return torch.einsum('bija,bad->bijd',k-translation[:,None,None,:],rotation.transpose(-1,-2)) 53 | 54 | def update_transform(t,tr,rotation,translation): 55 | return torch.einsum('bja,bad->bjd',t,rotation),torch.einsum('ba,bad->bd',tr,rotation) +translation 56 | 57 | 58 | def quat2rot(q,L): 59 | scale= ((q**2).sum(dim=-1,keepdim=True) +1) [:,:,None] 60 | u=torch.empty([L,3,3],device=q.device) 61 | u[:,0,0]=1*1+q[:,0]*q[:,0]-q[:,1]*q[:,1]-q[:,2]*q[:,2] 62 | u[:,0,1]=2*(q[:,0]*q[:,1]-1*q[:,2]) 63 | u[:,0,2]=2*(q[:,0]*q[:,2]+1*q[:,1]) 64 | u[:,1,0]=2*(q[:,0]*q[:,1]+1*q[:,2]) 65 | u[:,1,1]=1*1-q[:,0]*q[:,0]+q[:,1]*q[:,1]-q[:,2]*q[:,2] 66 | u[:,1,2]=2*(q[:,1]*q[:,2]-1*q[:,0]) 67 | u[:,2,0]=2*(q[:,0]*q[:,2]-1*q[:,1]) 68 | u[:,2,1]=2*(q[:,1]*q[:,2]+1*q[:,0]) 69 | u[:,2,2]=1*1-q[:,0]*q[:,0]-q[:,1]*q[:,1]+q[:,2]*q[:,2] 70 | return u/scale 71 | 72 | 73 | def rotation_x(sintheta,costheta,ones,zeros): 74 | # L x 1 75 | return torch.stack([torch.stack([ones, zeros, zeros]), 76 | torch.stack([zeros, costheta, sintheta]), 77 | torch.stack([zeros, -sintheta, costheta])]) 78 | def rotation_y(sintheta,costheta,ones,zeros): 79 | # L x 1 80 | return torch.stack([torch.stack([costheta, zeros, sintheta]), 81 | torch.stack([zeros, ones, zeros]), 82 | torch.stack([-sintheta, zeros, costheta])]) 83 | def rotation_z(sintheta,costheta,ones,zeros): 84 | # L x 1 85 | return torch.stack([torch.stack([costheta, sintheta, zeros]), 86 | torch.stack([-sintheta, costheta, zeros]), 87 | torch.stack([zeros, zeros, ones])]) 88 | def batch_rotation(k,rotation): 89 | # k: L 3 90 | # rotation: L 3 x 3 91 | # translation: L 3 92 | return torch.einsum('ba,bad->bd',k,rotation) 93 | 94 | def compute_cb(bl,sin_angle,cos_angle,sin_torsion,cos_torsion): 95 | L=bl.shape[0] 96 | ones=torch.ones(L,device=bl.device) 97 | zeros=torch.zeros(L,device=bl.device) 98 | cb=torch.stack([bl,zeros,zeros]).permute(1,0) 99 | rotz=rotation_z(sin_angle,cos_angle,ones,zeros).permute(2,0,1) 100 | rotx=rotation_x(sin_torsion,cos_torsion,ones,zeros).permute(2,0,1) 101 | cb=batch_rotation(cb,rotz) 102 | cb=batch_rotation(cb,rotx) 103 | return cb 104 | 105 | def rigidFrom3Points_(x1,x2,x3): 106 | v1=x3-x2 107 | v2=x1-x2 108 | e1=v1/(torch.norm(v1,dim=-1,keepdim=True) + 1e-03) 109 | u2=v2 - e1*(torch.einsum('bn,bn->b',e1,v2)[:,None]) 110 | e2 = u2/(torch.norm(u2,dim=-1,keepdim=True) + 1e-03) 111 | e3=torch.cross(e1,e2,dim=-1) 112 | 113 | return torch.stack([e1,e2,e3],dim=1),x2[:,:] 114 | def rigidFrom3Points(x1,x2,x3):# L 3 115 | the_dim=1 116 | x = torch.stack([x1,x2,x3],dim=the_dim) 117 | x_mean = torch.mean(x,dim=the_dim,keepdim=True) 118 | x = x - x_mean 119 | 120 | 121 | m = x.view(-1, 3, 3) 122 | u, s, v = torch.svd(m) 123 | vt = torch.transpose(v, 1, 2) 124 | det = torch.det(torch.matmul(u, vt)) 125 | det = det.view(-1, 1, 1) 126 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 127 | r = torch.matmul(u, vt) 128 | return r,x_mean.squeeze() 129 | def Kabsch_rigid(bases,x1,x2,x3): 130 | ''' 131 | return the direction from to_q to from_p 132 | ''' 133 | the_dim=1 134 | to_q = torch.stack([x1,x2,x3],dim=the_dim) 135 | biasq=torch.mean(to_q,dim=the_dim,keepdim=True) 136 | q=to_q-biasq 137 | m = torch.einsum('bnz,bny->bzy',bases,q) 138 | u, s, v = torch.svd(m) 139 | vt = torch.transpose(v, 1, 2) 140 | det = torch.det(torch.matmul(u, vt)) 141 | det = det.view(-1, 1, 1) 142 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 143 | r = torch.matmul(u, vt) 144 | return r,biasq.squeeze() 145 | def Generate_msa_mask(n,l): 146 | # 1: 15% mask out 147 | randommatrix=torch.rand(n,l) 148 | mask = (randommatrix <0.1).float() 149 | # 2 random a segment 150 | seqlength = int(l*0.1) 151 | sindex=round(random.random()*(l-seqlength)) 152 | endindex=min(l,sindex+seqlength) 153 | mask[:,sindex:endindex]=1 154 | return mask 155 | 156 | 157 | -------------------------------------------------------------------------------- /cfg_97/base.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/cfg_97/base.npy -------------------------------------------------------------------------------- /cfg_97/basic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import random 5 | class Linear(nn.Module): 6 | def __init__(self,dim_in,dim_out): 7 | super(Linear,self).__init__() 8 | self.linear = nn.Linear(dim_in,dim_out) 9 | def forward(self,x): 10 | x = self.linear(x) 11 | return x 12 | 13 | 14 | class LinearNoBias(nn.Module): 15 | def __init__(self,dim_in,dim_out): 16 | super(LinearNoBias,self).__init__() 17 | self.linear = nn.Linear(dim_in,dim_out,bias=False) 18 | def forward(self,x): 19 | x = self.linear(x) 20 | return x 21 | 22 | def DropAtt(x,dim,droprate = 0.25): 23 | shapes = x.shape 24 | L = shapes[dim] 25 | num_of_dim = len(shapes) 26 | randmask = torch.rand(L).to(x.device) 27 | themask = randmaskbd',k,rotation) + translation 49 | 50 | def batch_atom_transform(k,rotation,translation): 51 | # k: L N 3 52 | # rotation: L 3 x 3 53 | # translation: L 3 54 | return torch.einsum('bja,bad->bjd',k,rotation) + translation[:,None,:] 55 | 56 | def IPA_transform(k,rotation,translation): 57 | # k: L d1, d2, 3 58 | # rotation: L 3 x 3 59 | # translation: L 3 60 | return torch.einsum('bija,bad->bijd',k,rotation)+translation[:,None,None,:] 61 | 62 | def IPA_inverse_transform(k,rotation,translation): 63 | # k: L d1, d2, 3 64 | # rotation: L 3 x 3 65 | # translation: L 3 66 | return torch.einsum('bija,bad->bijd',k-translation[:,None,None,:],rotation.transpose(-1,-2)) 67 | 68 | def update_transform(t,tr,rotation,translation): 69 | return torch.einsum('bja,bad->bjd',t,rotation),torch.einsum('ba,bad->bd',tr,rotation) +translation 70 | 71 | 72 | def quat2rot(q,L): 73 | scale= ((q**2).sum(dim=-1,keepdim=True) +1) [:,:,None] 74 | u=torch.empty([L,3,3],device=q.device) 75 | u[:,0,0]=1*1+q[:,0]*q[:,0]-q[:,1]*q[:,1]-q[:,2]*q[:,2] 76 | u[:,0,1]=2*(q[:,0]*q[:,1]-1*q[:,2]) 77 | u[:,0,2]=2*(q[:,0]*q[:,2]+1*q[:,1]) 78 | u[:,1,0]=2*(q[:,0]*q[:,1]+1*q[:,2]) 79 | u[:,1,1]=1*1-q[:,0]*q[:,0]+q[:,1]*q[:,1]-q[:,2]*q[:,2] 80 | u[:,1,2]=2*(q[:,1]*q[:,2]-1*q[:,0]) 81 | u[:,2,0]=2*(q[:,0]*q[:,2]-1*q[:,1]) 82 | u[:,2,1]=2*(q[:,1]*q[:,2]+1*q[:,0]) 83 | u[:,2,2]=1*1-q[:,0]*q[:,0]-q[:,1]*q[:,1]+q[:,2]*q[:,2] 84 | return u/scale 85 | 86 | 87 | def rotation_x(sintheta,costheta,ones,zeros): 88 | # L x 1 89 | return torch.stack([torch.stack([ones, zeros, zeros]), 90 | torch.stack([zeros, costheta, sintheta]), 91 | torch.stack([zeros, -sintheta, costheta])]) 92 | def rotation_y(sintheta,costheta,ones,zeros): 93 | # L x 1 94 | return torch.stack([torch.stack([costheta, zeros, sintheta]), 95 | torch.stack([zeros, ones, zeros]), 96 | torch.stack([-sintheta, zeros, costheta])]) 97 | def rotation_z(sintheta,costheta,ones,zeros): 98 | # L x 1 99 | return torch.stack([torch.stack([costheta, sintheta, zeros]), 100 | torch.stack([-sintheta, costheta, zeros]), 101 | torch.stack([zeros, zeros, ones])]) 102 | def batch_rotation(k,rotation): 103 | # k: L 3 104 | # rotation: L 3 x 3 105 | # translation: L 3 106 | return torch.einsum('ba,bad->bd',k,rotation) 107 | 108 | def compute_cb(bl,sin_angle,cos_angle,sin_torsion,cos_torsion): 109 | L=bl.shape[0] 110 | ones=torch.ones(L,device=bl.device) 111 | zeros=torch.zeros(L,device=bl.device) 112 | cb=torch.stack([bl,zeros,zeros]).permute(1,0) 113 | rotz=rotation_z(sin_angle,cos_angle,ones,zeros).permute(2,0,1) 114 | rotx=rotation_x(sin_torsion,cos_torsion,ones,zeros).permute(2,0,1) 115 | cb=batch_rotation(cb,rotz) 116 | cb=batch_rotation(cb,rotx) 117 | return cb 118 | 119 | def rigidFrom3Points_(x1,x2,x3): 120 | v1=x3-x2 121 | v2=x1-x2 122 | e1=v1/(torch.norm(v1,dim=-1,keepdim=True) + 1e-03) 123 | u2=v2 - e1*(torch.einsum('bn,bn->b',e1,v2)[:,None]) 124 | e2 = u2/(torch.norm(u2,dim=-1,keepdim=True) + 1e-03) 125 | e3=torch.cross(e1,e2,dim=-1) 126 | 127 | return torch.stack([e1,e2,e3],dim=1),x2[:,:] 128 | def rigidFrom3Points(x1,x2,x3):# L 3 129 | the_dim=1 130 | x = torch.stack([x1,x2,x3],dim=the_dim) 131 | x_mean = torch.mean(x,dim=the_dim,keepdim=True) 132 | x = x - x_mean 133 | 134 | 135 | m = x.view(-1, 3, 3) 136 | u, s, v = torch.svd(m) 137 | vt = torch.transpose(v, 1, 2) 138 | det = torch.det(torch.matmul(u, vt)) 139 | det = det.view(-1, 1, 1) 140 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 141 | r = torch.matmul(u, vt) 142 | return r,x_mean.squeeze() 143 | def Kabsch_rigid(bases,x1,x2,x3): 144 | ''' 145 | return the direction from to_q to from_p 146 | ''' 147 | the_dim=1 148 | to_q = torch.stack([x1,x2,x3],dim=the_dim) 149 | biasq=torch.mean(to_q,dim=the_dim,keepdim=True) 150 | q=to_q-biasq 151 | m = torch.einsum('bnz,bny->bzy',bases,q) 152 | u, s, v = torch.svd(m) 153 | vt = torch.transpose(v, 1, 2) 154 | det = torch.det(torch.matmul(u, vt)) 155 | det = det.view(-1, 1, 1) 156 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 157 | r = torch.matmul(u, vt) 158 | return r,biasq.squeeze() 159 | def Generate_msa_mask(n,l): 160 | # 1: 15% mask out 161 | randommatrix=torch.rand(n,l) 162 | mask = (randommatrix <0.1).float() 163 | # 2 random a segment 164 | seqlength = int(l*0.1) 165 | sindex=round(random.random()*(l-seqlength)) 166 | endindex=min(l,sindex+seqlength) 167 | mask[:,sindex:endindex]=1 168 | return mask 169 | 170 | 171 | -------------------------------------------------------------------------------- /cfg_97/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os,math 3 | import numpy as np 4 | from torch.utils.data import Dataset, DataLoader 5 | from random import sample 6 | from numpy import float32 7 | import random 8 | from scipy.spatial.distance import cdist 9 | from subprocess import Popen, PIPE, STDOUT 10 | expdir=os.path.dirname(os.path.abspath(__file__)) 11 | #lines = open(os.path.join(expdir,'newconfig')).readlines() 12 | #attdrop = lines[0].strip().split()[-1] == '1' 13 | # denoisee2e = lines[1].strip().split()[-1] == '1' 14 | # ss_type = lines[2].strip().split()[-1] 15 | code_standard = { 16 | 'A':'A','G':'G','C':'C','U':'U','a':'A','g':'G','c':'C','u':'U','T':'U','t':'U' 17 | } 18 | expdir=os.path.dirname(os.path.abspath(__file__)) 19 | parentdir = os.path.dirname(expdir) 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | def parse_seq(inseq): 30 | seqnpy=np.zeros(len(inseq)) 31 | seq1=np.array(list(inseq)) 32 | seqnpy[seq1=='A']=1 33 | seqnpy[seq1=='G']=2 34 | seqnpy[seq1=='C']=3 35 | seqnpy[seq1=='U']=4 36 | seqnpy[seq1=='T']=4 37 | return seqnpy 38 | 39 | 40 | def Get_base(seq,basenpy_standard): 41 | basenpy = np.zeros([len(seq),3,3]) 42 | seqnpy = np.array(list(seq)) 43 | basenpy[seqnpy=='A']=basenpy_standard[0] 44 | basenpy[seqnpy=='a']=basenpy_standard[0] 45 | 46 | basenpy[seqnpy=='G']=basenpy_standard[1] 47 | basenpy[seqnpy=='g']=basenpy_standard[1] 48 | 49 | basenpy[seqnpy=='C']=basenpy_standard[2] 50 | basenpy[seqnpy=='c']=basenpy_standard[2] 51 | 52 | basenpy[seqnpy=='U']=basenpy_standard[3] 53 | basenpy[seqnpy=='u']=basenpy_standard[3] 54 | 55 | basenpy[seqnpy=='T']=basenpy_standard[3] 56 | basenpy[seqnpy=='t']=basenpy_standard[3] 57 | return basenpy 58 | -------------------------------------------------------------------------------- /cfg_97/newconfig: -------------------------------------------------------------------------------- 1 | attdrop: 1 2 | denoisee2e: 1 3 | ss_type: attention -------------------------------------------------------------------------------- /cfg_97/test_modeldir.py: -------------------------------------------------------------------------------- 1 | import random 2 | random.seed(0) 3 | import numpy as np 4 | np.random.seed(0) 5 | import os,sys,re,random 6 | from numpy import select 7 | import torch 8 | torch.manual_seed(0) 9 | torch.backends.cudnn.deterministic = True 10 | torch.backends.cudnn.benchmark = False 11 | expdir=os.path.dirname(os.path.abspath(__file__)) 12 | 13 | # from pathlib import Path 14 | # path = Path(expdir) 15 | # parepath = path.parent.absolute() 16 | 17 | 18 | import torch.optim as opt 19 | from torch.nn import functional as F 20 | import data,util 21 | import EvoMSA2XYZ,basic 22 | import math 23 | import pickle 24 | Batch_size=3 25 | Num_cycle=3 26 | TEST_STEP=1000 27 | VISION_STEP=50 28 | device = sys.argv[1] 29 | 30 | 31 | expdir=os.path.dirname(os.path.abspath(__file__)) 32 | expround=expdir.split('_')[-1] 33 | model_path=os.path.join(expdir,'others','models') 34 | 35 | testdir=os.path.join(expdir,'others','preds') 36 | 37 | 38 | 39 | 40 | basenpy_standard= np.load( os.path.join(os.path.dirname(os.path.abspath(__file__)),'base.npy' ) ) 41 | def data_collect(pdb_seq): 42 | aa_type = data.parse_seq(pdb_seq) 43 | base = data.Get_base(pdb_seq,basenpy_standard) 44 | seq_idx = np.arange(len(pdb_seq)) + 1 45 | 46 | msa=aa_type[None,:] 47 | msa=torch.from_numpy(msa).to(device) 48 | msa=torch.cat([msa,msa],0) 49 | msa=F.one_hot(msa.long(),6).float() 50 | 51 | base_x = torch.from_numpy(base).float().to(device) 52 | seq_idx = torch.from_numpy(seq_idx).long().to(device) 53 | return msa,base_x,seq_idx 54 | predxs,plddts = model.pred(msa,seq_idx,ss,base_x,sample_1['alpha_0']) 55 | 56 | 57 | 58 | def classifier(infasta,out_prefix,model_dir): 59 | with torch.no_grad(): 60 | lines = open(infasta).readlines()[1:] 61 | seqs = [aline.strip() for aline in lines] 62 | seq = ''.join(seqs) 63 | msa,base_x,seq_idx = data_collect(seq) 64 | # seq_idx = np.genfromtxt(idxfile).astype(int) 65 | # seq_idx = torch.from_numpy(seq_idx).long().to(device) 66 | 67 | msa_dim=6+1 68 | m_dim,s_dim,z_dim = 64,64,64 69 | N_ensemble,N_cycle=3,8 70 | model=EvoMSA2XYZ.MSA2XYZ(msa_dim-1,msa_dim,N_ensemble,N_cycle,m_dim,s_dim,z_dim) 71 | model.to(device) 72 | model.eval() 73 | models = os.listdir( model_dir ) 74 | models = [amodel for amodel in models if 'model' in amodel and 'opt' not in amodel] 75 | 76 | models.sort() 77 | # models = models[5:] 78 | 79 | for amodel in models: 80 | #saved_model=os.path.join(expdir,'others','models',amodel) 81 | saved_model=os.path.join(model_dir,amodel) 82 | model.load_state_dict(torch.load(saved_model,map_location='cpu'),strict=True) 83 | ret = model.pred(msa,seq_idx,None,base_x,np.array(list(seq))) 84 | 85 | util.outpdb(ret['coor'],seq_idx,seq,out_prefix+f'{amodel}.pdb') 86 | #ret = {'plddt':ret['plddt']} 87 | pickle.dump(ret,open(out_prefix+f'{amodel}.ret','wb')) 88 | # for akey in ret: 89 | # print(akey,ret[akey].shape) 90 | 91 | 92 | 93 | 94 | 95 | 96 | if __name__ == '__main__': 97 | infasta,out_prefix,model_dir = sys.argv[2],sys.argv[3],sys.argv[4] 98 | classifier(infasta,out_prefix,model_dir) -------------------------------------------------------------------------------- /cfg_97/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from subprocess import Popen, PIPE, STDOUT 4 | import os,sys 5 | 6 | 7 | 8 | def outpdb(coor,seq_idx,seq,savefile,start=0,end=10000,energystr=''): 9 | #rama=torch.cat([rama.view(self.L,2),self.betas],dim=-1) 10 | L = coor.shape[0] 11 | 12 | Atom_name=[' P '," C4'",' N1 '] 13 | Other_Atom_name = [" O5'"," C5'"," C3'"," O3'"," C1'"] 14 | other_last_name = ['O',"C","C","O","C"] 15 | 16 | 17 | last_name=['P','C','N'] 18 | wstr=[f'REMARK {str(energystr)}'] 19 | templet='%6s%5d %4s %3s %1s%4d %8.3f%8.3f%8.3f%6.2f%6.2f %2s%2s' 20 | count=1 21 | for i in range(L): 22 | if seq[i] in ['a','g','A','G']: 23 | Atom_name = [' P '," C4'",' N9 '] 24 | #atoms = ['P','C4'] 25 | 26 | elif seq[i] in ['c','u','C','U']: 27 | Atom_name = [' P '," C4'",' N1 '] 28 | for j in range(coor.shape[1]): 29 | outs=('ATOM ',count,Atom_name[j],seq[i],'A',seq_idx[i],coor[i][j][0],coor[i][j][1],coor[i][j][2],0,0,last_name[j],'') 30 | #outs=('ATOM ',count,Atom_name[j],'ALA','A',i+1,coor_np[i][j][0],coor_np[i][j][1],coor_np[i][j][2],1.0,90,last_name[j],'') 31 | #print(outs) 32 | if i>=start-1 and i < end: 33 | wstr.append(templet % outs) 34 | 35 | # for j in range(other_np.shape[1]): 36 | # outs=('ATOM ',count,Other_Atom_name[j],self.seq[i],'A',i+1,other_np[i][j][0],other_np[i][j][1],other_np[i][j][2],0,0,other_last_name[j],'') 37 | # #outs=('ATOM ',count,Atom_name[j],'ALA','A',i+1,coor_np[i][j][0],coor_np[i][j][1],coor_np[i][j][2],1.0,90,last_name[j],'') 38 | # #print(outs) 39 | # if i>=start-1 and i < end: 40 | # wstr.append(templet % outs) 41 | count+=1 42 | wstr.append('TER') 43 | wstr='\n'.join(wstr) 44 | wfile=open(savefile,'w') 45 | wfile.write(wstr) 46 | wfile.close() 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /cfg_99/Evoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import basic,EvoPair,EvoMSA 5 | import math,sys 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class EvoBlock(nn.Module): 10 | def __init__(self,m_dim,z_dim,docheck=False): 11 | super(EvoBlock,self).__init__() 12 | self.msa_row=EvoMSA.MSARow(m_dim,z_dim) 13 | self.msa_col=EvoMSA.MSACol(m_dim) 14 | self.msa_trans=EvoMSA.MSATrans(m_dim) 15 | 16 | self.msa_opm=EvoMSA.MSAOPM(m_dim,z_dim) 17 | 18 | self.pair_triout=EvoPair.TriOut(z_dim) 19 | self.pair_triin =EvoPair.TriIn(z_dim) 20 | self.pair_tristart=EvoPair.TriAttStart(z_dim) 21 | self.pair_triend =EvoPair.TriAttEnd(z_dim) 22 | self.pair_trans = EvoPair.PairTrans(z_dim) 23 | self.docheck=docheck 24 | if docheck: 25 | print('will do checkpoint') 26 | 27 | def layerfunc_msa_row(self,m,z): 28 | return self.msa_row(m,z) + m 29 | def layerfunc_msa_col(self,m): 30 | return self.msa_col(m) + m 31 | def layerfunc_msa_trans(self,m): 32 | return self.msa_trans(m) + m 33 | def layerfunc_msa_opm(self,m,z): 34 | return self.msa_opm(m) + z 35 | 36 | def layerfunc_pair_triout(self,z): 37 | return self.pair_triout(z) + z 38 | def layerfunc_pair_triin(self,z): 39 | return self.pair_triin(z) + z 40 | def layerfunc_pair_tristart(self,z): 41 | return self.pair_tristart(z) + z 42 | def layerfunc_pair_triend(self,z): 43 | return self.pair_triend(z) + z 44 | def layerfunc_pair_trans(self,z): 45 | return self.pair_trans(z) + z 46 | def forward(self,m,z): 47 | if True: 48 | m = m + self.msa_row(m,z) 49 | m = m + self.msa_col(m) 50 | m = m + self.msa_trans(m) 51 | z = z + self.msa_opm(m) 52 | z = z + self.pair_triout(z) 53 | z = z + self.pair_triin(z) 54 | z = z + self.pair_tristart(z) 55 | z = z + self.pair_triend(z) 56 | z = z + self.pair_trans(z) 57 | return m,z 58 | else: 59 | m=checkpoint(self.layerfunc_msa_row,m,z) 60 | m=checkpoint(self.layerfunc_msa_col,m) 61 | m=checkpoint(self.layerfunc_msa_trans,m) 62 | z=checkpoint(self.layerfunc_msa_opm,m,z) 63 | 64 | z=checkpoint(self.layerfunc_pair_triout,z) 65 | z=checkpoint(self.layerfunc_pair_triin,z) 66 | z=checkpoint(self.layerfunc_pair_tristart,z) 67 | z=checkpoint(self.layerfunc_pair_triend,z) 68 | z=checkpoint(self.layerfunc_pair_trans,z) 69 | 70 | return m,z 71 | 72 | 73 | class Evoformer(nn.Module): 74 | def __init__(self,m_dim,z_dim,docheck=False): 75 | super(Evoformer,self).__init__() 76 | self.layers=[16] 77 | self.docheck=docheck 78 | if docheck: 79 | pass 80 | #print('will do checkpoint') 81 | self.evos=nn.ModuleList([EvoBlock(m_dim,z_dim,True) for i in range(self.layers[0])]) 82 | 83 | def layerfunc(self,layermodule,m,z): 84 | m_,z_=layermodule(m,z) 85 | return m_,z_ 86 | 87 | 88 | # def forward(self,m,z): 89 | 90 | # if True: 91 | # #print('will do checkpoint in Evoformer') 92 | # for i in range(self.layers[0]): 93 | # m,z=checkpoint(self.layerfunc,self.evos[i],m,z) 94 | # #m,z=self.evos[i](m,z) 95 | # return m,z 96 | # else: 97 | # for i in range(self.layers[0]): 98 | # m,z=self.evos[i](m,z) 99 | 100 | # return m,z 101 | def forward_n(self,m,z,starti,endi): 102 | for i in range(starti,endi): 103 | #print(i) 104 | m,z=self.evos[i](m,z) 105 | return m,z 106 | def forward(self,m,z): 107 | 108 | # if True: 109 | # #print('will do checkpoint in Evoformer') 110 | # for i in range(self.layers[0]): 111 | # #m,z=checkpoint(self.layerfunc,self.evos[i],m,z) 112 | # m,z=self.evos[i](m,z) 113 | # return m,z 114 | m,z = checkpoint(self.forward_n,m,z,0,3) 115 | m,z = checkpoint(self.forward_n,m,z,3,6) 116 | m,z = checkpoint(self.forward_n,m,z,6,10) 117 | m,z = checkpoint(self.forward_n,m,z,10,13) 118 | m,z = checkpoint(self.forward_n,m,z,13,16) 119 | return m,z 120 | 121 | -------------------------------------------------------------------------------- /cfg_99/IPA.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import basic 6 | import math 7 | import os 8 | expdir=os.path.dirname(os.path.abspath(__file__)) 9 | lines = open(os.path.join(expdir,'newconfig')).readlines() 10 | attdrop = lines[0].strip().split()[-1] == '1' 11 | denoisee2e = lines[1].strip().split()[-1] == '1' 12 | ss_type = lines[2].strip().split()[-1] 13 | class InvariantPointAttention(nn.Module): 14 | def __init__(self,dim_in,dim_z,N_head=8,c=16,N_query=4,N_p_values=6,) -> None: 15 | super(InvariantPointAttention,self).__init__() 16 | self.dim_in=dim_in 17 | self.dim_z=dim_z 18 | self.N_head =N_head 19 | self.c=c 20 | self.c_squ = 1.0/math.sqrt(c) 21 | self.W_c = math.sqrt(2.0/(9*N_query)) 22 | self.W_L = math.sqrt(1.0/3) 23 | self.N_query=N_query 24 | self.N_p_values=N_p_values 25 | self.liner_nb_q1=basic.LinearNoBias(dim_in,self.c*N_head) 26 | self.liner_nb_k1=basic.LinearNoBias(dim_in,self.c*N_head) 27 | self.liner_nb_v1=basic.LinearNoBias(dim_in,self.c*N_head) 28 | 29 | self.liner_nb_q2=basic.LinearNoBias(dim_in,N_head*N_query*3) 30 | self.liner_nb_k2=basic.LinearNoBias(dim_in,N_head*N_query*3) 31 | 32 | self.liner_nb_v3=basic.LinearNoBias(dim_in,N_head*N_p_values*3) 33 | 34 | self.liner_nb_z=basic.LinearNoBias(dim_z,N_head) 35 | self.lastlinear1=basic.Linear(N_head*dim_z,dim_in) 36 | self.lastlinear2=basic.Linear(N_head*c,dim_in) 37 | self.lastlinear3=basic.Linear(N_head*N_p_values*3,dim_in) 38 | self.gama = nn.ParameterList([nn.Parameter(torch.zeros(N_head))]) 39 | self.cos_f=nn.CosineSimilarity(dim=-1) 40 | 41 | def forward(self,s,z,rot,trans): 42 | L=s.shape[0] 43 | q1=self.liner_nb_q1(s).reshape(L,self.N_head,self.c) # Lq, 44 | k1=self.liner_nb_k1(s).reshape(L,self.N_head,self.c) 45 | v1=self.liner_nb_v1(s).reshape(L,self.N_head,self.c) # lv,h,c 46 | 47 | attmap=torch.einsum('ihc,jhc->ijh',q1,k1) * self.c_squ # Lq,Lk_v,h 48 | bias_z=self.liner_nb_z(z) # L L h 49 | 50 | q2 = self.liner_nb_q2(s).reshape(L,self.N_head,self.N_query,3) 51 | k2 = self.liner_nb_k2(s).reshape(L,self.N_head,self.N_query,3) 52 | 53 | v3 = self.liner_nb_v3(s).reshape(L,self.N_head,self.N_p_values,3) 54 | 55 | q2 = basic.IPA_transform(q2,rot,trans) # Lq,self.N_head,self.N_query,3 56 | k2 = basic.IPA_transform(k2,rot,trans) # Lk,self.N_head,self.N_query,3 57 | 58 | dismap=((q2[:,None,:,:,:] - k2[None,:,:,:,:])**2).sum([3,4]) ## Lq,Lk, self.N_head, 59 | #dismap=dismap - (self.cos_f(q2[:,None,:,:,:] , k2[None,:,:,:,:])).sum(3) 60 | attmap = attmap + bias_z - F.softplus(self.gama[0])[None,None,:]*dismap*self.W_c*0.5 61 | #print(torch.max(attmap*self.W_L),torch.min(attmap)*self.W_L) 62 | #attmap = F.softmax( torch.clamp(attmap*self.W_L,-5,5),dim=1 ) # Lk dim, [Lq,Lk, self.N_head] 63 | 64 | attmap = F.softmax( attmap*self.W_L,dim=1 ) # Lk dim, [Lq,Lk, self.N_head] 65 | if attdrop: 66 | if self.training: 67 | attmap = basic.DropAtt(attmap,dim=1) 68 | o1 = (attmap[:,:,:,None] * z[:,:,None,:]).sum(1) # Lq, N_head, c_z 69 | o2 = torch.einsum('abc,dab->dbc',v1,attmap) # Lq, N_head, c 70 | o3 = basic.IPA_transform(v3,rot,trans) # Lv, h, p* ,3 71 | o3 = basic.IPA_inverse_transform( torch.einsum('vhpt,gvh->ghpt',o3,attmap),rot,trans) #Lv, h, p* ,3 72 | 73 | return self.lastlinear1(o1.reshape(L,-1)) + self.lastlinear2(o2.reshape(L,-1)) + self.lastlinear3(o3.reshape(L,-1)) 74 | 75 | 76 | 77 | 78 | 79 | 80 | if __name__ == "__main__": 81 | dim_in,dim_z = 8,4 82 | L = 10 83 | ipa = InvariantPointAttention(dim_in,dim_z) 84 | s=torch.rand(L,dim_in) 85 | z=torch.rand(L,L,dim_z) 86 | rot=(torch.eye(3)[None,:,:]).repeat(L,1,1) 87 | trans=torch.rand(L,3) 88 | 89 | out=ipa(s,z,rot,trans) 90 | print(out) 91 | print(out.shape) 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /cfg_99/RNALM2/EvoMSA.py: -------------------------------------------------------------------------------- 1 | from numpy import select 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from RNALM2 import basic 6 | import math 7 | 8 | def SignedSqrt( x): 9 | x = torch.sqrt(torch.relu(x)) - torch.sqrt(torch.relu(-x)) 10 | return x 11 | class MSARow(nn.Module): 12 | def __init__(self,m_dim,z_dim,N_head=8,c=8): 13 | super(MSARow,self).__init__() 14 | self.N_head = N_head 15 | self.c = c 16 | self.sq_c = 1/math.sqrt(c) 17 | self.norm1=nn.LayerNorm(m_dim) 18 | self.qlinear = basic.LinearNoBias(m_dim,N_head*c) 19 | self.klinear = basic.LinearNoBias(m_dim,N_head*c) 20 | self.vlinear = basic.LinearNoBias(m_dim,N_head*c) 21 | self.norm_z = nn.LayerNorm(z_dim) 22 | self.zlinear = basic.LinearNoBias(z_dim,N_head) 23 | self.glinear = basic.Linear(m_dim,N_head*c) 24 | self.olinear = basic.Linear(N_head*c,m_dim) 25 | 26 | def forward(self,m,z): 27 | # m : N L 32 28 | N,L,D = m.shape 29 | m = self.norm1(m) 30 | q = self.qlinear(m).reshape(N,L,self.N_head,self.c) #s rq h c 31 | k = self.klinear(m).reshape(N,L,self.N_head,self.c) #s rv h c 32 | v = self.vlinear(m).reshape(N,L,self.N_head,self.c) 33 | b = self.zlinear(self.norm_z(z)) 34 | g = torch.sigmoid(self.glinear(m)).reshape(N,L,self.N_head,self.c) 35 | att=torch.einsum('bqhc,bvhc->bqvh',q,k) * (self.sq_c) + b[None,:,:,:] # rq rv h 36 | att=F.softmax(SignedSqrt(att),dim=2) 37 | o = torch.einsum('bqvh,bvhc->bqhc',att,v) * g 38 | m_ = self.olinear(o.reshape(N,L,-1)) 39 | return m_ 40 | 41 | class MSACol(nn.Module): 42 | def __init__(self,m_dim,N_head=8,c=8): 43 | super(MSACol,self).__init__() 44 | self.N_head = N_head 45 | self.c = c 46 | self.sq_c = 1/math.sqrt(c) 47 | self.norm1=nn.LayerNorm(m_dim) 48 | self.qlinear = basic.LinearNoBias(m_dim,N_head*c) 49 | self.klinear = basic.LinearNoBias(m_dim,N_head*c) 50 | self.vlinear = basic.LinearNoBias(m_dim,N_head*c) 51 | 52 | self.glinear = basic.Linear(m_dim,N_head*c) 53 | self.olinear = basic.Linear(N_head*c,m_dim) 54 | 55 | def forward(self,m): 56 | # m : N L 32 57 | N,L,D = m.shape 58 | m = self.norm1(m) 59 | q = self.qlinear(m).reshape(N,L,self.N_head,self.c) #s rq h c 60 | k = self.klinear(m).reshape(N,L,self.N_head,self.c) #s rv h c 61 | v = self.vlinear(m).reshape(N,L,self.N_head,self.c) 62 | 63 | g = torch.sigmoid(self.glinear(m)).reshape(N,L,self.N_head,self.c) 64 | 65 | att=torch.einsum('slhc,tlhc->stlh',q,k) * (self.sq_c) # rq rv h 66 | att=F.softmax(SignedSqrt(att),dim=1) 67 | o = torch.einsum('stlh,tlhc->slhc',att,v) * g 68 | m_ = self.olinear(o.reshape(N,L,-1)) 69 | return m_ 70 | 71 | class MSATrans(nn.Module): 72 | def __init__(self,m_dim,c_expand=2): 73 | super(MSATrans,self).__init__() 74 | self.c_expand=4 75 | self.m_dim=m_dim 76 | self.norm=nn.LayerNorm(m_dim) 77 | self.linear1 = basic.Linear(m_dim,m_dim*c_expand) 78 | self.linear2 = basic.Linear(m_dim*c_expand,m_dim) 79 | def forward(self,m): 80 | m = self.norm(m) 81 | m = self.linear1(m) 82 | m = self.linear2(F.relu(m)) 83 | return m 84 | 85 | class MSAOPM(nn.Module): 86 | def __init__(self,m_dim,z_dim,c=12): 87 | super(MSAOPM,self).__init__() 88 | self.m_dim=m_dim 89 | self.c=c 90 | self.norm=nn.LayerNorm(m_dim) 91 | self.linear1=basic.Linear(m_dim,c) 92 | self.linear2=basic.Linear(m_dim,c) 93 | self.linear3=basic.Linear(c*c,z_dim) 94 | def forward(self,m): 95 | N,L,D=m.shape 96 | o=self.norm(m) 97 | a=self.linear2(o) 98 | b=self.linear1(o) 99 | o = torch.einsum('nia,njb->nijab',a,b).mean(dim=0) 100 | o = self.linear3(o.reshape(L,L,-1)) 101 | return o 102 | 103 | 104 | 105 | 106 | 107 | 108 | if __name__ == "__main__": 109 | N=10 110 | L=30 111 | m_dim=16 112 | z_dim=8 113 | m=torch.rand(N,L,m_dim) 114 | z=torch.rand(L,L,z_dim) 115 | msarow=MSARow(m_dim,z_dim) 116 | msacol=MSACol(m_dim) 117 | msatrans=MSATrans(m_dim) 118 | msaopm=MSAOPM(m_dim,z_dim) 119 | y=msaopm(m) 120 | print(y.shape) -------------------------------------------------------------------------------- /cfg_99/RNALM2/EvoPair.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from RNALM2 import basic 5 | import math 6 | 7 | def SignedSqrt( x): 8 | x = torch.sqrt(torch.relu(x)) - torch.sqrt(torch.relu(-x)) 9 | return x 10 | class TriOut(nn.Module): 11 | def __init__(self,z_dim,c=32): 12 | super(TriOut,self).__init__() 13 | self.z_dim = z_dim 14 | self.norm =nn.LayerNorm(z_dim) 15 | self.onorm =nn.LayerNorm(c) 16 | self.alinear=basic.Linear(z_dim,c) 17 | self.blinear=basic.Linear(z_dim,c) 18 | self.aglinear=basic.Linear(z_dim,c) 19 | self.bglinear=basic.Linear(z_dim,c) 20 | self.glinear =basic.Linear(z_dim,z_dim) 21 | self.olinear=basic.Linear(c,z_dim) 22 | 23 | def forward(self,z): 24 | z = self.norm(z) 25 | a = self.alinear(z) * torch.sigmoid(self.aglinear(z)) 26 | b = self.blinear(z) * torch.sigmoid(self.bglinear(z)) 27 | o = torch.einsum('ilc,jlc->ijc',a,b) 28 | o = self.onorm(o) 29 | o = self.olinear(o) 30 | o = o * torch.sigmoid(self.glinear(z)) 31 | return o 32 | 33 | class TriIn(nn.Module): 34 | def __init__(self,z_dim,c=32): 35 | super(TriIn,self).__init__() 36 | self.z_dim = z_dim 37 | self.norm =nn.LayerNorm(z_dim) 38 | self.onorm =nn.LayerNorm(c) 39 | self.alinear=basic.Linear(z_dim,c) 40 | self.blinear=basic.Linear(z_dim,c) 41 | self.aglinear=basic.Linear(z_dim,c) 42 | self.bglinear=basic.Linear(z_dim,c) 43 | self.glinear =basic.Linear(z_dim,z_dim) 44 | self.olinear=basic.Linear(c,z_dim) 45 | 46 | def forward(self,z): 47 | z = self.norm(z) 48 | a = self.alinear(z) * torch.sigmoid(self.aglinear(z)) 49 | b = self.blinear(z) * torch.sigmoid(self.bglinear(z)) 50 | o = torch.einsum('lic,ljc->ijc',a,b) 51 | o = self.onorm(o) 52 | o = self.olinear(o) 53 | o = o * torch.sigmoid(self.glinear(z)) 54 | return o 55 | 56 | 57 | class TriAttStart(nn.Module): 58 | def __init__(self,z_dim,N_head=4,c=8): 59 | super(TriAttStart,self).__init__() 60 | self.z_dim = z_dim 61 | self.N_head = N_head 62 | self.c = c 63 | self.sq_c = 1/math.sqrt(c) 64 | self.norm=nn.LayerNorm(z_dim) 65 | self.qlinear=basic.Linear(z_dim,c*N_head) 66 | self.klinear=basic.Linear(z_dim,c*N_head) 67 | self.vlinear=basic.Linear(z_dim,c*N_head) 68 | self.blinear=basic.Linear(z_dim,N_head) 69 | self.glinear=basic.Linear(z_dim,c*N_head) 70 | self.olinear=basic.Linear(c*N_head,z_dim) 71 | 72 | def forward(self,z_): 73 | L1,L2,D=z_.shape 74 | z = self.norm(z_) 75 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 76 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 77 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 78 | b = self.blinear(z) 79 | att = torch.einsum('blhc,bkhc->blkh',q,k)*self.sq_c + b[None,:,:,:] 80 | att = F.softmax(SignedSqrt(att),dim=2) 81 | o = torch.einsum('blkh,bkhc->blhc',att,v) 82 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 83 | o = self.olinear(o) 84 | return o 85 | 86 | class TriAttEnd(nn.Module): 87 | def __init__(self,z_dim,N_head=4,c=8): 88 | super(TriAttEnd,self).__init__() 89 | self.z_dim = z_dim 90 | self.N_head = N_head 91 | self.c = c 92 | self.sq_c = 1/math.sqrt(c) 93 | self.norm=nn.LayerNorm(z_dim) 94 | self.qlinear=basic.Linear(z_dim,c*N_head) 95 | self.klinear=basic.Linear(z_dim,c*N_head) 96 | self.vlinear=basic.Linear(z_dim,c*N_head) 97 | self.blinear=basic.Linear(z_dim,N_head) 98 | self.glinear=basic.Linear(z_dim,c*N_head) 99 | self.olinear=basic.Linear(c*N_head,z_dim) 100 | 101 | def forward(self,z_): 102 | L1,L2,D=z_.shape 103 | z = self.norm(z_) 104 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 105 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 106 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 107 | b = self.blinear(z) 108 | att = torch.einsum('blhc,kbhc->blkh',q,k)*self.sq_c + b[None,:,:,:].permute(0,2,1,3) 109 | att = F.softmax(SignedSqrt(att),dim=2) 110 | o = torch.einsum('blkh,klhc->blhc',att,v) 111 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 112 | o = self.olinear(o) 113 | return o 114 | def forward2(self,z_): 115 | z = z_.permute(1,0,2) 116 | L1,L2,D=z_.shape 117 | z = self.norm(z_) 118 | q = self.qlinear(z).reshape(L1,L2,self.N_head,self.c) 119 | k = self.klinear(z).reshape(L1,L2,self.N_head,self.c) 120 | v = self.vlinear(z).reshape(L1,L2,self.N_head,self.c) 121 | b = self.blinear(z) 122 | att = torch.einsum('blhc,bkhc->blkh',q,k)*self.sq_c + b[None,:,:,:] 123 | att = F.softmax(att,dim=2) 124 | o = torch.einsum('blkh,bkhc->blhc',att,v) 125 | o = (torch.sigmoid(self.glinear(z).reshape(L1,L2,self.N_head,self.c)) * o).reshape(L1,L2,-1) 126 | o = self.olinear(o) 127 | o = o.permute(1,0,2) 128 | return o 129 | class PairTrans(nn.Module): 130 | def __init__(self,z_dim,c_expand=2): 131 | super(PairTrans,self).__init__() 132 | self.z_dim=z_dim 133 | self.c_expand=c_expand 134 | self.norm = nn.LayerNorm(z_dim) 135 | self.linear1=basic.Linear(z_dim,z_dim*c_expand) 136 | self.linear2=basic.Linear(z_dim*c_expand,z_dim) 137 | def forward(self,z): 138 | a = self.linear1(self.norm(z)) 139 | a = self.linear2(F.relu(a)) 140 | return a 141 | 142 | 143 | 144 | 145 | if __name__ == "__main__": 146 | N=10 147 | L=30 148 | m_dim=16 149 | z_dim=8 150 | m=torch.rand(N,L,m_dim) 151 | z=torch.rand(L,L,z_dim) 152 | 153 | tr1=TriAttEnd(z_dim) 154 | tr2=PairTrans(z_dim) 155 | y=tr1(z) 156 | y2=tr1.forward2(z) 157 | y3=tr2(z) 158 | print(y3.shape) 159 | 160 | -------------------------------------------------------------------------------- /cfg_99/RNALM2/Evoformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from RNALM2 import basic,EvoPair,EvoMSA 5 | import math,sys 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | 9 | class EvoBlock(nn.Module): 10 | def __init__(self,m_dim,z_dim,docheck=False): 11 | super(EvoBlock,self).__init__() 12 | N_head = 16 13 | c = 16 14 | self.msa_row=EvoMSA.MSARow(m_dim,z_dim,N_head,c) 15 | self.msa_col=EvoMSA.MSACol(m_dim,N_head,c) 16 | self.msa_trans=EvoMSA.MSATrans(m_dim) 17 | 18 | self.msa_opm=EvoMSA.MSAOPM(m_dim,z_dim) 19 | 20 | self.pair_triout=EvoPair.TriOut(z_dim,72) 21 | self.pair_triin =EvoPair.TriIn(z_dim,72) 22 | self.pair_tristart=EvoPair.TriAttStart(z_dim) 23 | self.pair_triend =EvoPair.TriAttEnd(z_dim) 24 | self.pair_trans = EvoPair.PairTrans(z_dim) 25 | self.docheck=docheck 26 | if docheck: 27 | print('will do checkpoint') 28 | 29 | def layerfunc_msa_row(self,m,z): 30 | return self.msa_row(m,z) + m 31 | def layerfunc_msa_col(self,m): 32 | return self.msa_col(m) + m 33 | def layerfunc_msa_trans(self,m): 34 | return self.msa_trans(m) + m 35 | def layerfunc_msa_opm(self,m,z): 36 | return self.msa_opm(m) + z 37 | 38 | def layerfunc_pair_triout(self,z): 39 | return self.pair_triout(z) + z 40 | def layerfunc_pair_triin(self,z): 41 | return self.pair_triin(z) + z 42 | def layerfunc_pair_tristart(self,z): 43 | return self.pair_tristart(z) + z 44 | def layerfunc_pair_triend(self,z): 45 | return self.pair_triend(z) + z 46 | def layerfunc_pair_trans(self,z): 47 | return self.pair_trans(z) + z 48 | def forward(self,m,z): 49 | if True: 50 | m = m + self.msa_row(m,z) 51 | m = m + self.msa_col(m) 52 | m = m + self.msa_trans(m) 53 | z = z + self.msa_opm(m) 54 | z = z + self.pair_triout(z) 55 | z = z + self.pair_triin(z) 56 | #z = z + self.pair_tristart(z) 57 | #z = z + self.pair_triend(z) 58 | z = z + self.pair_trans(z) 59 | return m,z 60 | else: 61 | m=checkpoint(self.layerfunc_msa_row,m,z) 62 | m=checkpoint(self.layerfunc_msa_col,m) 63 | m=checkpoint(self.layerfunc_msa_trans,m) 64 | z=checkpoint(self.layerfunc_msa_opm,m,z) 65 | 66 | z=checkpoint(self.layerfunc_pair_triout,z) 67 | z=checkpoint(self.layerfunc_pair_triin,z) 68 | z=checkpoint(self.layerfunc_pair_tristart,z) 69 | z=checkpoint(self.layerfunc_pair_triend,z) 70 | z=checkpoint(self.layerfunc_pair_trans,z) 71 | 72 | return m,z 73 | 74 | 75 | class Evoformer(nn.Module): 76 | def __init__(self,m_dim,z_dim,N_elayers=12,docheck=False): 77 | super(Evoformer,self).__init__() 78 | self.layers=[N_elayers] 79 | self.docheck=docheck 80 | if docheck: 81 | pass 82 | #print('will do checkpoint') 83 | self.evos=nn.ModuleList([EvoBlock(m_dim,z_dim,True) for i in range(self.layers[0])]) 84 | 85 | def layerfunc(self,layermodule,m,z): 86 | m_,z_=layermodule(m,z) 87 | return m_,z_ 88 | 89 | 90 | def forward(self,m,z): 91 | 92 | if True: 93 | #print('will do checkpoint in Evoformer') 94 | for i in range(self.layers[0]): 95 | #m,z=checkpoint(self.layerfunc,self.evos[i],m,z) 96 | m,z=self.evos[i](m,z) 97 | return m,z 98 | else: 99 | for i in range(self.layers[0]): 100 | m,z=self.evos[i](m,z) 101 | 102 | return m,z 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | if __name__ == "__main__": 113 | N=10 114 | L=30 115 | m_dim=16 116 | z_dim=8 117 | m=torch.rand(N,L,m_dim) 118 | z=torch.rand(L,L,z_dim) 119 | model = Evoformer(m_dim,z_dim) 120 | m,z=model(m,z) 121 | print(model.parameters()) 122 | for param in model.parameters(): 123 | print(type(param), param.size()) 124 | print(m.shape,z.shape) -------------------------------------------------------------------------------- /cfg_99/RNALM2/Model.py: -------------------------------------------------------------------------------- 1 | from numpy import select 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from RNALM2 import basic,Evoformer 6 | import math,sys 7 | from torch.utils.checkpoint import checkpoint 8 | import numpy as np 9 | 10 | 11 | def one_d(idx_, d, max_len=2056*8): 12 | idx = idx_[None] 13 | K = torch.arange(d//2).to(idx.device) 14 | sin_e = torch.sin(idx[..., None] * math.pi / (max_len**(2*K[None]/d))).to(idx.device) 15 | cos_e = torch.cos(idx[..., None] * math.pi / (max_len**(2*K[None]/d))).to(idx.device) 16 | return torch.cat([sin_e, cos_e], axis=-1)[0] 17 | 18 | 19 | 20 | 21 | 22 | 23 | class RNAembedding(nn.Module): 24 | def __init__(self,cfg): 25 | super(RNAembedding,self).__init__() 26 | self.s_in_dim=cfg['s_in_dim'] 27 | self.z_in_dim=cfg['z_in_dim'] 28 | self.s_dim=cfg['s_dim'] 29 | self.z_dim=cfg['z_dim'] 30 | self.qlinear =basic.Linear(self.s_in_dim+1,self.z_dim) 31 | self.klinear =basic.Linear(self.s_in_dim+1,self.z_dim) 32 | self.slinear =basic.Linear(self.s_in_dim+1,self.s_dim) 33 | self.zlinear =basic.Linear(self.z_in_dim+1,self.z_dim) 34 | 35 | self.poslinears = basic.Linear(64,self.s_dim) 36 | self.poslinearz = basic.Linear(64,self.z_dim) 37 | def forward(self,in_dict): 38 | # msa N L D, seq L D 39 | # mask: maksing, L, 1 means masked 40 | # aa: L x s_in_dim 41 | # ss: L x L x 2 42 | # idx: L (LongTensor) 43 | L = in_dict['aa'].shape[0] 44 | aamask = in_dict['mask'][:,None] 45 | zmask = in_dict['mask'][:,None] + in_dict['mask'][None,:] 46 | zmask[zmask>0.5]=1 47 | zmask = zmask[...,None] 48 | s = torch.cat([aamask,(1-aamask)*in_dict['aa']],dim=-1) 49 | sq=self.qlinear(s) 50 | sk=self.klinear(s) 51 | z=sq[None,:,:]+sk[:,None,:] 52 | seq_idx = in_dict['idx'][None] 53 | relative_pos = seq_idx[:, :, None] - seq_idx[:, None, :] 54 | relative_pos = relative_pos.reshape([1, -1]) 55 | relative_pos =one_d(relative_pos,64) 56 | z = z + self.poslinearz( relative_pos.reshape([1, L, L, -1])[0] ) 57 | 58 | s = self.slinear(s) + self.poslinears( one_d(in_dict['idx'], 64) ) 59 | 60 | return s,z 61 | 62 | 63 | class RNA2nd(nn.Module): 64 | def __init__(self,cfg): 65 | super(RNA2nd,self).__init__() 66 | self.s_in_dim=cfg['s_in_dim'] 67 | self.z_in_dim=cfg['z_in_dim'] 68 | self.s_dim=cfg['s_dim'] 69 | self.z_dim=cfg['z_dim'] 70 | self.N_elayers =cfg['N_elayers'] 71 | self.emb = RNAembedding(cfg) 72 | self.evmodel=Evoformer.Evoformer(self.s_dim,self.z_dim,self.N_elayers) 73 | self.seq_head = basic.Linear(self.s_dim,self.s_in_dim) 74 | self.joint_head = basic.Linear(self.z_dim,self.s_in_dim*self.s_in_dim) 75 | 76 | 77 | 78 | 79 | def embedding(self,in_dict): 80 | s,z = self.emb(in_dict) 81 | s,z = self.evmodel(s[None,...],z) 82 | return s[0],z 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /cfg_99/RNALM2/basic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import random 5 | class Linear(nn.Module): 6 | def __init__(self,dim_in,dim_out): 7 | super(Linear,self).__init__() 8 | self.linear = nn.Linear(dim_in,dim_out) 9 | def forward(self,x): 10 | x = self.linear(x) 11 | return x 12 | 13 | 14 | class LinearNoBias(nn.Module): 15 | def __init__(self,dim_in,dim_out): 16 | super(LinearNoBias,self).__init__() 17 | self.linear = nn.Linear(dim_in,dim_out,bias=False) 18 | def forward(self,x): 19 | x = self.linear(x) 20 | return x 21 | 22 | 23 | 24 | def transform(k,rotation,translation): 25 | # K L x 3 26 | # rotation 27 | return torch.matmul(k,rotation) + translation 28 | 29 | 30 | def batch_transform(k,rotation,translation): 31 | # k: L 3 32 | # rotation: L 3 x 3 33 | # translation: L 3 34 | return torch.einsum('ba,bad->bd',k,rotation) + translation 35 | 36 | def batch_atom_transform(k,rotation,translation): 37 | # k: L N 3 38 | # rotation: L 3 x 3 39 | # translation: L 3 40 | return torch.einsum('bja,bad->bjd',k,rotation) + translation[:,None,:] 41 | 42 | def IPA_transform(k,rotation,translation): 43 | # k: L d1, d2, 3 44 | # rotation: L 3 x 3 45 | # translation: L 3 46 | return torch.einsum('bija,bad->bijd',k,rotation)+translation[:,None,None,:] 47 | 48 | def IPA_inverse_transform(k,rotation,translation): 49 | # k: L d1, d2, 3 50 | # rotation: L 3 x 3 51 | # translation: L 3 52 | return torch.einsum('bija,bad->bijd',k-translation[:,None,None,:],rotation.transpose(-1,-2)) 53 | 54 | def update_transform(t,tr,rotation,translation): 55 | return torch.einsum('bja,bad->bjd',t,rotation),torch.einsum('ba,bad->bd',tr,rotation) +translation 56 | 57 | 58 | def quat2rot(q,L): 59 | scale= ((q**2).sum(dim=-1,keepdim=True) +1) [:,:,None] 60 | u=torch.empty([L,3,3],device=q.device) 61 | u[:,0,0]=1*1+q[:,0]*q[:,0]-q[:,1]*q[:,1]-q[:,2]*q[:,2] 62 | u[:,0,1]=2*(q[:,0]*q[:,1]-1*q[:,2]) 63 | u[:,0,2]=2*(q[:,0]*q[:,2]+1*q[:,1]) 64 | u[:,1,0]=2*(q[:,0]*q[:,1]+1*q[:,2]) 65 | u[:,1,1]=1*1-q[:,0]*q[:,0]+q[:,1]*q[:,1]-q[:,2]*q[:,2] 66 | u[:,1,2]=2*(q[:,1]*q[:,2]-1*q[:,0]) 67 | u[:,2,0]=2*(q[:,0]*q[:,2]-1*q[:,1]) 68 | u[:,2,1]=2*(q[:,1]*q[:,2]+1*q[:,0]) 69 | u[:,2,2]=1*1-q[:,0]*q[:,0]-q[:,1]*q[:,1]+q[:,2]*q[:,2] 70 | return u/scale 71 | 72 | 73 | def rotation_x(sintheta,costheta,ones,zeros): 74 | # L x 1 75 | return torch.stack([torch.stack([ones, zeros, zeros]), 76 | torch.stack([zeros, costheta, sintheta]), 77 | torch.stack([zeros, -sintheta, costheta])]) 78 | def rotation_y(sintheta,costheta,ones,zeros): 79 | # L x 1 80 | return torch.stack([torch.stack([costheta, zeros, sintheta]), 81 | torch.stack([zeros, ones, zeros]), 82 | torch.stack([-sintheta, zeros, costheta])]) 83 | def rotation_z(sintheta,costheta,ones,zeros): 84 | # L x 1 85 | return torch.stack([torch.stack([costheta, sintheta, zeros]), 86 | torch.stack([-sintheta, costheta, zeros]), 87 | torch.stack([zeros, zeros, ones])]) 88 | def batch_rotation(k,rotation): 89 | # k: L 3 90 | # rotation: L 3 x 3 91 | # translation: L 3 92 | return torch.einsum('ba,bad->bd',k,rotation) 93 | 94 | def compute_cb(bl,sin_angle,cos_angle,sin_torsion,cos_torsion): 95 | L=bl.shape[0] 96 | ones=torch.ones(L,device=bl.device) 97 | zeros=torch.zeros(L,device=bl.device) 98 | cb=torch.stack([bl,zeros,zeros]).permute(1,0) 99 | rotz=rotation_z(sin_angle,cos_angle,ones,zeros).permute(2,0,1) 100 | rotx=rotation_x(sin_torsion,cos_torsion,ones,zeros).permute(2,0,1) 101 | cb=batch_rotation(cb,rotz) 102 | cb=batch_rotation(cb,rotx) 103 | return cb 104 | 105 | def rigidFrom3Points_(x1,x2,x3): 106 | v1=x3-x2 107 | v2=x1-x2 108 | e1=v1/(torch.norm(v1,dim=-1,keepdim=True) + 1e-03) 109 | u2=v2 - e1*(torch.einsum('bn,bn->b',e1,v2)[:,None]) 110 | e2 = u2/(torch.norm(u2,dim=-1,keepdim=True) + 1e-03) 111 | e3=torch.cross(e1,e2,dim=-1) 112 | 113 | return torch.stack([e1,e2,e3],dim=1),x2[:,:] 114 | def rigidFrom3Points(x1,x2,x3):# L 3 115 | the_dim=1 116 | x = torch.stack([x1,x2,x3],dim=the_dim) 117 | x_mean = torch.mean(x,dim=the_dim,keepdim=True) 118 | x = x - x_mean 119 | 120 | 121 | m = x.view(-1, 3, 3) 122 | u, s, v = torch.svd(m) 123 | vt = torch.transpose(v, 1, 2) 124 | det = torch.det(torch.matmul(u, vt)) 125 | det = det.view(-1, 1, 1) 126 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 127 | r = torch.matmul(u, vt) 128 | return r,x_mean.squeeze() 129 | def Kabsch_rigid(bases,x1,x2,x3): 130 | ''' 131 | return the direction from to_q to from_p 132 | ''' 133 | the_dim=1 134 | to_q = torch.stack([x1,x2,x3],dim=the_dim) 135 | biasq=torch.mean(to_q,dim=the_dim,keepdim=True) 136 | q=to_q-biasq 137 | m = torch.einsum('bnz,bny->bzy',bases,q) 138 | u, s, v = torch.svd(m) 139 | vt = torch.transpose(v, 1, 2) 140 | det = torch.det(torch.matmul(u, vt)) 141 | det = det.view(-1, 1, 1) 142 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 143 | r = torch.matmul(u, vt) 144 | return r,biasq.squeeze() 145 | def Generate_msa_mask(n,l): 146 | # 1: 15% mask out 147 | randommatrix=torch.rand(n,l) 148 | mask = (randommatrix <0.1).float() 149 | # 2 random a segment 150 | seqlength = int(l*0.1) 151 | sindex=round(random.random()*(l-seqlength)) 152 | endindex=min(l,sindex+seqlength) 153 | mask[:,sindex:endindex]=1 154 | return mask 155 | 156 | 157 | -------------------------------------------------------------------------------- /cfg_99/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/cfg_99/__init__.py -------------------------------------------------------------------------------- /cfg_99/base.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leeyang/DRfold2/54129443ed151032e3d647a279814c5cebdca630/cfg_99/base.npy -------------------------------------------------------------------------------- /cfg_99/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os,math 3 | import numpy as np 4 | from torch.utils.data import Dataset, DataLoader 5 | from random import sample 6 | from numpy import float32 7 | import random 8 | from scipy.spatial.distance import cdist 9 | from subprocess import Popen, PIPE, STDOUT 10 | expdir=os.path.dirname(os.path.abspath(__file__)) 11 | #lines = open(os.path.join(expdir,'newconfig')).readlines() 12 | #attdrop = lines[0].strip().split()[-1] == '1' 13 | # denoisee2e = lines[1].strip().split()[-1] == '1' 14 | # ss_type = lines[2].strip().split()[-1] 15 | code_standard = { 16 | 'A':'A','G':'G','C':'C','U':'U','a':'A','g':'G','c':'C','u':'U','T':'U','t':'U' 17 | } 18 | expdir=os.path.dirname(os.path.abspath(__file__)) 19 | parentdir = os.path.dirname(expdir) 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | def parse_seq(inseq): 30 | seqnpy=np.zeros(len(inseq)) 31 | seq1=np.array(list(inseq)) 32 | seqnpy[seq1=='A']=1 33 | seqnpy[seq1=='G']=2 34 | seqnpy[seq1=='C']=3 35 | seqnpy[seq1=='U']=4 36 | seqnpy[seq1=='T']=4 37 | return seqnpy 38 | 39 | 40 | def Get_base(seq,basenpy_standard): 41 | basenpy = np.zeros([len(seq),3,3]) 42 | seqnpy = np.array(list(seq)) 43 | basenpy[seqnpy=='A']=basenpy_standard[0] 44 | basenpy[seqnpy=='a']=basenpy_standard[0] 45 | 46 | basenpy[seqnpy=='G']=basenpy_standard[1] 47 | basenpy[seqnpy=='g']=basenpy_standard[1] 48 | 49 | basenpy[seqnpy=='C']=basenpy_standard[2] 50 | basenpy[seqnpy=='c']=basenpy_standard[2] 51 | 52 | basenpy[seqnpy=='U']=basenpy_standard[3] 53 | basenpy[seqnpy=='u']=basenpy_standard[3] 54 | 55 | basenpy[seqnpy=='T']=basenpy_standard[3] 56 | basenpy[seqnpy=='t']=basenpy_standard[3] 57 | return basenpy 58 | 59 | 60 | -------------------------------------------------------------------------------- /cfg_99/newconfig: -------------------------------------------------------------------------------- 1 | attdrop: 1 2 | denoisee2e: 1 3 | ss_type: attention -------------------------------------------------------------------------------- /cfg_99/test_modeldir.py: -------------------------------------------------------------------------------- 1 | import random 2 | random.seed(0) 3 | import numpy as np 4 | np.random.seed(0) 5 | import os,sys,re,random 6 | from numpy import select 7 | import torch 8 | torch.manual_seed(0) 9 | torch.backends.cudnn.deterministic = True 10 | torch.backends.cudnn.benchmark = False 11 | expdir=os.path.dirname(os.path.abspath(__file__)) 12 | 13 | # from pathlib import Path 14 | # path = Path(expdir) 15 | # parepath = path.parent.absolute() 16 | 17 | 18 | import torch.optim as opt 19 | from torch.nn import functional as F 20 | import data,util 21 | import EvoMSA2XYZ,basic 22 | import math 23 | import pickle 24 | Batch_size=3 25 | Num_cycle=3 26 | TEST_STEP=1000 27 | VISION_STEP=50 28 | device = sys.argv[1] 29 | 30 | 31 | expdir=os.path.dirname(os.path.abspath(__file__)) 32 | expround=expdir.split('_')[-1] 33 | model_path=os.path.join(expdir,'others','models') 34 | # if not os.path.isdir(model_path): 35 | # try: 36 | # os.makedirs(model_path) 37 | # except: 38 | # pass 39 | testdir=os.path.join(expdir,'others','preds') 40 | # if not os.path.isdir(testdir): 41 | # try: 42 | # os.makedirs(testdir) 43 | # except: 44 | # pass 45 | 46 | 47 | 48 | basenpy_standard= np.load( os.path.join(os.path.dirname(os.path.abspath(__file__)),'base.npy' ) ) 49 | def data_collect(pdb_seq): 50 | aa_type = data.parse_seq(pdb_seq) 51 | base = data.Get_base(pdb_seq,basenpy_standard) 52 | seq_idx = np.arange(len(pdb_seq)) + 1 53 | 54 | msa=aa_type[None,:] 55 | msa=torch.from_numpy(msa).to(device) 56 | msa=torch.cat([msa,msa],0) 57 | msa=F.one_hot(msa.long(),6).float() 58 | 59 | base_x = torch.from_numpy(base).float().to(device) 60 | seq_idx = torch.from_numpy(seq_idx).long().to(device) 61 | return msa,base_x,seq_idx 62 | predxs,plddts = model.pred(msa,seq_idx,ss,base_x,sample_1['alpha_0']) 63 | 64 | 65 | 66 | def classifier(infasta,out_prefix,model_dir): 67 | with torch.no_grad(): 68 | lines = open(infasta).readlines()[1:] 69 | seqs = [aline.strip() for aline in lines] 70 | seq = ''.join(seqs) 71 | msa,base_x,seq_idx = data_collect(seq) 72 | # seq_idx = np.genfromtxt(idxfile).astype(int) 73 | # seq_idx = torch.from_numpy(seq_idx).long().to(device) 74 | 75 | msa_dim=6+1 76 | m_dim,s_dim,z_dim = 64,64,64 77 | N_ensemble,N_cycle=3,8 78 | model=EvoMSA2XYZ.MSA2XYZ(msa_dim-1,msa_dim,N_ensemble,N_cycle,m_dim,s_dim,z_dim) 79 | model.to(device) 80 | model.eval() 81 | models = os.listdir( model_dir ) 82 | models = [amodel for amodel in models if 'model' in amodel and 'opt' not in amodel] 83 | 84 | models.sort() 85 | 86 | for amodel in models: 87 | saved_model=os.path.join(model_dir,amodel) 88 | model.load_state_dict(torch.load(saved_model,map_location='cpu'),strict=False) 89 | ret = model.pred(msa,seq_idx,None,base_x,np.array(list(seq))) 90 | 91 | util.outpdb(ret['coor'],seq_idx,seq,out_prefix+f'{amodel}.pdb') 92 | #ret = {'plddt':ret['plddt']} 93 | pickle.dump(ret,open(out_prefix+f'{amodel}.ret','wb')) 94 | # for akey in ret: 95 | # print(akey,ret[akey].shape) 96 | 97 | 98 | 99 | 100 | 101 | 102 | if __name__ == '__main__': 103 | infasta,out_prefix,model_dir = sys.argv[2],sys.argv[3],sys.argv[4] 104 | classifier(infasta,out_prefix,model_dir) -------------------------------------------------------------------------------- /cfg_99/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from subprocess import Popen, PIPE, STDOUT 4 | import os,sys 5 | 6 | 7 | 8 | def outpdb(coor,seq_idx,seq,savefile,start=0,end=10000,energystr=''): 9 | #rama=torch.cat([rama.view(self.L,2),self.betas],dim=-1) 10 | L = coor.shape[0] 11 | 12 | Atom_name=[' P '," C4'",' N1 '] 13 | Other_Atom_name = [" O5'"," C5'"," C3'"," O3'"," C1'"] 14 | other_last_name = ['O',"C","C","O","C"] 15 | 16 | 17 | last_name=['P','C','N'] 18 | wstr=[f'REMARK {str(energystr)}'] 19 | templet='%6s%5d %4s %3s %1s%4d %8.3f%8.3f%8.3f%6.2f%6.2f %2s%2s' 20 | count=1 21 | for i in range(L): 22 | if seq[i] in ['a','g','A','G']: 23 | Atom_name = [' P '," C4'",' N9 '] 24 | #atoms = ['P','C4'] 25 | 26 | elif seq[i] in ['c','u','C','U']: 27 | Atom_name = [' P '," C4'",' N1 '] 28 | for j in range(coor.shape[1]): 29 | outs=('ATOM ',count,Atom_name[j],seq[i],'A',seq_idx[i],coor[i][j][0],coor[i][j][1],coor[i][j][2],0,0,last_name[j],'') 30 | #outs=('ATOM ',count,Atom_name[j],'ALA','A',i+1,coor_np[i][j][0],coor_np[i][j][1],coor_np[i][j][2],1.0,90,last_name[j],'') 31 | #print(outs) 32 | if i>=start-1 and i < end: 33 | wstr.append(templet % outs) 34 | 35 | # for j in range(other_np.shape[1]): 36 | # outs=('ATOM ',count,Other_Atom_name[j],self.seq[i],'A',i+1,other_np[i][j][0],other_np[i][j][1],other_np[i][j][2],0,0,other_last_name[j],'') 37 | # #outs=('ATOM ',count,Atom_name[j],'ALA','A',i+1,coor_np[i][j][0],coor_np[i][j][1],coor_np[i][j][2],1.0,90,last_name[j],'') 38 | # #print(outs) 39 | # if i>=start-1 and i < end: 40 | # wstr.append(templet % outs) 41 | count+=1 42 | wstr.append('TER') 43 | wstr='\n'.join(wstr) 44 | wfile=open(savefile,'w') 45 | wfile.write(wstr) 46 | wfile.close() 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /cfg_for_folding.json: -------------------------------------------------------------------------------- 1 | { 2 | "weight_pp": 1, 3 | "weight_cc": 1, 4 | "weight_nn": 1, 5 | "weight_pccp": 0, 6 | "weight_cnnc": 0, 7 | "weight_pnnp": 0, 8 | "weight_pcc": 0, 9 | "weight_cnn": 0, 10 | "weight_pnn": 0, 11 | "weight_vdw": 1, 12 | "weight_nn_contact": 0, 13 | "weight_cc_contact": 0, 14 | "weight_beta": 0, 15 | "weight_fape": 2, 16 | "weight_bond": 5000, 17 | "pair_weight_power": 0.25, 18 | "pair_weight_min": 0.2, 19 | "pair_error_power": 3, 20 | "pair_rest_min_dist": 2, 21 | "FAPE_max": 30, 22 | "geo_scale": 450, 23 | "num_of_models": 5 24 | } -------------------------------------------------------------------------------- /cfg_for_selection.json: -------------------------------------------------------------------------------- 1 | { 2 | "weight_pp": 1, 3 | "weight_cc": 1, 4 | "weight_nn": 1, 5 | "weight_pccp": 0, 6 | "weight_cnnc": 0, 7 | "weight_pnnp": 0, 8 | "weight_pcc": 0, 9 | "weight_cnn": 0, 10 | "weight_pnn": 0, 11 | "weight_vdw": 1, 12 | "weight_nn_contact": 0, 13 | "weight_cc_contact": 0, 14 | "weight_beta": 0, 15 | "weight_fape": 1, 16 | "weight_bond": 1000, 17 | "pair_weight_power": 0.5, 18 | "pair_weight_min": 0.3, 19 | "pair_error_power": 3.5, 20 | "pair_rest_min_dist": 2, 21 | "FAPE_max": 30, 22 | "geo_scale": 450, 23 | "num_of_models": 5 24 | } -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | wget https://zhanglab.comp.nus.edu.sg/DRfold2/res/model_hub.tar.gz 2 | tar -xzvf model_hub.tar.gz 3 | git clone https://github.com/pylelab/Arena.git 4 | cd Arena 5 | make Arena 6 | -------------------------------------------------------------------------------- /script/refine.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | 3 | from simtk.openmm.app import * 4 | from simtk.openmm import * 5 | from simtk.unit import * 6 | import simtk.openmm as mm 7 | import subprocess 8 | 9 | def woutpdb( infile, outfile): 10 | lines = open(infile).readlines() 11 | nums=[] 12 | for aline in lines: 13 | if aline.startswith('ATOM'): 14 | num = aline.split()[5] 15 | nums.append(num) 16 | nums.sort() 17 | olines=[] 18 | for aline in lines: 19 | if aline.startswith('ATOM'): 20 | atom = aline.split()[2] 21 | num = int(aline.split()[5]) 22 | 23 | theline = list(aline) 24 | if num == nums[0]: 25 | theline[18:20] = theline[19:20] + ['5'] 26 | elif num == nums[-1]: 27 | theline[18:20] = theline[19:20] + ['3'] 28 | if not (("P" in atom and num == 1) or ("H" in atom ) ): 29 | olines.append(''.join(theline)) 30 | wfile = open(outfile,'w') 31 | wfile.write(''.join(olines)) 32 | wfile.close() 33 | 34 | def woutpdb2( infile, outfile): 35 | lines = open(infile).readlines() 36 | olines = [] 37 | for aline in lines: 38 | if aline.startswith('ATOM') and 'H' not in aline: 39 | olines.append(aline) 40 | wfile = open(outfile,'w') 41 | wfile.write(''.join(olines)) 42 | wfile.close() 43 | 44 | 45 | 46 | def opt(inpdb,outpdb,steps): 47 | 48 | # https://openmm.org/ 49 | pdb = PDBFile(inpdb) 50 | modeller = Modeller(pdb.topology, pdb.positions) 51 | forcefield = ForceField('amber14-all.xml', 'amber14/tip3pfb.xml') 52 | modeller.addHydrogens(forcefield) 53 | modeller.addSolvent(forcefield, padding=1 * nanometer) 54 | system = forcefield.createSystem(modeller.topology, nonbondedMethod=NoCutoff, nonbondedCutoff=1 * nanometer, 55 | constraints=HBonds) 56 | if False: 57 | #restraint = CustomExternalForce('k*periodicdistance(x, y, z, x0, y0, z0)^2') 58 | restraint = CustomExternalForce('k*((x-x0)^2+(y-y0)^2+(z-z0)^2)') 59 | system.addForce(restraint) 60 | restraint.addGlobalParameter('k', 100.0*kilojoules_per_mole/nanometer) 61 | restraint.addPerParticleParameter('x0') 62 | restraint.addPerParticleParameter('y0') 63 | restraint.addPerParticleParameter('z0') 64 | for atom in pdb.topology.atoms(): 65 | # print(atom.name) 66 | if atom.name == 'P': 67 | # print('added') 68 | restraint.addParticle(atom.index, pdb.positions[atom.index]) 69 | integrator = LangevinIntegrator(300 * kelvin, 1 / picosecond, 0.002 * picoseconds) 70 | simulation = Simulation(modeller.topology, system, integrator) 71 | simulation.context.setPositions(modeller.positions) 72 | simulation.reporters.append(StateDataReporter(sys.stdout, 1000, step=True, potentialEnergy=True, temperature=True)) 73 | simulation.minimizeEnergy(maxIterations=steps) 74 | position = simulation.context.getState(getPositions=True).getPositions() 75 | app.PDBFile.writeFile(simulation.topology, position, open(outpdb, 'w')) 76 | 77 | 78 | 79 | inpdb = sys.argv[1] 80 | outpdb = sys.argv[2] 81 | count = len(open(inpdb).readlines()) 82 | steps=(int(count/20.0)) * 1 83 | if len(sys.argv) == 4: 84 | steps=int(float(sys.argv[3]) * steps) 85 | steps=max(steps,10) 86 | steps=min(steps,1000) 87 | print('steps:',steps) 88 | woutpdb( inpdb, outpdb+'amber_tmp.pdb') 89 | opt(outpdb+'amber_tmp.pdb',outpdb+'amber_tmp2.pdb',steps) 90 | woutpdb2( outpdb+'amber_tmp2.pdb', outpdb) 91 | try: 92 | os.remove(outpdb+'amber_tmp.pdb') 93 | os.remove(outpdb+'amber_tmp2.pdb') 94 | except: 95 | pass -------------------------------------------------------------------------------- /test/seq.fasta: -------------------------------------------------------------------------------- 1 | >test 2 | UUGGGUUCCCUCACCCCAAUCAUAAAAAGG --------------------------------------------------------------------------------