├── README.md ├── coal.py ├── lik.py ├── palm.py ├── palm_utils ├── deterministic_utils.py └── tree_utils.py └── snp_example ├── example.haps ├── example.quad_fit.npy ├── genetic_map.txt └── relate.haps /README.md: -------------------------------------------------------------------------------- 1 | # PALM 2 | Polygenic Adaptation Likelihood Method (PALM) / Joint PALM (J-PALM) 3 | 4 | 5 | 6 | 7 | 8 | =================================== 9 | 10 | ***See Wiki for documentation:*** 11 | 12 | https://github.com/35ajstern/palm/wiki 13 | 14 | =================================== 15 | -------------------------------------------------------------------------------- /coal.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import ArgumentParser 3 | import warnings 4 | import numpy as np 5 | import glob 6 | from scipy.stats import chi2 7 | import scipy.stats as stats 8 | from scipy.special import logsumexp 9 | from scipy.optimize import minimize 10 | import progressbar 11 | import sys 12 | from numba import njit 13 | 14 | from Bio import Phylo 15 | from io import StringIO 16 | import palm_utils.tree_utils as tree_utils 17 | 18 | def locus_parse_coal_times(args): 19 | bedFile = args.treesFile 20 | derivedAllele = args.derivedAllele 21 | posn = args.posn 22 | sitesFile = args.sitesFile 23 | outFile = args.outFile 24 | timeScale = args.timeScale 25 | burnin = args.burnin 26 | thin = args.thin 27 | debug = args.debug 28 | 29 | if not args.sites: 30 | indLists = tree_utils._derived_carriers_from_haps(sitesFile, 31 | posn, 32 | args.offset, 33 | relate=args.relate) 34 | else: 35 | indLists = tree_utils._derived_carriers_from_sites(sitesFile, 36 | posn, 37 | relate=args.relate, 38 | derivedAllele=args.derivedAllele) 39 | derInds = indLists[0] 40 | ancInds = indLists[1] 41 | ancHap = indLists[2] 42 | 43 | n = len(derInds) 44 | m = len(ancInds) 45 | 46 | f = open(bedFile,'r') 47 | lines = f.readlines() 48 | lines = [line for line in lines if line[0] != '#' and line[0] != 'R' and line[0] != 'N'][burnin::thin] 49 | 50 | numImportanceSamples = len(lines) 51 | 52 | 53 | derTimesList = [] 54 | ancTimesList = [] 55 | 56 | #if debug: 57 | # print('Parsing trees...',file=sys.stderr) 58 | # bar = progressbar.ProgressBar(maxval=numImportanceSamples, \ 59 | # widgets=[progressbar.Bar('=', '[', ']'), ' ', progressbar.Percentage()]) 60 | # bar.start() 61 | for (k,line) in enumerate(lines): 62 | nwk = line.rstrip().split()[-1] 63 | derTree = Phylo.read(StringIO(nwk),'newick') 64 | ancTree = Phylo.read(StringIO(nwk),'newick') 65 | mixTree = Phylo.read(StringIO(nwk),'newick') 66 | 67 | derTimes,ancTimes,mixTimes = tree_utils._get_times_all_classes(derTree,ancTree,mixTree, 68 | derInds,ancInds,ancHap,n,m,sitesFile, 69 | timeScale=timeScale, prune=args.prune) 70 | derTimesList.append(derTimes) 71 | ancTimesList.append(ancTimes) 72 | 73 | # if args.debug: 74 | # bar.update(k+1) 75 | 76 | 77 | 78 | #if args.debug: 79 | # bar.finish() 80 | times = -1 * np.ones((2,n+m,numImportanceSamples)) 81 | for k in range(numImportanceSamples): 82 | times[0,:len(derTimesList[k]),k] = np.array(derTimesList[k]) 83 | times[1,:len(ancTimesList[k]),k] = np.array(ancTimesList[k]) 84 | return times 85 | 86 | def _args_passed_to_locus(args): 87 | locusDir = args.locusDir 88 | passed_args = args 89 | 90 | # reach into args and add additional attributes 91 | d = vars(passed_args) 92 | d['treesFile'] = locusDir+args.locusTreeFile 93 | d['sitesFile'] = locusDir+args.locusSitesFile 94 | 95 | d['popFreq'] = 0.50 96 | 97 | d['posn'] = args.posn 98 | d['derivedAllele'] = args.derivedAllele 99 | return passed_args 100 | 101 | 102 | def _args(super_parser,main=False): 103 | if not main: 104 | parser = super_parser.add_parser('snp_extract',description= 105 | 'Parse/extract coalescence times in the derived & ancestral classes.') 106 | else: 107 | parser = super_parser 108 | # mandatory inputs: 109 | required = parser.add_argument_group('required arguments') 110 | required.add_argument('--locusDir',type=str) 111 | required.add_argument('--posn',type=int) 112 | required.add_argument('--derivedAllele',type=str) 113 | # options: 114 | parser.add_argument('-q','--quiet',action='store_true') 115 | parser.add_argument('-o','--output',dest='outFile',type=str,default=None) 116 | parser.add_argument('-debug','--debug',action='store_true') 117 | 118 | parser.add_argument('--locusTreeFile',type=str,default='mssel.tree') 119 | parser.add_argument('--locusSitesFile',type=str,default='relate.haps') 120 | parser.add_argument('--locusOutPrefix',type=str,default='mssel',help='prefix for outfiles (.ld, .der.npy, .anc.npy)') 121 | 122 | parser.add_argument('-timeScale','--timeScale',type=float,help='Multiply the coal times \ 123 | in bedFile by this factor to get in terms of generations; e.g. use \ 124 | this on trees in units of 4N gens (--timeScale <4*N>)',default=1) 125 | parser.add_argument('--relate',action='store_true') 126 | parser.add_argument('--sites',action='store_true') 127 | parser.add_argument('-thin','--thin',type=int,default=1) 128 | parser.add_argument('-burnin','--burnin',type=int,default=0) 129 | parser.add_argument('--sep',type=str,default='\t') 130 | parser.add_argument('--offset',type=int,default=0) 131 | parser.add_argument('--prune',type=str,default=None) 132 | return parser 133 | 134 | def freq(genoMat): 135 | n = genoMat.shape[1] 136 | return np.sum(genoMat,axis=1)/n 137 | 138 | def r2(genoMat,posnFocal,posns,freqs): 139 | ifiltfocal = list(posns).index(posnFocal) 140 | genoMatFilt = genoMat[:,:] 141 | l = genoMatFilt.shape[0] 142 | r2vec = np.zeros(l) 143 | n = genoMatFilt.shape[1] 144 | rowa = genoMatFilt[ifiltfocal,:] 145 | for j,rowb in enumerate(genoMatFilt): 146 | pab = (rowa & rowb).sum()/n 147 | pa = rowa.sum()/n 148 | pb = rowb.sum()/n 149 | #print(pab,pa,pb) 150 | r2el = ((pab - pa*pb)/np.sqrt(pa*(1-pa)*pb*(1-pb))) 151 | r2vec[j] = r2el 152 | return np.array(r2vec) 153 | 154 | 155 | def _parse_haps_file(haps,focalPosn): 156 | genoMat = [] 157 | posns = [] 158 | 159 | for line in open(haps,'r'): 160 | if line[0] == 'N' or line[0] == 'R': 161 | continue 162 | 163 | cols = line.rstrip().split(' ') 164 | posn = int(cols[2]) 165 | if posn == focalPosn: 166 | iFocal = len(posns) 167 | alleles = ''.join(cols[5:]) 168 | ancAllele = '0' 169 | derAllele = '1' 170 | if alleles == ancAllele*len(alleles) or alleles == derAllele*len(alleles): 171 | continue 172 | genoMat.append([0 if char == ancAllele else 1 for char in alleles]) 173 | posns.append(posn) 174 | genoMat = np.array(genoMat) 175 | 176 | freqs = freq(genoMat) 177 | freqs = np.array(freqs) 178 | posns = np.array(posns) 179 | r2vector = r2(genoMat,focalPosn,posns,freqs) 180 | return posns,freqs,r2vector 181 | 182 | def _write_ld_file(args,posns,freqs,r2vector,focalPosn,focalFreq,locusDir): 183 | out = open(locusDir+args.locusOutPrefix+'.ld','w') 184 | out.write('#posn\tfreq\tr\n') 185 | out.write('##%d\t%f\n'%(focalPosn,focalFreq)) 186 | for (p,f,r) in zip(posns,freqs,r2vector): 187 | out.write('%d\t%f\t%f\n'%(p,f,r)) 188 | out.close() 189 | return 190 | 191 | 192 | def _write_times_files(args,locusTimes): 193 | locusDir = args.locusDir 194 | i0 = np.argmax(locusTimes[0,:,0] < 0.0) 195 | i1 = np.argmax(locusTimes[1,:,0] < 0.0) 196 | a1 = locusTimes[0,:i0,:] 197 | a2 = locusTimes[1,:i1,:] 198 | a1 = a1.transpose() 199 | a2 = a2.transpose() 200 | np.save(locusDir+args.locusOutPrefix+'.der.npy',a1) 201 | np.save(locusDir+args.locusOutPrefix+'.anc.npy',a2) 202 | return 203 | 204 | def _parse_locus_stats(args): 205 | passed_args = _args_passed_to_locus(args) 206 | locusTimes = locus_parse_coal_times(passed_args) 207 | _write_times_files(args,locusTimes) 208 | return 209 | 210 | def _main(args): 211 | _parse_locus_stats(args) 212 | 213 | if True: 214 | super_parser = argparse.ArgumentParser() 215 | parser = _args(super_parser,main=True) 216 | args = parser.parse_args() 217 | _main(args) 218 | -------------------------------------------------------------------------------- /lik.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import warnings 4 | import numpy as np 5 | import glob 6 | from scipy.stats import chi2 7 | import scipy.stats as stats 8 | import progressbar 9 | import sys 10 | from numba import njit 11 | 12 | from palm_utils.deterministic_utils import log_l_importance_sampler 13 | import gzip 14 | 15 | def parse_clues(filename): 16 | with gzip.open(filename, 'rb') as fp: 17 | try: 18 | #parse file 19 | data = fp.read() 20 | except OSError: 21 | with open(filename, 'rb') as fp: 22 | try: 23 | #parse file 24 | data = fp.read() 25 | except OSError: 26 | print('Error: Unable to open ' + filename) 27 | exit(1) 28 | 29 | #get #mutations and #sampled trees per mutation 30 | filepos = 0 31 | num_muts, num_sampled_trees_per_mut = np.frombuffer(data[slice(filepos, filepos+8, 1)], dtype = np.int32) 32 | #print(num_muts, num_sampled_trees_per_mut) 33 | 34 | filepos += 8 35 | #iterate over mutations 36 | for m in range(0,num_muts): 37 | bp = np.frombuffer(data[slice(filepos, filepos+4, 1)], dtype = np.int32) 38 | filepos += 4 39 | anc, der = np.frombuffer(data[slice(filepos, filepos+2, 1)], dtype = 'c') 40 | filepos += 2 41 | daf, n = np.frombuffer(data[slice(filepos, filepos+8, 1)], dtype = np.int32) 42 | filepos += 8 43 | #print("BP: %d, anc: %s, der %s, DAF: %d, n: %d" % (bp, str(anc), str(der), daf, n)) 44 | 45 | num_anctimes = 4*(n-daf-1)*num_sampled_trees_per_mut 46 | anctimes = np.reshape(np.frombuffer(data[slice(filepos, filepos+num_anctimes, 1)], dtype = np.float32), (num_sampled_trees_per_mut, n-daf-1)) 47 | filepos += num_anctimes 48 | #print(anctimes) 49 | 50 | num_dertimes = 4*(daf-1)*num_sampled_trees_per_mut 51 | dertimes = np.reshape(np.frombuffer(data[slice(filepos, filepos+num_dertimes, 1)], dtype = np.float32), (num_sampled_trees_per_mut, daf-1)) 52 | filepos += num_dertimes 53 | 54 | return dertimes,anctimes 55 | 56 | def _args(super_parser,main=False): 57 | if not main: 58 | parser = super_parser.add_parser('lik',description= 59 | 'Locus selection likelihoods.') 60 | else: 61 | parser = super_parser 62 | # mandatory inputs: 63 | required = parser.add_argument_group('required arguments') 64 | required.add_argument('--times',type=str) 65 | # options: 66 | parser.add_argument('--popFreq',type=float,default=None) 67 | parser.add_argument('-q','--quiet',action='store_true') 68 | 69 | parser.add_argument('--locusAncientCounts',type=str,default=None) 70 | parser.add_argument('--out',type=str,default=None) 71 | #advanced options 72 | parser.add_argument('-N','--N',type=float,default=10**4) 73 | parser.add_argument('-coal','--coal',type=str,default=None,help='path to Relate .coal file. Negates --N option.') 74 | 75 | parser.add_argument('-w','--w',type=float,default=0.01) 76 | parser.add_argument('--sMax',type=float,default=0.1) 77 | parser.add_argument('-thin','--thin',type=int,default=1) 78 | parser.add_argument('-burnin','--burnin',type=int,default=0) 79 | parser.add_argument('--tCutoff',type=float,default=50000) 80 | parser.add_argument('--linspace',nargs=2,type=int,default=(50,1)) 81 | parser.add_argument('--K',type=int,default=1,help='which epoch (bwd in time) selected started (e.g. K=1 & kappa=1 means selection started + ended in present day)') 82 | parser.add_argument('--kappa',type=int,default=1,help='# of epochs during which selection occurred, counting back from K') 83 | parser.add_argument('--timeScale',type=float,default=1.0) 84 | return parser 85 | 86 | def _parse_locus_stats(args): 87 | locusDerTimes,locusAncTimes = parse_clues(args.times+'.timeb') 88 | 89 | if locusDerTimes.ndim == 0 or locusAncTimes.ndim == 0: 90 | raise ValueError 91 | elif locusAncTimes.ndim == 1 and locusDerTimes.ndim == 1: 92 | M = 1 93 | locusDerTimes = np.transpose(np.array([locusDerTimes])) 94 | locusAncTimes = np.transpose(np.array([locusAncTimes])) 95 | elif locusAncTimes.ndim == 2 and locusDerTimes.ndim == 1: 96 | locusDerTimes = np.array([locusDerTimes])[:,::args.thin] 97 | locusAncTimes = np.transpose(locusAncTimes)[:,::args.thin] 98 | M = locusDerTimes.shape[1] 99 | elif locusAncTimes.ndim == 1 and locusDerTimes.ndim == 2: 100 | locusAncTimes = np.array([locusAncTimes])[:,::args.thin] 101 | locusDerTimes = np.transpose(locusDerTimes)[:,::args.thin] 102 | M = locusDerTimes.shape[1] 103 | else: 104 | locusDerTimes = np.transpose(locusDerTimes)[:,::args.thin] 105 | locusAncTimes = np.transpose(locusAncTimes)[:,::args.thin] 106 | M = locusDerTimes.shape[1] 107 | n = locusDerTimes.shape[0] + 1 108 | m = locusAncTimes.shape[0] + 1 109 | ntot = n + m 110 | row0 = -1.0 * np.ones((ntot,M)) 111 | row0[:locusDerTimes.shape[0],:] = locusDerTimes 112 | row1 = -1.0 * np.ones((ntot,M)) 113 | row1[:locusAncTimes.shape[0],:] = locusAncTimes 114 | locusTimes = np.array([row0,row1])* args.timeScale 115 | 116 | if args.popFreq == None: 117 | popFreq = n/ntot 118 | else: 119 | popFreq = args.popFreq 120 | return locusTimes,n,m,popFreq 121 | 122 | def _print_sel_coeff_matrix(omega,args,epochs,se): 123 | print('\t'.join(['%d-%d'%(epochs[i],epochs[i+1]) for i in range(len(epochs[:-1]))])) 124 | O = omega.shape[0] 125 | if True: 126 | sig = np.zeros((omega.shape[0],3)) 127 | for level in range(3): 128 | c = stats.norm.ppf(1-0.05/(2*O)*10**-level) 129 | sig[:,level] = np.logical_not((omega - c*se <= 0) & (omega + c*se >= 0)) 130 | print('\t'.join(['%.3f%s'%(omega[i],'*'*int(np.sum(sig[i,:]))) for i in range(omega.shape[0])])) 131 | 132 | return 133 | 134 | def _optimize_locus_likelihood(statistics,args): 135 | if args.coal != None: 136 | epochs = np.genfromtxt(args.coal,skip_header=1,skip_footer=1) 137 | N = 0.5/np.genfromtxt(args.coal,skip_header=2)[2:-1] 138 | N = np.array(list(N)+[N[-1]]) 139 | K = args.K + args.kappa - 1 140 | else: 141 | epochs = np.linspace(0,args.linspace[0],args.linspace[1]+1) 142 | N = args.N*np.ones(len(epochs)) 143 | K = len(epochs)-1 144 | if not args.quiet: 145 | print('Demographic model with diploid Ne:') 146 | print(N) 147 | icutoff = np.digitize(args.tCutoff,epochs) 148 | N = N[:icutoff] 149 | epochs = epochs[:icutoff] 150 | 151 | times,n,m,x0 = statistics 152 | tmp = np.swapaxes(times, 0, 2) 153 | times = tmp 154 | I = len(epochs)-1 155 | if not args.quiet: 156 | print('Analyzing selection over %d time periods...'%(K)) 157 | print('# importance samples: %d'%(times.shape[0])) 158 | print('Optimizing likelihood surface...') 159 | 160 | ns = np.array([n,m]) 161 | logL0 = 0.0 162 | theta = np.zeros(len(epochs)) 163 | logL0 = log_l_importance_sampler(times,ns,epochs,theta,x0,N,tCutoff=args.tCutoff) 164 | 165 | S = np.linspace(-args.sMax,args.sMax,200) 166 | L = np.zeros(len(S)) 167 | for i,s in enumerate(S): 168 | theta[0:args.kappa] = s 169 | logL1 = log_l_importance_sampler(times,ns,epochs,theta,x0,N,tCutoff=args.tCutoff) 170 | L[i] = logL1 - logL0 171 | if args.out == None: 172 | print(s,logL1 - logL0) 173 | I = np.abs(S-S[np.argmax(L)]) < args.w 174 | p = np.polyfit(S[I],L[I],deg=2) 175 | if args.out != None: 176 | np.save(args.out+'.quad_fit.npy',p) 177 | 178 | return 179 | 180 | def _main(args): 181 | statistics = _parse_locus_stats(args) 182 | _optimize_locus_likelihood(statistics,args) 183 | 184 | if True: 185 | super_parser = argparse.ArgumentParser() 186 | parser = _args(super_parser,main=True) 187 | args = parser.parse_args() 188 | _main(args) 189 | 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /palm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import pandas as pd 4 | import warnings 5 | import numpy as np 6 | import glob 7 | import scipy.stats as stats 8 | import progressbar 9 | from numba import njit 10 | 11 | def _args(super_parser,main=False): 12 | if not main: 13 | parser = super_parser.add_parser('trait',description= 14 | 'Trait selection tests.') 15 | else: 16 | parser = super_parser 17 | # mandatory inputs: 18 | required = parser.add_argument_group('required arguments') 19 | required.add_argument('--traitDir',type=str,help='A directory containing only directories, each representing a causal locus') 20 | required.add_argument('--metadata',type=str,help='A dataframe holding attributes for each SNP') 21 | parser.add_argument('--traits',type=str,help='Traits to analyze, separated by commas; only specify if metadata has betas indexed by trait(s) (e.g., joint analysis of traits)',default='NULL') 22 | # options: 23 | parser.add_argument('-q','--quiet',action='store_true') 24 | parser.add_argument('-o','--output',dest='outFile',type=str,default=None) 25 | 26 | parser.add_argument('--quad',type=str,default=None,help='prefix for the quadratic likelihood fits from lik.py') 27 | parser.add_argument('--out',type=str,default=None) 28 | parser.add_argument('--B',type=int,default=250) 29 | #advanced options 30 | parser.add_argument('--maxp',type=float,default=1) 31 | parser.add_argument('--seed',default=None,type=int) 32 | return parser 33 | 34 | def _parse_loci_stats(args): 35 | if args.seed != None: 36 | np.random.seed(args.seed) 37 | 38 | coeffs = [] 39 | betas = [] 40 | mults = [] 41 | pvals = [] 42 | df = pd.read_csv(args.metadata,sep='\t',index_col=(0,1),header=0) 43 | 44 | if args.traits == 'NULL': 45 | betaColumns = ['beta'] 46 | pColumns = ['pval'] 47 | seColumns = ['se'] 48 | if not args.quiet: 49 | print() 50 | print('Analyzing trait...') 51 | traitNames = [''] 52 | else: 53 | pStr = 'pval@' 54 | seStr = 'se@' 55 | betaStr = 'beta@' 56 | betaColumns = [col for col in df.columns if np.any([betaStr+trait in col for trait in (args.traits).split(',')])] 57 | pColumns = [col for col in df.columns if np.any([pStr+trait in col for trait in (args.traits).split(',')])] 58 | seColumns = [col for col in df.columns if np.any([seStr+trait in col for trait in (args.traits).split(',')])] 59 | traitNames = [col[len(betaStr):] for col in betaColumns] 60 | if not args.quiet: 61 | print() 62 | print('Analyzing traits: %s'%(', '.join(traitNames))) 63 | 64 | dfFiltered = df 65 | idxs = dfFiltered.index.values 66 | K = len(idxs) 67 | if not args.quiet: 68 | print('Loading likelihoods...',file=sys.stderr) 69 | bar = progressbar.ProgressBar(maxval=K, \ 70 | widgets=[progressbar.Bar('=', '[', ']'), ' ', progressbar.Percentage()]) 71 | bar.start() 72 | 73 | for (k,idx) in enumerate(idxs): 74 | dfRow = df.loc[idx] 75 | mult = len(df.loc[idx[0]].index) 76 | variant = idx[1] 77 | cols = variant.split(':') 78 | chrom = int(cols[0]) 79 | bp = int(cols[1]) 80 | ref = cols[2] 81 | alt = cols[3] 82 | 83 | ld_block = int(str(idx[0]).lstrip('ld_') ) 84 | 85 | locusDir = args.traitDir+'ld_%d/'%(ld_block) 86 | 87 | derived_allele = dfRow.derived_allele 88 | 89 | if derived_allele != alt: 90 | flipper = -1.0 91 | else: 92 | flipper = 1.0 93 | 94 | loc_betas = list(flipper * np.array(dfRow[betaColumns])) 95 | loc_ses = list(np.array(dfRow[seColumns],dtype=float)) 96 | loc_pvals = np.array(dfRow[pColumns],dtype=float) 97 | if np.size(loc_ses) == 0: 98 | loc_ses = list(np.zeros(len(betaColumns))) 99 | if np.size(loc_pvals) == 0: 100 | loc_pvals = list(np.zeros(len(betaColumns))) 101 | if args.maxp < 1: 102 | if np.any(np.isnan(loc_betas)): 103 | continue 104 | if np.logical_not(np.any(stats.chi2.sf((np.array(loc_betas)/np.array(loc_ses))**2,df=1) < args.maxp)): 105 | continue 106 | 107 | try: 108 | if args.quad != None: 109 | coeff = np.load(locusDir + args.quad + '.npy') 110 | else: 111 | coeff = np.load(locusDir + 'bp%s.quad_fit.npy'%(bp)) 112 | except: 113 | continue 114 | coeffs.append(coeff) 115 | betas.append(loc_betas) 116 | mults.append(1/mult) 117 | pvals.append(np.array(loc_pvals)) 118 | if not args.quiet: 119 | bar.update(k) 120 | if not args.quiet: 121 | bar.finish() 122 | betas = np.array(betas) 123 | mults = np.array(mults) 124 | coeffs = np.array(coeffs) 125 | pvals = np.array(pvals) 126 | return coeffs,betas,mults,pvals,traitNames 127 | 128 | def _print_omega(omega,ses,traitNames,marg=None,T=None): 129 | if marg is None and T is None: 130 | print('Trait\t\t\tSel\t(SE)\t\tZ') 131 | else: 132 | print('Trait\t\t\tSel\t(SE)\t\tZ\tZmarg\tR') 133 | print('='*90) 134 | for j,trait in enumerate(traitNames): 135 | if len(trait) < 16: 136 | traitFmt = trait+' '*(16-len(trait)) 137 | else: 138 | traitFmt = trait[:16] 139 | if marg is None and T is None: 140 | print('%s\t%.3f\t(%.4f)\t%.3f'%(traitFmt,omega[j],ses[j],omega[j]/ses[j])) 141 | else: 142 | print('%s\t%.3f\t(%.4f)\t%.3f\t%.3f\t%.3f'%(traitFmt,omega[j],ses[j],omega[j]/ses[j],marg[j],T[j])) 143 | print('='*90) 144 | return 145 | 146 | def _out(omega,ses,L_byTrait,L,out,marg=None,T=None): 147 | np.save(out+'.est.npy',omega) 148 | np.save(out+'.se.npy',ses) 149 | np.save(out+'.L.npy',np.concatenate(([L],L_byTrait))) 150 | if marg is not None: 151 | np.save(out+'.margZ.npy',marg) 152 | if T is not None: 153 | np.save(out+'.T.npy',T) 154 | return 155 | 156 | def _bootstrap(stats): 157 | coeffs,betas,mults,pvals,traitNames = stats 158 | I = np.random.choice(coeffs.shape[0],coeffs.shape[0],replace=True) 159 | return coeffs[I,:],betas[I,:],mults[I],pvals[I,:],traitNames 160 | 161 | def _opt_omega(stats): 162 | coeffs,betas,mults,pvals,traitNames = stats 163 | J = betas.shape[1] 164 | L = betas.shape[0] 165 | 166 | A = np.zeros((J,J)) 167 | b = np.zeros(J) 168 | 169 | for l in range(L): 170 | A += 2 * mults[l] * coeffs[l,0] * np.outer(betas[l,:],betas[l,:]) 171 | b += -mults[l] * coeffs[l,1] * betas[l,:] 172 | Ainv = np.linalg.inv(A) 173 | omega = np.dot(Ainv,b) 174 | return omega 175 | 176 | def _nloci(stats,args): 177 | coeffs,betas,mults,pvals,traitNames = stats 178 | L = pvals.shape[0] 179 | J = pvals.shape[1] 180 | L_byTrait = np.sum(pvals < args.maxp,axis=0) 181 | return L_byTrait,L 182 | 183 | def _inference(statistics,args): 184 | omega = _opt_omega(statistics) 185 | L = statistics[0].shape[0] 186 | J = len(omega) 187 | L_byTrait,L = _nloci(statistics,args) 188 | print('Analyzing %d loci...'%(L)) 189 | B = args.B 190 | 191 | omegaJK = np.zeros((J,B)) 192 | for b in range(B): 193 | statsDK = _bootstrap(statistics) 194 | omegaJK_b = _opt_omega(statsDK) 195 | omegaJK[:,b] = omegaJK_b 196 | ses = np.std(omegaJK,axis=1) 197 | 198 | return omega,ses 199 | 200 | def _T_inference(statistics,args): 201 | coeffs,betas,mults,pvals,traitNames = statistics 202 | L_byTrait,L = _nloci(statistics,args) 203 | print('Analyzing %d loci...'%(L)) 204 | omega = _opt_omega(statistics) 205 | J = betas.shape[1] 206 | margOmega = np.zeros(J) 207 | for j in range(J): 208 | Lj = L_byTrait[j] 209 | msig = pvals[:,j] < args.maxp 210 | mcoeffs = coeffs[msig,:] 211 | mbetas = np.reshape(betas[msig,j],(Lj,1)) 212 | mmults = mults[msig] 213 | mpvals = np.reshape(pvals[msig,j],(Lj,1)) 214 | mtraitNames = [traitNames[j]] 215 | mstats = mcoeffs,mbetas,mmults,mpvals,mtraitNames 216 | momega = _opt_omega(mstats) 217 | margOmega[j] = momega 218 | 219 | ## se estimation 220 | B = args.B 221 | omegaJK = np.zeros((J,B)) 222 | margOmegaJK = np.zeros((J,B)) 223 | for b in range(B): 224 | statsDK = _bootstrap(statistics) 225 | omegaJK_b = _opt_omega(statsDK) 226 | omegaJK[:,b] = omegaJK_b 227 | 228 | coeffs,betas,mults,pvals,traitNames = statsDK 229 | L_byTrait,L = _nloci(statsDK,args) 230 | 231 | for j in range(J): 232 | Lj = L_byTrait[j] 233 | msig = pvals[:,j] < args.maxp 234 | mcoeffs = coeffs[msig,:] 235 | mbetas = np.reshape(betas[msig,j],(Lj,1)) 236 | mmults = mults[msig] 237 | mpvals = np.reshape(pvals[msig,j],(Lj,1)) 238 | mtraitNames = [traitNames[j]] 239 | mstats = mcoeffs,mbetas,mses,mx0,mmults,mpvals,mtraitNames 240 | momega = _opt_omega(mstats) 241 | margOmegaJK[j,b] = momega 242 | ses = np.std(omegaJK,axis=1) 243 | Dses = np.std(omegaJK.transpose()/np.std(omegaJK,axis=1)-margOmegaJK.transpose()/np.std(margOmegaJK,axis=1),axis=0).transpose() 244 | Mses = np.std(margOmegaJK,axis=1) 245 | D = omega/ses - margOmega/Mses 246 | return omega,ses,margOmega,Mses,D,Dses 247 | 248 | def _main(args): 249 | statistics = _parse_loci_stats(args) 250 | L_byTrait,L = _nloci(statistics,args) 251 | coeffs,betas,mults,pvals,traitNames = statistics 252 | J = betas.shape[1] 253 | 254 | if J > 1: 255 | # run test jointly 256 | omega,ses,margOmega,Mses,D,Dses = _T_inference(statistics,args) 257 | Ts = D/Dses 258 | margZs = margOmega/Mses 259 | else: 260 | omega,ses = _inference(statistics,args) 261 | 262 | if args.out != None: 263 | if J > 1: 264 | _out(omega,ses,L_byTrait,L,args.out,marg=margZs,T=Ts) 265 | else: 266 | _out(omega,ses,L_byTrait,L,args.out) 267 | if not args.quiet: 268 | if J > 1: 269 | _print_omega(omega,ses,traitNames,marg=margZs,T=Ts) 270 | else: 271 | _print_omega(omega,ses,traitNames) 272 | print() 273 | return 274 | 275 | if True: 276 | super_parser = argparse.ArgumentParser() 277 | parser = _args(super_parser,main=True) 278 | args = parser.parse_args() 279 | _main(args) 280 | -------------------------------------------------------------------------------- /palm_utils/deterministic_utils.py: -------------------------------------------------------------------------------- 1 | from numba import njit 2 | import numpy as np 3 | import scipy.stats as stats 4 | from scipy.optimize import minimize 5 | 6 | EPS = 1e-7 7 | EPS2 = -1 8 | 9 | @njit('float64(float64[:])',cache=True) 10 | def _logsumexp(a): 11 | a_max = np.max(a) 12 | 13 | tmp = np.exp(a - a_max) 14 | 15 | s = np.sum(tmp) 16 | out = np.log(s) 17 | 18 | out += a_max 19 | return out 20 | 21 | @njit('float64(float64[:],float64[:])',cache=True) 22 | def _logsumexpb(a,b): 23 | 24 | a_max = np.max(a) 25 | 26 | tmp = b * np.exp(a - a_max) 27 | 28 | s = np.sum(tmp) 29 | out = np.log(s) 30 | 31 | out += a_max 32 | return out 33 | 34 | @njit('float64[:](float64,float64[:],float64[:],float64[:],int64)',cache=True) 35 | def _frequency_memos(x0,epochs,selEpochPtr,N,anc=0): 36 | if True: 37 | bit = EPS 38 | else: 39 | bit = 1 40 | if anc: 41 | selEpoch = -np.array(list(selEpochPtr)) - bit 42 | x0 = 1-x0 43 | else: 44 | selEpoch = np.array(list(selEpochPtr)) + bit 45 | 46 | freqs = np.zeros(len(epochs)) 47 | freqs[0] = x0 48 | for ie in range(1,len(epochs)): 49 | t0 = epochs[ie-1] 50 | t1 = epochs[ie] 51 | s = selEpoch[ie-1] 52 | x0 = x0 * (x0 + (1-x0) * np.exp(s * (t1-t0)))**(-1) 53 | freqs[ie] = x0 54 | return freqs 55 | 56 | @njit('float64(float64,float64[:],float64[:],float64[:],float64[:],int64)',cache=True) 57 | def _freq_using_memos(t,epochs,selEpochPtr,freqMemosPtr,N,anc=0): 58 | iEpoch = int(np.digitize(np.array([t]),epochs)[0]-1) 59 | if True: 60 | bit = EPS 61 | else: 62 | bit = 1 63 | if anc: 64 | selEpoch = -np.array(list(selEpochPtr)) - bit 65 | freqMemos = 1-freqMemosPtr 66 | else: 67 | selEpoch = np.array(list(selEpochPtr)) + bit 68 | freqMemos = freqMemosPtr 69 | 70 | s = selEpoch[iEpoch] 71 | t1 = epochs[iEpoch] 72 | x0 = freqMemos[iEpoch] 73 | f = x0*(x0 + (1-x0)*np.exp(s*(t-t1)))**-1 74 | return f 75 | 76 | 77 | @njit('float64[:](float64[:],float64[:],float64[:],float64[:],int64)',cache=True) 78 | def _coal_intensity_memos(epochs,selEpochPtr,freqMemosPtr,N,anc=0): 79 | ''' 80 | returns the intensity function evaluated at each epoch (for 81 | faster likelihood and gradient calculations) 82 | ''' 83 | if True: 84 | bit = EPS 85 | else: 86 | bit = 1 87 | if anc: 88 | selEpoch = -np.array(list(selEpochPtr)) - bit 89 | freqMemos = 1-freqMemosPtr 90 | else: 91 | selEpoch = np.array(list(selEpochPtr)) + bit 92 | freqMemos = freqMemosPtr 93 | 94 | Lambda = np.zeros(len(epochs)) 95 | N0 = N[0] 96 | for ie in range(1,len(epochs)): 97 | x0 = freqMemos[ie-1] 98 | t0 = epochs[ie-1] 99 | t1 = epochs[ie] 100 | s = selEpoch[ie-1] 101 | if np.abs(s) > EPS2: 102 | Lambda[ie] = (t1-t0) + (1-x0)/(x0*s)*(np.exp(s*(t1-t0)) - 1) 103 | else: 104 | Lambda[ie] = 1/x0 * (t1-t0) 105 | Lambda[ie] *= N0/N[ie-1] 106 | Lambda[ie] += Lambda[ie-1] 107 | return Lambda 108 | 109 | @njit('float64(float64,float64[:],float64[:],float64[:],float64[:],float64[:],int64)',cache=True) 110 | def _coal_intensity_using_memos(t,epochs,selEpochPtr,freqMemosPtr,intensityMemos,N,anc=0): 111 | iEpoch = int(np.digitize(np.array([t]),epochs)[0]-1) 112 | if True: 113 | bit = EPS 114 | else: 115 | bit = 1 116 | if anc: 117 | selEpoch = -np.array(list(selEpochPtr)) - bit 118 | freqMemos = 1-freqMemosPtr 119 | else: 120 | selEpoch = np.array(list(selEpochPtr)) + bit 121 | freqMemos = freqMemosPtr 122 | 123 | 124 | s = selEpoch[iEpoch] 125 | t1 = epochs[iEpoch] 126 | x0 = freqMemos[iEpoch] 127 | N0 = N[0] 128 | Lambda = intensityMemos[iEpoch] 129 | if np.abs(s) > EPS2: 130 | Lambda += N0/N[iEpoch] * ((t-t1) + (1-x0)/(x0*s)*(np.exp(s*(t-t1)) - 1)) 131 | else: 132 | Lambda += N0/N[iEpoch] * 1/x0 * (t-t1) 133 | return Lambda 134 | 135 | @njit('float64(float64[:],int64,float64[:],float64[:],float64,float64[:],int64,float64)',cache=True) 136 | def _log_coal_density(times,n,epochs,selEpoch,x0,N,anc=0,tCutoff=5000.0): 137 | logp = 0 138 | prevt = 0 139 | prevLambda = 0 140 | times = times[times < tCutoff] 141 | times = times[times >= 0] 142 | mySelEpoch = selEpoch 143 | N0 = N[0] 144 | # memoize frequency and intensity 145 | myFreqMemos = _frequency_memos(x0,epochs,mySelEpoch,N,anc=0) 146 | myIntensityMemos = _coal_intensity_memos(epochs,mySelEpoch,myFreqMemos,N,anc=anc) 147 | for i,t in enumerate(times): 148 | k = n-i 149 | kchoose2 = k*(k-1)/(4*N0) 150 | Lambda = _coal_intensity_using_memos(t,epochs,mySelEpoch,myFreqMemos,myIntensityMemos,N,anc=anc) 151 | logpk = -np.log(_freq_using_memos(t,epochs,mySelEpoch,myFreqMemos,N,anc=anc)) \ 152 | - kchoose2 * ( Lambda - prevLambda) 153 | logp += logpk 154 | 155 | prevt = t 156 | prevLambda = Lambda 157 | ## now add the probability of lineages not coalescing by tCutoff 158 | k -= 1 159 | kchoose2 = k*(k-1)/(4*N0) 160 | logPk = - kchoose2 * (_coal_intensity_using_memos(tCutoff,epochs,mySelEpoch,myFreqMemos,myIntensityMemos,N,anc=anc) \ 161 | - prevLambda) 162 | 163 | logp += logPk 164 | return logp 165 | 166 | @njit('float64(float64[:,:,:],int64[:],float64[:],float64[:],float64,float64[:],float64)',cache=True) 167 | def log_l_importance_sampler(times,ns,epochs,selEpoch,x0,N,tCutoff=5000.0): 168 | M = times.shape[0] 169 | logls = np.zeros(M) 170 | neuSelEpoch = np.zeros(len(selEpoch)) 171 | tCutoffNeu = 2000000 172 | for i in range(M): 173 | times_m = times[i,:,:] 174 | n = ns[0] 175 | m = ns[1] 176 | derTimes = times_m[:,0] 177 | ancTimes = times_m[:,1] 178 | val = _log_coal_density(derTimes,n,epochs,selEpoch,x0,N,anc=0,tCutoff=tCutoff) 179 | val += _log_coal_density(ancTimes,m,epochs,selEpoch,x0,N,anc=1,tCutoff=tCutoff) 180 | 181 | neuTimes = np.sort(np.concatenate((derTimes,ancTimes))) 182 | val -= _log_coal_density(neuTimes,n+m,epochs,neuSelEpoch,1.0-EPS,N,anc=0,tCutoff=tCutoffNeu) 183 | 184 | logls[i] = val 185 | logl = _logsumexp(logls) - np.log(M) 186 | 187 | if np.isnan(logl): 188 | logl = -np.inf 189 | return logl 190 | -------------------------------------------------------------------------------- /palm_utils/tree_utils.py: -------------------------------------------------------------------------------- 1 | from Bio import Phylo 2 | from io import StringIO 3 | import numpy as np 4 | 5 | def _coal_times(clades): 6 | # get number of leaf nodes 7 | # assumes standard (non-mult-merger) coalescent tree (!) 8 | [left,right] = clades 9 | lbl = float(left.branch_length) 10 | rbl = float(right.branch_length) 11 | 12 | #print lbl, rbl 13 | if len(left.clades) == 0 and len(right.clades) == 0: 14 | return [rbl] 15 | 16 | elif len(left.clades) == 0: 17 | right_times = _coal_times(right.clades) 18 | return [lbl] + right_times 19 | 20 | elif len(right.clades) == 0: 21 | left_times = _coal_times(left.clades) 22 | return [rbl] + left_times 23 | 24 | else: 25 | left_times = _coal_times(left) 26 | right_times = _coal_times(right) 27 | 28 | if lbl < rbl: 29 | return [lbl + left_times[0]] + left_times + right_times 30 | else: 31 | return [rbl + right_times[0]] + left_times + right_times 32 | #''' 33 | def _branch_counts(coalTimes, timePts, eps=1): 34 | ## return number of lineages at each time point 35 | n = len(coalTimes) + 1 36 | C = [n] 37 | 38 | for tp in timePts: 39 | i = 0 40 | for (j,ct) in enumerate(coalTimes[i:]): 41 | if ct >= tp + eps: 42 | i += j 43 | C.append( n-j ) 44 | break 45 | return C 46 | 47 | def _derived_carriers_from_haps(hapsFile,posn,offset,relate=False): 48 | f = open(hapsFile,'r') 49 | lines = f.readlines() 50 | for line in lines: 51 | posnLine = int(line.split()[2]) 52 | if posnLine != posn: 53 | continue 54 | if posnLine == posn: 55 | alleles = ''.join(line.rstrip().split()[5:]) 56 | hapsDer = [str(i+1-int(relate)+offset) for i in range(len(alleles)) if alleles[i] == '1'] 57 | hapsAnc = [str(i+1-int(relate)+offset) for i in range(len(alleles)) if alleles[i] != '1'] 58 | return [hapsDer,hapsAnc,[]] 59 | 60 | def _derived_carriers_from_sites(sitesFile,posn,derivedAllele='G',ancientHap=None,relate=False,nDer=None,nAnc=None): 61 | ''' 62 | Takes the sitesFile 63 | Returns a list of individuals (labels in 64 | the header of sitesFile) who carry derived allele 65 | ''' 66 | 67 | f = open(sitesFile,'r') 68 | lines = f.readlines() 69 | 70 | headerLine = lines[0] 71 | inds = headerLine.split()[1:] 72 | 73 | for line in lines: 74 | if line[0] == '#' or line[0] == 'N' or line[0] == 'R': 75 | continue 76 | cols = line.rstrip().split() 77 | thisPosn = int(cols[-2]) 78 | alleles = cols[-1] 79 | if thisPosn < posn: 80 | continue 81 | elif thisPosn == posn: 82 | n = len(alleles) 83 | if not relate: 84 | if ancientHap: 85 | raise NotImplementedError 86 | indsAnc = [str(_) for _ in range(1,nAnc+1)] 87 | indsDer = [str(_) for _ in range(nAnc+1,n+1)] 88 | return [indsDer,indsAnc,[]] 89 | if ancientHap != None: 90 | raise NotImplementedError 91 | idxsDerived = [i for (i,x) in enumerate(alleles) if x == derivedAllele and inds[i] != ancientHap] 92 | indsDerived = [inds[i] for i in idxsDerived] 93 | indsAnc = [ind for (i,ind) in enumerate(inds) if i not in idxsDerived and inds[i] != ancientHap] 94 | return [indsDerived,indsAnc,[ancientHap]] 95 | else: 96 | idxsDerived = [i for (i,x) in enumerate(alleles) if x == derivedAllele] 97 | indsDerived = [inds[i] for i in idxsDerived] 98 | indsAnc = [ind for (i,ind) in enumerate(inds) if i not in idxsDerived] 99 | return [indsDerived,indsAnc,[]] 100 | else: 101 | inds.remove(ancientHap) 102 | return [[],inds,[ancientHap]] 103 | #raise ValueError('Specified posn not specified in sitesFile') 104 | 105 | def _get_times_all_classes(derTree,ancTree,mixTree,derInds,ancInds,ancHap,n,m,sitesFile,timeScale=1,prune=None): 106 | 107 | indsToPrune = [] 108 | if prune != None: 109 | for line in open(prune,'r'): 110 | indsToPrune += [line.rstrip()] 111 | #print(indsToPrune) 112 | if sitesFile == None: 113 | ### assume all individuals are fixed for the derived type! 114 | if ancHap != None: 115 | raise NotImplementedError 116 | else: 117 | derTimes = timeScale * np.sort(_coal_times(derTree.clade.clades)) 118 | ancTimes = timeScale * np.sort(_coal_times(ancTree.clade.clades)) 119 | mixTimes = timeScale * np.sort(_coal_times(mixTree.clade.clades)) 120 | 121 | if ancHap == None: 122 | ancHap = [] 123 | if n >= 2 and m >= 2: 124 | for ind in set(ancInds + ancHap + indsToPrune): 125 | #print('der',ind) 126 | derTree.prune(ind) 127 | for ind in set(derInds + ancHap + indsToPrune): 128 | #print('anc',ind) 129 | ancTree.prune(ind) 130 | for ind in set(derInds[1:] + ancHap + indsToPrune): 131 | mixTree.prune(ind) 132 | derTimes = timeScale * np.sort(_coal_times(derTree.clade.clades)) 133 | ancTimes = timeScale *np.sort(_coal_times(ancTree.clade.clades)) 134 | mixTimes = timeScale *np.sort(_coal_times(mixTree.clade.clades)) 135 | 136 | 137 | elif n == 1 and m >= 2: 138 | for ind in set(derInds + ancHap + indsToPrune): 139 | ancTree.prune(ind) 140 | for ind in set(derInds[1:] + ancHap + indsToPrune): 141 | mixTree.prune(ind) 142 | 143 | ancTimes = timeScale * np.sort(_coal_times(ancTree.clade.clades)) 144 | mixTimes = timeScale * np.sort(_coal_times(mixTree.clade.clades)) 145 | derTimes = np.array([]) 146 | 147 | elif n >= 2 and m == 1: 148 | for ind in set(ancInds + ancHap + indsToPrune): 149 | derTree.prune(ind) 150 | for ind in set(derInds[1:] + ancHap + indsToPrune): 151 | mixTree.prune(ind) 152 | 153 | derTimes = timeScale * np.sort(_coal_times(derTree.clade.clades)) 154 | mixTimes = timeScale * np.sort(_coal_times(mixTree.clade.clades)) 155 | ancTimes = np.array([]) 156 | 157 | elif n == 0 and m >= 2: 158 | Cder = [0] 159 | for ind in set(ancHap + indsToPrune): 160 | ancTree.prune(ind) 161 | ancTimes = timeScale * np.sort(_coal_times(ancTree.clade.clades)) 162 | derTimes = np.array([]) 163 | mixTimes = np.array([]) 164 | 165 | elif n >= 2 and m == 0: 166 | Canc = [0] 167 | for ind in set(ancHap + indsToPrune): 168 | derTree.prune(ind) 169 | derTimes = timeScale * np.sort(_coal_times(derTree.clade.clades)) 170 | ancTimes = np.array([]) 171 | mixTimes = np.array([]) 172 | return derTimes,ancTimes,mixTimes 173 | 174 | def _get_branches_all_classes(derTree,ancTree,mixTree,derInds,ancInds,ancHap,n,m,sitesFile,times,timeScale=1,prune=None): 175 | 176 | indsToPrune = [] 177 | if prune != None: 178 | for line in open(prune,'r'): 179 | indsToPrune += [line.rstrip()+'_1',line.rstrip()+'_2'] 180 | if sitesFile == None: 181 | ### assume all individuals are fixed for the derived type! 182 | if ancHap != None: 183 | raise NotImplementedError 184 | else: 185 | derTimes = timeScale * np.sort(_coal_times(derTree.clade.clades)) 186 | ancTimes = timeScale *np.sort(_coal_times(ancTree.clade.clades)) 187 | mixTimes = timeScale *np.sort(_coal_times(mixTree.clade.clades)) 188 | Cder = _branch_counts(derTimes,times,eps=10**-10)[1:] + [1] 189 | Canc = _branch_counts(ancTimes,times,eps=10**-10)[1:] + [1] 190 | Cmix = _branch_counts(mixTimes,times,eps=10**-10)[1:] + [1] 191 | 192 | if ancHap == None: 193 | ancHap = [] 194 | if n >= 2 and m >= 2: 195 | for ind in set(ancInds + ancHap + indsToPrune): 196 | #print(ind) 197 | derTree.prune(ind) 198 | for ind in set(derInds + ancHap + indsToPrune): 199 | ancTree.prune(ind) 200 | for ind in set(derInds[1:] + ancHap + indsToPrune): 201 | mixTree.prune(ind) 202 | derTimes = timeScale * np.sort(_coal_times(derTree.clade.clades)) 203 | #print(derTimes[:20]) 204 | #print(derTimes.astype(int)[:20]) 205 | 206 | ancTimes = timeScale *np.sort(_coal_times(ancTree.clade.clades)) 207 | #print(ancTimes.astype(int)[:20]) 208 | mixTimes = timeScale *np.sort(_coal_times(mixTree.clade.clades)) 209 | #print(mixTimes.astype(int)[:20]) 210 | Cder = _branch_counts(derTimes,times,eps=10**-10)[1:] + [1] 211 | Canc = _branch_counts(ancTimes,times,eps=10**-10)[1:] + [1] 212 | Cmix = _branch_counts(mixTimes,times,eps=10**-10)[1:] + [1] 213 | print(np.array(Cder).astype(int)[:20]) 214 | print(np.array(Canc).astype(int)[:20]) 215 | print(np.array(Cmix).astype(int)[:20]) 216 | #print(Canc,Cmix) 217 | 218 | elif n == 1 and m >= 2: 219 | for ind in set(derInds + ancHap + indsToPrune): 220 | ancTree.prune(ind) 221 | for ind in set(derInds[1:] + ancHap + indsToPrune): 222 | mixTree.prune(ind) 223 | 224 | ancTimes = timeScale * np.sort(_coal_times(ancTree.clade.clades)) 225 | mixTimes = timeScale * np.sort(_coal_times(mixTree.clade.clades)) 226 | #print(n,m) 227 | Cder = [1] 228 | Canc = _branch_counts(ancTimes,times,eps=10**-10)[1:] + [1] 229 | Cmix = _branch_counts(mixTimes,times,eps=10**-10)[1:] + [1] 230 | #print(Canc,Cmix) 231 | 232 | elif n >= 2 and m == 1: 233 | for ind in set(ancInds + ancHap + indsToPrune): 234 | derTree.prune(ind) 235 | for ind in set(derInds[1:] + ancHap + indsToPrune): 236 | mixTree.prune(ind) 237 | 238 | derTimes = timeScale * np.sort(_coal_times(derTree.clade.clades)) 239 | mixTimes = timeScale * np.sort(_coal_times(mixTree.clade.clades)) 240 | #print(n,m) 241 | Cder = _branch_counts(derTimes,times,eps=10**-10)[1:] + [1] 242 | Canc = [1] 243 | Cmix = _branch_counts(mixTimes,times,eps=10**-10)[1:] + [1] 244 | elif n == 0 and m >= 2: 245 | Cder = [0] 246 | for ind in set(ancHap + indsToPrune): 247 | ancTree.prune(ind) 248 | ancTimes = timeScale * np.sort(_coal_times(ancTree.clade.clades)) 249 | Canc = _branch_counts(ancTimes,times,eps=10**-10)[1:] + [1] 250 | Cmix = Canc 251 | 252 | elif n >= 2 and m == 0: 253 | Canc = [0] 254 | for ind in set(ancHap + indsToPrune): 255 | derTree.prune(ind) 256 | derTimes = timeScale * np.sort(_coal_times(derTree.clade.clades)) 257 | Cder = _branch_counts(derTimes,times,eps=10**-10)[1:] + [1] 258 | Cmix = [1] 259 | 260 | return Cder,Canc,Cmix 261 | -------------------------------------------------------------------------------- /snp_example/example.quad_fit.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/standard-aaron/palm/b3c12023b5e1938b049531d1d09e71dd1dca3f61/snp_example/example.quad_fit.npy -------------------------------------------------------------------------------- /snp_example/genetic_map.txt: -------------------------------------------------------------------------------- 1 | position COMBINED_rate.cM.Mb. Genetic_Map.cM. 2 | 0 1.0 0.0 3 | 1000000 0.0 1.0 4 | --------------------------------------------------------------------------------