├── twisst2 ├── __init__.py ├── parse_newick.py ├── TopologySummary.py └── twisst2.py ├── example ├── admix_hiILS_123_l5e6_r1e8_mu1e8.vcf.gz └── groups_4_20.txt ├── plot_twisst ├── examples │ ├── admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.intervals.tsv.gz │ └── admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.topocounts.tsv.gz ├── example_plot.R └── plot_twisst.R ├── pyproject.toml ├── LICENSE └── README.md /twisst2/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.5" 2 | 3 | -------------------------------------------------------------------------------- /example/admix_hiILS_123_l5e6_r1e8_mu1e8.vcf.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simonhmartin/twisst2/HEAD/example/admix_hiILS_123_l5e6_r1e8_mu1e8.vcf.gz -------------------------------------------------------------------------------- /plot_twisst/examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.intervals.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simonhmartin/twisst2/HEAD/plot_twisst/examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.intervals.tsv.gz -------------------------------------------------------------------------------- /plot_twisst/examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.topocounts.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/simonhmartin/twisst2/HEAD/plot_twisst/examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.topocounts.tsv.gz -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "twisst2" 3 | authors = [{name = "Simon Martin", email = "simon.martin@ed.ac.uk"}] 4 | description = "Topology weighting from unphased genotypes of any ploidy" 5 | readme = "README.md" 6 | requires-python = ">=3.7" 7 | dynamic = ["version"] 8 | # version = "0.0.1" 9 | dependencies = ["numpy >= 1.21.5", "cyvcf2"] 10 | 11 | [tool.setuptools.dynamic] 12 | version = {attr = "twisst2.__version__"} 13 | 14 | [project.scripts] 15 | twisst2 = "twisst2.twisst2:main" 16 | 17 | [tool.setuptools] 18 | packages = ["twisst2"] 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Simon martin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /example/groups_4_20.txt: -------------------------------------------------------------------------------- 1 | tsk_0 group1 2 | tsk_1 group1 3 | tsk_2 group1 4 | tsk_3 group1 5 | tsk_4 group1 6 | tsk_5 group1 7 | tsk_6 group1 8 | tsk_7 group1 9 | tsk_8 group1 10 | tsk_9 group1 11 | tsk_10 group1 12 | tsk_11 group1 13 | tsk_12 group1 14 | tsk_13 group1 15 | tsk_14 group1 16 | tsk_15 group1 17 | tsk_16 group1 18 | tsk_17 group1 19 | tsk_18 group1 20 | tsk_19 group1 21 | tsk_20 group2 22 | tsk_21 group2 23 | tsk_22 group2 24 | tsk_23 group2 25 | tsk_24 group2 26 | tsk_25 group2 27 | tsk_26 group2 28 | tsk_27 group2 29 | tsk_28 group2 30 | tsk_29 group2 31 | tsk_30 group2 32 | tsk_31 group2 33 | tsk_32 group2 34 | tsk_33 group2 35 | tsk_34 group2 36 | tsk_35 group2 37 | tsk_36 group2 38 | tsk_37 group2 39 | tsk_38 group2 40 | tsk_39 group2 41 | tsk_40 group3 42 | tsk_41 group3 43 | tsk_42 group3 44 | tsk_43 group3 45 | tsk_44 group3 46 | tsk_45 group3 47 | tsk_46 group3 48 | tsk_47 group3 49 | tsk_48 group3 50 | tsk_49 group3 51 | tsk_50 group3 52 | tsk_51 group3 53 | tsk_52 group3 54 | tsk_53 group3 55 | tsk_54 group3 56 | tsk_55 group3 57 | tsk_56 group3 58 | tsk_57 group3 59 | tsk_58 group3 60 | tsk_59 group3 61 | tsk_60 group4 62 | tsk_61 group4 63 | tsk_62 group4 64 | tsk_63 group4 65 | tsk_64 group4 66 | tsk_65 group4 67 | tsk_66 group4 68 | tsk_67 group4 69 | tsk_68 group4 70 | tsk_69 group4 71 | tsk_70 group4 72 | tsk_71 group4 73 | tsk_72 group4 74 | tsk_73 group4 75 | tsk_74 group4 76 | tsk_75 group4 77 | tsk_76 group4 78 | tsk_77 group4 79 | tsk_78 group4 80 | tsk_79 group4 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # twisst2 2 | 3 | `twisst2` is a tool for [topology weighting](https://doi.org/10.1534/genetics.116.194720). Topology weighting summarises genealogies in terms of the relative abundance of different sub-tree topologies. It can be used to explore processes like introgression and it can aid the identification of trait-associated loci. 4 | 5 | `twisst2` has a number of important improvements over the original [`twisst`](https://github.com/simonhmartin/twisst) tool. Most importantly, `twisst2` **incorporates inference of the ancestral recombination graph (ARG) or tree sequence** - local genealogies and their breakpoints along the chromosome. It does this using [`sticcs`](https://github.com/simonhmartin/sticcs). `sticcs` is a model-free approach and it does not require phased data, so `twisst2` **can run on unphased genotypes of any ploidy**. 6 | 7 | The recommended way to run `twisst2` is to start from polarised genotype data. This means you either need to know the ancestral allele at each site, or you need an appropriate outgroup(s) to allow inference of the derived allele. 8 | 9 | An alternative way to run it is by first inferring the ancestral recombination graph (ARG) tree sequence using a different tool like [Relate](https://myersgroup.github.io/relate/) or [`tsinfer`](https://tskit.dev/tsinfer/docs/stable/index.html). However, this typically requires phased genotypes, and [my tests](https://doi.org/10.1093/genetics/iyaf181) suggest that `twisst2+sticcs` is more accurate than other methods anyway. 10 | 11 | # Publications 12 | - The general concept of topology weighting is described by [Martin and Van Belleghem 2017](https://doi.org/10.1534/genetics.116.194720). 13 | - Combining genealogy inference with `sticcs` and topology weighting with `twisst2` is described by [Martin 2025](https://doi.org/10.1093/genetics/iyaf181). 14 | 15 | 16 | ### Installation 17 | 18 | First install [`sticcs`](https://github.com/simonhmartin/sticcs) by following the intructions there. 19 | 20 | If you would like to analyse tree sequence objects from tools like [`msprime`](https://tskit.dev/msprime/docs/stable/intro.html) and [tsinfer](https://tskit.dev/tsinfer/docs/stable/index.html), you will also need to install [`tskit`](https://tskit.dev/tskit/docs/stable/introduction.html) yourself. To install `twisst2`: 21 | 22 | ```bash 23 | git clone https://github.com/simonhmartin/twisst2.git 24 | 25 | cd twisst2 26 | 27 | pip install -e . 28 | ``` 29 | 30 | ### Command line tool 31 | 32 | #### Starting from unphased (or phased) genotypes 33 | 34 | To perform tree inference and topology weighting, `twisst2` takes as input a modified vcf file that contains a `DC` field, giving the count of derived alleles for each individual at each site. 35 | 36 | Once you have a vcf file for your genotype data, make the modified version using `sticcs` (this needs to be installed, see above): 37 | ```bash 38 | sticcs prep -i -o --outgroup 39 | ``` 40 | 41 | If the vcf file already has the ancestral allele (provided in the `AA` field in the `INFO` section), then you do not need to specifiy outrgoups for polarising. 42 | 43 | Now you can run the `twisst2` to count sub-tree topologies: 44 | 45 | ```bash 46 | twisst2 sticcstack -i -o --max_subtrees 512 --ploidy 2 --groups --groups_file 47 | ``` 48 | 49 | #### Starting from pre-inferred trees or ARG (e.g. Relate, tsinfer, argweaver, Singer) 50 | 51 | ```bash 52 | twisst2 trees -i -o --groups --groups_file 53 | ``` 54 | 55 | ### Output 56 | 57 | - `.topocounts.tsv.gz` gives the count of each group tree topology for each interval. 58 | - `.intervals.tsv.gz` gives the chromosome, start and end position of each interval. 59 | 60 | ### R functions for plotting 61 | 62 | Some functions for importing and plotting are provided in the `plot_twisst/plot_twisst.R` script. For examples of how to use these functions, see the `plot_twisst/example_plot.R` script. 63 | 64 | -------------------------------------------------------------------------------- /plot_twisst/example_plot.R: -------------------------------------------------------------------------------- 1 | 2 | ################################# overview ##################################### 3 | 4 | # The main data produced by Twisst is a weights file which has columns for each 5 | # topology and their number of observations of that topology within each 6 | # genealogy. Weights files produced by Twisst also contain initial comment lines 7 | # speficying the topologies. 8 | 9 | # The other data file that may be of interest is the window data. That is, the 10 | # chromosome/scaffold and start and end positions for each of the regions or 11 | # windows represented in the weights file. 12 | 13 | # Both of the above files can be read into R, manipulated and plotted however 14 | # you like, but I have written some functions to make these tasks easier. 15 | # These functions are provided in the script plot_twisst.R 16 | 17 | ################### load helpful plotting functions ############################# 18 | 19 | source("plot_twisst.R") 20 | 21 | ############################## input files ###################################### 22 | 23 | # Each genomic region (chromosome, contig etc) shgould have two files@ 24 | # A topocounts file giving the count of each topology (one per column) 25 | # And an intervals file, giving the chromosome name, start and end position 26 | 27 | # Here we just import one genomic region 28 | 29 | intervals_file <- "examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.intervals.tsv.gz" 30 | 31 | topocounts_file <- "examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.topocounts.tsv.gz" 32 | 33 | ################################# import data ################################## 34 | 35 | # The function import.twisst reads the topology counts and intervals data files into a list object 36 | # If there are multiple weights files, or a single file with different chromosomes/scaffolds/contigs 37 | # in the window data file, these will be separated when importing. 38 | 39 | twisst_data <- import.twisst(intervals_files=intervals_file, 40 | topocounts_files=topocounts_file) 41 | 42 | #Some additional arguments: 43 | 44 | # topos_file= 45 | # If you prefer to provide the topologies as a separate file (rather than as comment lines in the topocounts file) 46 | 47 | # ignore_extra_columns=TRUE 48 | # This option will ignore subtrees that did not match any of the defined topologies (usually due to polytomies). 49 | # If you use this option, the weightings for every tree will sum to 1, but you may be throwing away some information 50 | # If you use this option, you might want to set the min_subtrees option too. 51 | 52 | # min_subtrees=100 53 | # This option can be used in conjunction with ignore_extra_columns=TRUE. 54 | # For trees with polytomies there may be few subtrees that match any of the defined topologies. 55 | # These weightings would be less reliable (more noisy). 56 | # Any tree with too fewer subtrees considered than the defined number will be ignored. 57 | 58 | # max_interval=10000 59 | # In parts of the genome with bad data, the tree interval can be very large. 60 | # You can simply exclude this information by setting the maximum interval. 61 | 62 | # names= 63 | # If you have multiple regions they will be named according to the chromosome by default, 64 | # but you can give your own names if you prefer 65 | 66 | # reorder_by_start=TRUE 67 | # If the tree intervals are out of order, this option will reorder them by the start position 68 | 69 | 70 | ############################## combined plots ################################## 71 | # there are a functions available to plot both the weightings and the topologies 72 | 73 | #for all plots we will use the 15 colours that come with plot_twisst.R 74 | # But we will reorder them according to the most abundant topology 75 | topo_cols <- topo_cols[order(order(twisst_data$weights_overall_mean[1:length(twisst_data$topos)], decreasing=TRUE))] 76 | 77 | #a summary plot shows all the topologies and a bar plot of their relative weightings 78 | plot.twisst.summary(twisst_data, lwd=3, cex=0.7) 79 | 80 | plot.twisst.summary.boxplot(twisst_data) 81 | 82 | 83 | #or plot ALL the data across the chromosome(s) 84 | # Note, this is not recommended if there are large numbers of windows. 85 | # instead, it is recommended to first smooth the weghtings and plot the smoothed values 86 | # There are three plotting modes to try 87 | plot.twisst(twisst_data, mode=1, show_topos=TRUE, ncol_topos=15) 88 | plot.twisst(twisst_data, mode=2, show_topos=TRUE, ncol_topos=15) 89 | plot.twisst(twisst_data, mode=3, show_topos=TRUE, ncol_topos=15) 90 | 91 | 92 | # make smooth weightings and plot those across chromosomes 93 | twisst_data_smooth <- smooth.twisst(twisst_data, span_bp = 20000, spacing = 1000) 94 | plot.twisst(twisst_data_smooth, mode=2, ncol_topos=15) #mode 2 overlays polygons, mode 3 would stack them 95 | 96 | 97 | ##################### individual plots: raw weights ############################ 98 | 99 | #plot raw data in "stepped" style, with polygons stacked. 100 | #specify stepped style by providing a matrix of starts and ends for positions 101 | par(mfrow = c(1,1), mar = c(4,4,1,1)) 102 | plot.weights(weights_dataframe=twisst_data$weights[[1]], positions=twisst_data$interval_data[[1]][,c("start","end")], 103 | line_cols=topo_cols, fill_cols=topo_cols, stacked=TRUE) 104 | 105 | #plot raw data in stepped style, with polygons unstacked (stacked =FLASE) 106 | #use semi-transparent colours for fill 107 | plot.weights(weights_dataframe=twisst_data$weights[[1]], positions=twisst_data$interval_data[[1]][,c("start","end")], 108 | line_cols=topo_cols, fill_cols=paste0(topo_cols,80), stacked=FALSE) 109 | 110 | 111 | #################### individual plots: smoothed weights ######################## 112 | 113 | #plot smoothed data with polygons stacked 114 | plot.weights(weights_dataframe=twisst_data_smooth$weights[[1]], positions=twisst_data_smooth$pos[[1]], 115 | line_cols=topo_cols, fill_cols=topo_cols, stacked=TRUE) 116 | 117 | #plot smoothed data with polygons unstacked 118 | plot.weights(weights_dataframe=twisst_data_smooth$weights[[1]], positions=twisst_data_smooth$pos[[1]], 119 | line_cols=topo_cols, fill_cols=paste0(topo_cols,80), stacked=FALSE) 120 | 121 | 122 | 123 | #################### subset to only the most abundant topologies ################# 124 | 125 | #get list of the most abundant topologies (top 2 in this case) 126 | top2_topos <- order(twisst_data$weights_overall_mean, decreasing=T)[1:2] 127 | 128 | #subset twisst object for these 129 | twisst_data_top2topos <- subset.twisst.by.topos(twisst_data, top2_topos) 130 | #this can then be used in all the same plotting functions above. 131 | 132 | ######################## subset to only specific regions ######################### 133 | 134 | #regions to keep (more than one can be specified) 135 | regions <- c("chr1") 136 | 137 | #subset twisst object for these 138 | twisst_data_chr1 <- subset.twisst.by.regions(twisst_data, regions) 139 | #this can then be used in all the same plotting functions above. 140 | 141 | 142 | ########################### plot topologies using Ape ########################## 143 | #unroot trees if you want to 144 | # for (i in 1:length(twisst_data$topos)) twisst_data$topos[[i]] <- ladderize(unroot(twisst_data$topos[[i]])) 145 | 146 | par(mfrow = c(3,length(twisst_data$topos)/3), mar = c(1,1,2,1), xpd=NA) 147 | for (n in 1:length(twisst_data$topos)){ 148 | plot.phylo(twisst_data$topos[[n]], type = "cladogram", edge.color=topo_cols[n], edge.width=5, rotate.tree = 90, cex = 1, adj = .5, label.offset=.2) 149 | mtext(side=3,text=paste0("topo",n)) 150 | } 151 | 152 | 153 | -------------------------------------------------------------------------------- /twisst2/parse_newick.py: -------------------------------------------------------------------------------- 1 | def parse_newick(newick_string): 2 | """ 3 | Parse a Newick string into a dictionary representation. 4 | 5 | Returns: 6 | dict: Tree structure where keys are node IDs and values are lists of children 7 | dict: Node metadata (names, branch lengths, etc.) 8 | int: Root node ID 9 | """ 10 | 11 | # Remove whitespace and trailing semicolon 12 | newick = newick_string.strip().rstrip(';') 13 | 14 | # Global variables for the parser 15 | node_counter = 0 16 | node_children = {} 17 | node_label = {} 18 | branch_length = {} 19 | 20 | def get_next_node_id(): 21 | nonlocal node_counter 22 | node_id = node_counter 23 | node_counter += 1 24 | return node_id 25 | 26 | def parse_node(s, pos=0): 27 | """ 28 | Recursively parse a node from position 'pos' in string 's'. 29 | 30 | Returns: 31 | tuple: (node_id, new_position) 32 | """ 33 | nonlocal node_children, node_label, branch_length 34 | 35 | current_node_id = get_next_node_id() 36 | node_children[current_node_id] = [] # Initialize children list 37 | node_label[current_node_id] = None 38 | branch_length[current_node_id] = None 39 | 40 | # Skip whitespace 41 | while pos < len(s) and s[pos].isspace(): pos += 1 42 | 43 | # Case 1: Internal node - starts with '(' 44 | if pos < len(s) and s[pos] == '(': 45 | pos += 1 # Skip opening parenthesis 46 | 47 | # Parse children until we hit the closing parenthesis 48 | while pos < len(s) and s[pos] != ')': 49 | # Skip whitespace and commas 50 | while pos < len(s) and s[pos] in ' \t,': pos += 1 51 | 52 | if pos < len(s) and s[pos] != ')': 53 | # Recursively parse child 54 | child_id, pos = parse_node(s, pos) 55 | node_children[current_node_id].append(child_id) 56 | 57 | if pos < len(s) and s[pos] == ')': pos += 1 # Skip closing parenthesis 58 | 59 | # Case 2 & 3: Parse node label and branch length (for both leaf and internal nodes) 60 | # Node label comes first 61 | label_start = pos 62 | while pos < len(s) and s[pos] not in '[:,();': pos += 1 63 | 64 | if pos > label_start: 65 | node_label[current_node_id] = s[label_start:pos].strip() 66 | 67 | # Branch length comes after ':' 68 | if pos < len(s) and s[pos] == ':': 69 | pos += 1 # Skip ':' 70 | 71 | # Parse branch length 72 | length_start = pos 73 | while pos < len(s) and s[pos] not in '[,();': pos += 1 74 | 75 | if pos > length_start: 76 | try: 77 | branch_length[current_node_id] = float(s[length_start:pos]) 78 | except: 79 | raise ValueError(f"{length_start:pos} not a valid branch length") 80 | #pass # Invalid branch length, keep as None 81 | 82 | #progress past anything included beyond branch length 83 | while pos < len(s) and s[pos] not in ',();': pos += 1 84 | 85 | return current_node_id, pos 86 | 87 | # Start parsing from the beginning 88 | root_id, _ = parse_node(newick, 0) 89 | 90 | #now figure out which nodes are leaves and which are parents 91 | n_nodes = get_next_node_id() 92 | parents = [] 93 | leaves = [] 94 | for id in range(n_nodes): 95 | if node_children[id] == []: leaves.append(id) 96 | else: parents.append(id) 97 | 98 | return node_children, node_label, branch_length, leaves, parents, root_id 99 | 100 | 101 | from sticcs import sticcs 102 | 103 | 104 | # another function to convert a newick string to a sticcs tree. 105 | # This object uses numeric leaf IDs from 0 to N-1. 106 | # These can be passed to the function as a dictionary. 107 | # Otherwise the function attempts to convert leaf IDs to integers 108 | # Internal nodes don't need the same level of consistency, so those will 109 | # be kept in the order i which they were found, but renumbered so they come after the leaf indices 110 | def newick_to_sticcs_Tree(newick_string, leaf_idx_dict=None, allow_additional_leaves=False, interval=None): 111 | 112 | #first parse newick and number nodes and leaves by order of finding them 113 | node_children, node_label, branch_length, leaves, parents, root_id = parse_newick(newick_string) 114 | #nodes that come out of parse_newick are integers, but they are in the order they were encountered in reading the string, so they are meaningless 115 | #What matters is the node_labels, as these are what those nodes were called (i.e. leaf names, but internal nodes can theoretically have labels too) 116 | #Now the sticcs tree needs numeric leaves, but these need to start from zero - so they are NOT the same as the nodes that come from parse_newick() 117 | #Instead each node_label needs to map to an integer between zero and n_leaves 118 | #This can be provided in the leaf_idx_dict 119 | #If not, all node_labels need to be integers in the newick string 120 | 121 | n_leaves = len(leaves) 122 | 123 | #get the label for each leaf 124 | leaf_labels = [node_label[leaf_number] for leaf_number in leaves] 125 | assert len(set(leaf_labels)) == n_leaves, "Leaf labels must all be unique." 126 | 127 | #now, if there is no dictionary provided, we assume leaves are already numbered 0:n-1 128 | if not leaf_idx_dict: 129 | if not set(leaf_labels) == set([str(i) for i in range(n_leaves)]): 130 | raise ValueError("Leaf labels not consecutive integers. Please provide a leaf_idx_dict with consecutive indices from 0 to n_leaves-1") 131 | 132 | _leaf_idx_dict_ = dict([(label, int(label)) for label in leaf_labels]) 133 | 134 | else: 135 | #We already have the dictionary to convert leaf number to new index 136 | _leaf_idx_dict_ = {} 137 | _leaf_idx_dict_.update(leaf_idx_dict) 138 | assert set(_leaf_idx_dict_.values()) == set(range(len(_leaf_idx_dict_))), "leaf_idx_dict must provide consecutive indices from zero" 139 | #Check that all leaves are represented 140 | unrepresented_leaves = [label for label in leaf_labels if label not in leaf_idx_dict] 141 | if len(unrepresented_leaves) > 0: 142 | if allow_additional_leaves: 143 | i = len(_leaf_idx_dict_) 144 | for leaf_label in sorted(unrepresented_leaves): 145 | _leaf_idx_dict_[leaf_label] = i 146 | i += 1 147 | else: 148 | raise ValueError("Some leaves are not listed in leaf_idx_dict. Set allow_additional_leaves=True if you want to risk it.") 149 | 150 | new_leaf_IDs = sorted(_leaf_idx_dict_.values()) 151 | 152 | #now assign the new ID to each leaf number (remember leaf numbers in the parsed newick are simply in the order they were encountered) 153 | new_node_idx = dict(zip(leaves, [_leaf_idx_dict_[node_label[leaf_number]] for leaf_number in leaves])) 154 | 155 | #the new parents are just n_leaves along from where they were 156 | new_parents = [id+n_leaves for id in parents] 157 | 158 | new_node_idx.update(dict(zip(parents, new_parents))) 159 | 160 | #now update objects for the new sticcs Tree object 161 | new_node_children = dict() 162 | for item in node_children.items(): 163 | new_node_children[new_node_idx[item[0]]] = [new_node_idx[x] for x in item[1]] 164 | 165 | new_node_parent = dict() 166 | for item in new_node_children.items(): 167 | for child in item[1]: 168 | new_node_parent[child] = item[0] 169 | 170 | new_node_label = dict() 171 | for item in node_label.items(): 172 | new_node_label[new_node_idx[item[0]]] = item[1] 173 | 174 | objects = {"leaves":new_leaf_IDs, "root":n_leaves, "parents":new_parents, 175 | "node_children": new_node_children, "node_parent":new_node_parent, 176 | "interval":interval} 177 | 178 | return (sticcs.Tree(objects=objects), new_node_label) 179 | 180 | 181 | 182 | # a class that just holds a bunch of trees, but has a few attributes and methods that resemble a tskit treesequence object 183 | # This allos us to import newick trees and treat the list as a treesequence for a few things (like my quartet distance calculation) 184 | class TreeList: 185 | def __init__(self, trees): 186 | self.num_trees = len(trees) 187 | self.tree_list = trees 188 | 189 | def trees(self): 190 | for tree in self.tree_list: 191 | yield tree 192 | 193 | 194 | def parse_newick_file(newickfile, leaf_names=None): 195 | 196 | #if leaves are not integers, need to link each leaf to its index 197 | leaf_idx_dict = dict(zip(leaf_names, range(len(leaf_names)))) if leaf_names is not None else None 198 | 199 | trees = [] 200 | 201 | i = 1 202 | 203 | for line in newickfile: 204 | tree, node_labels = newick_to_sticcs_Tree(line.strip(), leaf_idx_dict=leaf_idx_dict, allow_additional_leaves=True, interval = (i, i)) 205 | trees.append(tree) 206 | i+=1 207 | 208 | return TreeList(trees) 209 | 210 | 211 | def parse_argweaver_smc(argfile, leaf_names=None): 212 | #annoyingly, the trees have numeric leaves, but these are NOT in the order of the haps 213 | #The order if given by the first line, so we can figure out which hap each leaf points to 214 | 215 | #get the order of haps from argweaver output 216 | #The position of each number in this list links it to a leaf (leaves are numeric) 217 | leaf_names_reordered = argfile.readline().split()[1:] 218 | 219 | n_leaves = len(leaf_names_reordered) 220 | 221 | _leaf_names = leaf_names if leaf_names is not None else [str(i) for i in range(n_leaves)] 222 | 223 | #link each leaf name to its real index. This is the index that will be used for topology weighting 224 | leaf_name_to_idx_dict = dict([(_leaf_names[i], i) for i in range(n_leaves)]) 225 | 226 | #link the leaf number in the current file (which is arbitrary) to its real index 227 | leaf_idx_dict = dict([(str(i), leaf_name_to_idx_dict[leaf_names_reordered[i]]) for i in range(n_leaves)]) 228 | 229 | trees = [] 230 | 231 | chrom, chrom_start, chrom_len = argfile.readline().split()[1:] 232 | for line in argfile: 233 | if line.startswith("TREE"): 234 | elements = line.split() 235 | interval=(int(elements[1]), int(elements[2]),) 236 | tree_newick = elements[3] 237 | tree, node_labels = newick_to_sticcs_Tree(tree_newick, leaf_idx_dict=leaf_idx_dict, allow_additional_leaves=True, interval=interval) 238 | trees.append(tree) 239 | 240 | return TreeList(trees) 241 | 242 | 243 | 244 | def sim_test_newick_parser(n=12, reps=5): 245 | import msprime 246 | from twisst2 import TopologySummary 247 | 248 | for i in range(reps): 249 | ts = msprime.sim_ancestry(n, ploidy=1) 250 | t = ts.first() 251 | sim_tree_summary = TopologySummary.TopologySummary(t) 252 | sim_tree_ID = sim_tree_summary.get_topology_ID() 253 | sim_tree_newick = t.as_newick(include_branch_lengths=False) 254 | print("Simulated tree:", sim_tree_newick) 255 | 256 | #now read it 257 | parsed_tree, node_labels = newick_to_sticcs_Tree(sim_tree_newick, leaf_idx_dict= dict([("n"+str(i), i) for i in range(n)])) 258 | 259 | parsed_tree_summary = TopologySummary.TopologySummary(parsed_tree) 260 | parsed_tree_ID = parsed_tree_summary.get_topology_ID() 261 | parsed_tree_ID == sim_tree_ID 262 | print("Parsed tree: ", parsed_tree.as_newick(node_labels=node_labels)) 263 | 264 | print("Match:", parsed_tree_ID == sim_tree_ID) 265 | 266 | 267 | 268 | # Test the parser 269 | if __name__ == "__main__": 270 | 271 | #Manual test 272 | #newick_string = '((n1,n4),((n0,n3),(n2,n5)));' 273 | #parsed_tree, node_labels = newick_to_sticcs_Tree(newick_string, leaf_idx_dict={"n0":0, "n1":1, "n2":2, "n3":3, "n4":4, "n5": 5}) 274 | #print(parsed_tree.as_newick(node_labels=node_labels)) 275 | 276 | #newick_string = '((1,4),((0,3),(2,5)));' 277 | #parsed_tree, node_labels = newick_to_sticcs_Tree(newick_string) 278 | #print(parsed_tree.as_newick()) 279 | 280 | #Auto test multiple times 281 | sim_test_newick_parser() 282 | 283 | -------------------------------------------------------------------------------- /twisst2/TopologySummary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | from collections import deque 4 | import itertools, random 5 | 6 | class NodeChain(deque): 7 | def __init__(self, nodeList, dists=None): 8 | super(NodeChain, self).__init__(nodeList) 9 | if dists is None: self.dists = None 10 | else: 11 | assert len(dists) == len(self)-1, "incorrect number of iternode distances" 12 | self.dists = deque(dists) 13 | self._set_ = None 14 | 15 | def addNode(self, name, dist=0): 16 | self.append(name) 17 | if self.dists is not None: self.dists.append(dist) 18 | 19 | def addNodeLeft(self, name, dist=0): 20 | self.appendleft(name) 21 | if self.dists is not None: self.dists.appendleft(dist) 22 | 23 | def addNodeChain(self, chainToAdd, joinDist=0): 24 | self.extend(chainToAdd) 25 | if self.dists is not None: 26 | assert chainToAdd.dists is not None, "Cannot add a chain without distances to one with distances" 27 | self.dists.append(joinDist) 28 | self.dists.extend(chainToAdd.dists) 29 | 30 | def addNodeChainLeft(self, chainToAdd, joinDist=0): 31 | self.extendleft(chainToAdd) 32 | if self.dists is not None: 33 | assert chainToAdd.dists is not None, "Cannot add a chain without distances to one with distances" 34 | self.dists.appendleft(joinDist) 35 | self.dists.extendleft(chainToAdd.dists) 36 | 37 | def chopLeft(self): 38 | self.popleft() 39 | if self.dists is not None: self.dists.popleft() 40 | 41 | def chop(self): 42 | self.pop() 43 | if self.dists is not None: self.dists.pop() 44 | 45 | def fuseLeft(self, chainToFuse): 46 | new = NodeChain(self, self.dists) 47 | assert new[0] == chainToFuse[0], "No common nodes" 48 | i = 1 49 | while new[1] == chainToFuse[i]: 50 | new.chopLeft() 51 | i += 1 52 | m = len(chainToFuse) 53 | while i < m: 54 | new.addNodeLeft(chainToFuse[i], chainToFuse.dists[i-1] if self.dists is not None else None) 55 | i += 1 56 | return new 57 | 58 | def simplifyToEnds(self, newDist=None): 59 | if self.dists is not None: 60 | if not newDist: newDist = sum(self.dists) 61 | self.dists.clear() 62 | leftNode = self.popleft() 63 | rightNode = self.pop() 64 | self.clear() 65 | self.append(leftNode) 66 | self.append(rightNode) 67 | if self.dists is not None: 68 | self.dists.append(newDist) 69 | 70 | def setSet(self): 71 | self._set_ = set(self) 72 | 73 | 74 | def getChainsToLeaves(tree, node=None, simplifyDict = None): 75 | if node is None: node = tree.root 76 | children = tree.children(node) #this syntax for tskit trees 77 | #children = tree.node_children[node] 78 | if len(children) == 0: 79 | #if it has no children is is a child 80 | #if it's in the simplifyDict or there is not simplifyDict 81 | #just record a weight for the node and return is as a new 1-node chain 82 | if simplifyDict is None or node in simplifyDict: 83 | chain = NodeChain([node]) 84 | setattr(chain, "weight", 1) 85 | return [chain] 86 | else: 87 | return [] 88 | #otherwise get chains for all children 89 | childrenChains = [getChainsToLeaves(tree, child, simplifyDict) for child in children] 90 | #now we have the chains from all children, we need to add the current node 91 | for childChains in childrenChains: 92 | for chain in childChains: chain.addNodeLeft(node) 93 | 94 | #if collapsing, check groups for each node 95 | if simplifyDict: 96 | nodeGroupsAll = np.array([simplifyDict[chain[-1]] for childChains in childrenChains for chain in childChains]) 97 | nodeGroups = list(set(nodeGroupsAll)) 98 | nGroups = len(nodeGroups) 99 | 100 | if (nGroups == 1 and len(nodeGroupsAll) > 1) or (nGroups == 2 and len(nodeGroupsAll) > 2): 101 | #all chains end in a leaf from one or two groups, so we can simplify. 102 | #first list all chains 103 | chains = [chain for childChains in childrenChains for chain in childChains] 104 | #Start by getting index of each chain for each group 105 | indices = [(nodeGroupsAll == group).nonzero()[0] for group in nodeGroups] 106 | #the new weight for each chain we keep will be the total node weight of all from each group 107 | newWeights = [sum([chains[i].weight for i in idx]) for idx in indices] 108 | #now reduce to just a chain for each group 109 | chains = [chains[idx[0]] for idx in indices] 110 | for j in range(nGroups): 111 | chains[j].simplifyToEnds() 112 | chains[j].weight = newWeights[j] 113 | 114 | #if we couldn't simply collapse completely, we might still be able to merge down a side branch 115 | #Side branches are child chains ending in a single leaf 116 | #If there is a lower level child branch that is itself a side branch, we can merge to it 117 | elif (len(childrenChains) == 2 and 118 | ((len(childrenChains[0]) == 1 and len(childrenChains[1]) > 1) or 119 | (len(childrenChains[1]) == 1 and len(childrenChains[0]) > 1))): 120 | chains,sideChain = (childrenChains[1],childrenChains[0][0]) if len(childrenChains[0]) == 1 else (childrenChains[0],childrenChains[1][0]) 121 | #now check if any main chain is suitable (should be length 3, and the only one that is such. and have correct group 122 | targets = (np.array([len(chain) for chain in chains]) == 3).nonzero()[0] 123 | if len(targets) == 1 and simplifyDict[chains[targets[0]][-1]] == simplifyDict[sideChain[-1]]: 124 | #we have found a suitable internal chain to merge to 125 | targetChain = chains[targets[0]] 126 | newWeight = targetChain.weight + sideChain.weight 127 | targetChain.simplifyToEnds() 128 | targetChain.weight = newWeight 129 | else: 130 | #if we didn't find a suitable match, just add side chain 131 | chains.append(sideChain) 132 | else: 133 | #if there was no side chain, just list all chains 134 | chains = [chain for childChains in childrenChains for chain in childChains] 135 | #otherwise we are not collapsing, so just list all chains 136 | else: 137 | chains = [chain for childChains in childrenChains for chain in childChains] 138 | #now we have the chains from all children, we need to add the current node 139 | 140 | return chains 141 | 142 | 143 | def get_leaf_combos(leaf_groups, max_subtrees): 144 | total = np.prod([len(t) for t in leaf_groups]) 145 | if total <= max_subtrees: 146 | for combo in itertools.product(*leaf_groups): 147 | yield combo 148 | else: 149 | for i in range(max_subtrees): 150 | yield tuple(random.choice(group) for group in groups) 151 | 152 | def pairsDisjoint(pairOfPairs): 153 | if pairOfPairs[0][0] in pairOfPairs[1] or pairOfPairs[0][1] in pairOfPairs[1]: return False 154 | return True 155 | 156 | 157 | pairs_generic = dict([(n, tuple(itertools.combinations(range(n),2))) for n in range(4,8)]) 158 | 159 | pairPairs_generic = dict([(n, [pairPair for pairPair in itertools.combinations(pairs_generic[n],2) if pairsDisjoint(pairPair)],) for n in range(4,8)]) 160 | 161 | 162 | class TopologySummary: 163 | #Summarises a tree as a set of NodeChain objects 164 | #Provides a convenient way to summarise the topology as a unique ID 165 | #If group data is provided, tree is only summarised in terms of chains 166 | # between nodes from distinct groups, and redundant chains (e.g. monophyletic clades) 167 | # are simplified and can be weighted accordingly with recorded leaf weights 168 | 169 | def __init__(self, tree, simplifyDict=None): 170 | self.root = tree.root 171 | self.simplifyDict = simplifyDict 172 | chains = getChainsToLeaves(tree, simplifyDict=self.simplifyDict) 173 | self.leafWeights = {} 174 | self.chains = {} 175 | #record each root-leaf chain with th 176 | for chain in chains: 177 | self.chains[(chain[0], chain[-1])] = chain 178 | self.leafWeights[chain[-1]] = chain.weight 179 | 180 | self.leavesRetained = set(self.leafWeights.keys()) 181 | #now make chains for all pairs of leaves (or all pairs in separate groups if defined) 182 | if self.simplifyDict: 183 | leafPairs = [pair for pair in itertools.combinations(self.leavesRetained,2) if self.simplifyDict[pair[0]] != self.simplifyDict[pair[1]]] 184 | else: 185 | leafPairs = [pair for pair in itertools.combinations(self.leavesRetained,2)] 186 | 187 | for l0,l1 in leafPairs: 188 | self.chains[(l0,l1)] = self.chains[(tree.root,l0)].fuseLeft(self.chains[(tree.root,l1)]) 189 | 190 | #add a reversed entry for each chain, and set the set for each one 191 | for pair in list(self.chains.keys()): 192 | self.chains[pair].setSet() 193 | self.chains[pair[::-1]] = self.chains[pair] 194 | 195 | #add weight 1 for root (do this now only because we don't want to include the root in the leavesRetained set) 196 | self.leafWeights[tree.root] = 1 197 | 198 | #def get_topology_ID(self, leaves, unrooted=False): 199 | #if leaves is None: leaves = sorted(self.leavesRetained) 200 | #if not unrooted: leaves = list(leaves) + [self.root] 201 | #n = len(leaves) 202 | #return tuple([self.chains[(leaves[pairs[0][0]], 203 | #leaves[pairs[0][1]],)]._set_.isdisjoint(self.chains[(leaves[pairs[1][0]], 204 | #leaves[pairs[1][1]],)]._set_) 205 | #for pairs in pairPairs_generic[n]]) 206 | 207 | def get_quartet_ID(self, quartet): 208 | if self.chains[(quartet[0],quartet[1],)]._set_.isdisjoint(self.chains[(quartet[2],quartet[3],)]._set_): 209 | return 0 210 | elif self.chains[(quartet[0],quartet[2],)]._set_.isdisjoint(self.chains[(quartet[1],quartet[3],)]._set_): 211 | return 1 212 | elif self.chains[(quartet[0],quartet[3],)]._set_.isdisjoint(self.chains[(quartet[1],quartet[2],)]._set_): 213 | return 2 214 | return 3 215 | 216 | #def get_all_quartet_IDs(self, leaves=None, unrooted=False): 217 | #if leaves is None: leaves = sorted(self.leavesRetained) 218 | #if not unrooted: leaves = list(leaves) + [self.root] 219 | #return [self.get_quartet_ID(quartet) for quartet in itertools.combinations(leaves, 4)] 220 | 221 | def get_all_quartet_IDs(self, leaves=None, unrooted=False): 222 | if leaves is None: leaves = sorted(self.leavesRetained) 223 | if unrooted: 224 | return [self.get_quartet_ID(quartet) for quartet in itertools.combinations(leaves, 4)] 225 | else: 226 | return [self.get_quartet_ID(trio + (self.root,)) for trio in itertools.combinations(leaves, 3)] 227 | 228 | def get_topology_ID(self, leaves=None, unrooted=False): 229 | return tuple(self.get_all_quartet_IDs(leaves, unrooted)) 230 | 231 | def get_topology_counts(self, leaf_groups, max_subtrees, unrooted=False): 232 | _leaf_groups = [[leaf for leaf in group if leaf in self.leavesRetained] for group in leaf_groups] 233 | 234 | if self.simplifyDict is not None: 235 | total = np.prod([len(g) for g in _leaf_groups]) 236 | assert total <= max_subtrees, f"With groups {_leaf_groups}, there will be {total} subtrees, but you have requested only {max_subtrees}, you you need to turn off tree simplification or increase max_subtrees." 237 | 238 | leaf_combos = get_leaf_combos(_leaf_groups, max_subtrees) 239 | counts = defaultdict(int) 240 | for combo in leaf_combos: 241 | comboWeight = np.prod([self.leafWeights[leaf] for leaf in combo]) 242 | ID = self.get_topology_ID(combo, unrooted) 243 | counts[ID] += comboWeight 244 | 245 | return counts 246 | 247 | def get_quartet_dist(tree1, tree2, unrooted=False, approximation_subset_size=None): 248 | topoSummary1 = TopologySummary(tree1) 249 | topoSummary2 = TopologySummary(tree2) 250 | leaves = list(topoSummary1.leavesRetained) 251 | 252 | if approximation_subset_size is None: 253 | quartetIDs1 = np.array(topoSummary1.get_all_quartet_IDs(unrooted=unrooted)) 254 | quartetIDs2 = np.array(topoSummary2.get_all_quartet_IDs(unrooted=unrooted)) 255 | else: 256 | #do approximate distance with random sets of quartets 257 | quartetIDs1 = np.zeros(approximation_subset_size, dtype=int) 258 | quartetIDs2 = np.zeros(approximation_subset_size, dtype=int) 259 | 260 | for i in range(approximation_subset_size): 261 | if unrooted: 262 | quartet = random.sample(leaves, 4) 263 | quartetIDs1[i] = topoSummary1.get_quartet_ID(quartet) 264 | quartetIDs2[i] = topoSummary2.get_quartet_ID(quartet) 265 | 266 | else: 267 | trio = random.sample(leaves, 3) 268 | quartetIDs1[i] = topoSummary1.get_quartet_ID(trio + [tree1.root]) 269 | quartetIDs2[i] = topoSummary2.get_quartet_ID(trio + [tree2.root]) 270 | 271 | dif = quartetIDs1 - quartetIDs2 272 | return np.mean(dif != 0) 273 | 274 | 275 | def get_min_quartet_dist(tree1, tree2, inds, max_itr=10, unrooted=False): 276 | topoSummary1 = TopologySummary(tree1) 277 | topoSummary2 = TopologySummary(tree2) 278 | 279 | #quartet IDs for tree 1 are unchanging 280 | quartetIDs1 = np.array(topoSummary1.get_all_quartet_IDs(unrooted=unrooted)) 281 | 282 | #for tree 2 we try with different permutations for each individual 283 | new_inds = inds[:] 284 | 285 | for itr in range(max_itr): 286 | 287 | for i in range(len(inds)): 288 | 289 | ind_orderings = list(itertools.permutations(inds[i])) 290 | 291 | dists = [] 292 | 293 | for ind_ordering in ind_orderings: 294 | current_inds = new_inds[:] 295 | current_inds[i] = ind_ordering 296 | quartetIDs2 = np.array(topoSummary2.get_all_quartet_IDs(leaves=[i for ind in current_inds for i in ind], unrooted=unrooted)) 297 | dif = quartetIDs1 - quartetIDs2 298 | dists.append(np.mean(dif != 0)) 299 | 300 | new_inds[i] = ind_orderings[np.argmin(dists)] 301 | 302 | if itr > 0 and new_inds == previous_new_inds: break 303 | else: 304 | previous_new_inds = new_inds[:] 305 | 306 | #get final dist 307 | quartetIDs2 = np.array(topoSummary2.get_all_quartet_IDs(leaves=[i for ind in new_inds for i in ind], unrooted=unrooted)) 308 | dif = quartetIDs1 - quartetIDs2 309 | 310 | return np.mean(dif != 0) 311 | 312 | -------------------------------------------------------------------------------- /plot_twisst/plot_twisst.R: -------------------------------------------------------------------------------- 1 | simple.loess.predict <- function(x, y, span, new_x=NULL, weights = NULL, max = NULL, min = NULL, family=NULL){ 2 | y.loess <- loess(y ~ x, span = span, weights = weights, family=family) 3 | if (is.null(new_x)) {y.predict <- predict(y.loess,x)} 4 | else {y.predict <- predict(y.loess,new_x)} 5 | if (is.null(min) == FALSE) {y.predict = ifelse(y.predict > min, y.predict, min)} 6 | if (is.null(max) == FALSE) {y.predict = ifelse(y.predict < max, y.predict, max)} 7 | y.predict 8 | } 9 | 10 | smooth.df <- function(x, df, span, new_x = NULL, col.names=NULL, weights=NULL, min=NULL, max=NULL, family=NULL){ 11 | if (is.null(new_x)) {smoothed <- df} 12 | else smoothed = df[1:length(new_x),] 13 | if (is.null(col.names)){col.names=colnames(df)} 14 | for (col.name in col.names){ 15 | print(paste("smoothing",col.name)) 16 | smoothed[,col.name] <- simple.loess.predict(x,df[,col.name],span = span, new_x = new_x, max = max, min = min, weights = weights, family=family) 17 | } 18 | smoothed 19 | } 20 | 21 | smooth.weights <- function(interval_positions, weights_dataframe, span, new_positions=NULL, interval_lengths=NULL){ 22 | weights_smooth <- smooth.df(x=interval_positions,df=weights_dataframe, weights = interval_lengths, 23 | span=span, new_x=new_positions, min=0, max=1) 24 | 25 | #return rescaled to sum to 1 26 | weights_smooth <- weights_smooth / apply(weights_smooth, 1, sum) 27 | 28 | weights_smooth[is.na(weights_smooth)] <- 0 29 | 30 | weights_smooth 31 | } 32 | 33 | 34 | stack <- function(mat){ 35 | upper <- t(apply(mat, 1, cumsum)) 36 | lower <- upper - mat 37 | list(upper=upper,lower=lower) 38 | } 39 | 40 | interleave <- function(x1,x2){ 41 | output <- vector(length= length(x1) + length(x2)) 42 | output[seq(1,length(output),2)] <- x1 43 | output[seq(2,length(output),2)] <- x2 44 | output 45 | } 46 | 47 | 48 | sum_df_columns <- function(df, columns_list){ 49 | new_df <- df[,0] 50 | for (x in 1:length(columns_list)){ 51 | if (length(columns_list[[x]]) > 1) new_df[,x] <- apply(df[,columns_list[[x]]], 1, sum, na.rm=T) 52 | else new_df[,x] <- df[,columns_list[[x]]] 53 | if (is.null(names(columns_list)[x]) == FALSE) names(new_df)[x] <- names(columns_list)[x] 54 | } 55 | new_df 56 | } 57 | 58 | 59 | plot.weights <- function(weights_dataframe,positions=NULL,line_cols=NULL,fill_cols=NULL,density=NULL,lwd=1,xlim=NULL,ylim=c(0,1),stacked=FALSE, 60 | ylab="Weighting", xlab = "Position", main="",xaxt=NULL,yaxt=NULL,bty="n", add=FALSE){ 61 | #get x axis 62 | x = positions 63 | #if a two-column matrix is given - plot step-like weights with start and end of each interval 64 | if (dim(as.matrix(x))[2]==2) { 65 | x = interleave(positions[,1],positions[,2]) 66 | yreps=2 67 | } 68 | else { 69 | if (is.null(x)==FALSE) x = positions 70 | else x = 1:nrow(weights_dataframe) 71 | yreps=1 72 | } 73 | 74 | #set x limits 75 | if(is.null(xlim)) xlim = c(min(x), max(x)) 76 | 77 | #if not adding to an old plot, make a new plot 78 | if (add==FALSE) plot(0, pch = "", xlim = xlim, ylim=ylim, ylab=ylab, xlab=xlab, main=main,xaxt=xaxt,yaxt=yaxt,bty=bty) 79 | 80 | if (stacked == TRUE){ 81 | y_stacked <- stack(weights_dataframe) 82 | for (n in 1:ncol(weights_dataframe)){ 83 | y_upper = rep(y_stacked[["upper"]][,n],each=yreps) 84 | y_lower = rep(y_stacked[["lower"]][,n],each = yreps) 85 | polygon(c(x,rev(x)),c(y_upper, rev(y_lower)), col = fill_cols[n], density=density[n], border=NA) 86 | } 87 | } 88 | else{ 89 | for (n in 1:ncol(weights_dataframe)){ 90 | y = rep(weights_dataframe[,n],each=yreps) 91 | polygon(c(x,rev(x)),c(y, rep(0,length(y))), col=fill_cols[n], border=NA,density=density[n]) 92 | lines(x,y, type = "l", col = line_cols[n],lwd=lwd) 93 | } 94 | } 95 | } 96 | 97 | options(scipen = 7) 98 | 99 | #Heres a set of 15 colourful colours from https://en.wikipedia.org/wiki/Help:Distinguishable_colors 100 | topo_cols <- c( 101 | "#0075DC", #Blue 102 | "#2BCE48", #Green 103 | "#FFA405", #Orpiment 104 | "#5EF1F2", #Sky 105 | "#FF5005", #Zinnia 106 | "#005C31", #Forest 107 | "#00998F", #Turquoise 108 | "#FF0010", #Red 109 | "#9DCC00", #Lime 110 | "#003380", #Navy 111 | "#F0A3FF", #Amethyst 112 | "#740AFF", #Violet 113 | "#426600", #Quagmire 114 | "#C20088", #Mallow 115 | "#94FFB5") #Jade 116 | 117 | 118 | 119 | 120 | ########### Below are some more object-oriented tools for working with standard twisst output files 121 | 122 | library(ape) 123 | library(tools) 124 | 125 | #a function that imports topocounts and computes weights 126 | import.twisst <- function(topocounts_files, intervals_files=NULL, split_by_chrom=TRUE, reorder_by_start=FALSE, na.rm=TRUE, max_interval=Inf, 127 | lengths=NULL, topos_file=NULL, ignore_extra_columns=FALSE, min_subtrees=1, recalculate_mid=FALSE, names=NULL){ 128 | l = list() 129 | 130 | if (length(intervals_files) > 1){ 131 | print("Reading topocounts and interval data") 132 | l$interval_data <- lapply(intervals_files, read.table ,header=TRUE) 133 | l$topocounts <- lapply(topocounts_files, read.table, header=TRUE) 134 | if (is.null(names) == FALSE) names(l$interval_data) <- names(l$topocounts) <- names 135 | } 136 | 137 | if (length(intervals_files) == 1){ 138 | print("Reading topocounts and interval data") 139 | l$interval_data <- list(read.table(intervals_files, header=TRUE)) 140 | l$topocounts <- list(read.table(topocounts_files, header=TRUE)) 141 | if (split_by_chrom == TRUE){ 142 | l$topocounts <- split(l$topocounts[[1]], l$interval_data[[1]][,1]) 143 | l$interval_data <- split(l$interval_data[[1]], l$interval_data[[1]][,1]) 144 | } 145 | } 146 | 147 | if (is.null(intervals_files) == TRUE) { 148 | print("Reading topocounts") 149 | l$topocounts <- lapply(topocounts_files, read.table, header=TRUE) 150 | n <- nrow(l$topocounts[[1]]) 151 | l$interval_data <- list(data.frame(chrom=rep(0,n), start=1:n, end=1:n)) 152 | if (is.null(names) == FALSE) names(l$interval_data) <- names 153 | } 154 | 155 | l$n_regions <- length(l$topocounts) 156 | 157 | if (is.null(names(l$interval_data)) == TRUE) { 158 | names(l$interval_data) <- names(l$topocounts) <- paste0("region", 1:l$n_regions) 159 | } 160 | 161 | print(paste("Number of regions:", l$n_regions)) 162 | 163 | if (reorder_by_start==TRUE & is.null(intervals_files) == FALSE){ 164 | print("Reordering") 165 | orders = sapply(l$interval_data, function(df) order(df[,2]), simplify=FALSE) 166 | l$interval_data <- sapply(names(orders), function(x) l$interval_data[[x]][orders[[x]],], simplify=F) 167 | l$topocounts <- sapply(names(orders), function(x) l$topocounts[[x]][orders[[x]],], simplify=F) 168 | } 169 | 170 | print("Getting topologies") 171 | 172 | #attempt to retrieve topologies 173 | l$topos=NULL 174 | #first, check if a topologies file is provided 175 | if (is.null(topos_file) == FALSE) { 176 | l$topos <- read.tree(file=topos_file) 177 | if (is.null(names(l$topos)) == TRUE) names(l$topos) <- names(l$topocounts[[1]]) 178 | } 179 | else{ 180 | #otherwise we try to retrieve topologies from the (first) topocounts file 181 | n_topos = ncol(l$topocounts[[1]]) - 1 182 | topos_text <- read.table(topocounts_files[1], nrow=n_topos, comment.char="", sep="\t", as.is=T)[,1] 183 | try(l$topos <- read.tree(text = topos_text)) 184 | try(names(l$topos) <- sapply(names(l$topos), substring, 2)) 185 | } 186 | 187 | print("Cleaning data") 188 | 189 | if (ignore_extra_columns == TRUE & is.null(l$topos)==FALSE) { 190 | for (i in 1:l$n_regions){ 191 | #columns that are unwanted 192 | l$topocounts[[i]] <- l$topocounts[[i]][,1:length(l$topos)] 193 | } 194 | } 195 | 196 | if (na.rm==TRUE){ 197 | for (i in 1:l$n_regions){ 198 | #remove rows containing NA values 199 | row_sums = apply(l$topocounts[[i]],1,sum) 200 | good_rows = which(is.na(row_sums) == F & row_sums >= min_subtrees & 201 | l$interval_data[[i]]$end - l$interval_data[[i]]$start + 1 <= max_interval) 202 | l$topocounts[[i]] <- l$topocounts[[i]][good_rows,] 203 | l$interval_data[[i]] = l$interval_data[[i]][good_rows,] 204 | } 205 | } 206 | 207 | print("Computing summaries") 208 | 209 | l$weights <- sapply(l$topocounts, function(raw) raw/apply(raw, 1, sum), simplify=FALSE) 210 | 211 | l$weights_mean <- t(sapply(l$weights, apply, 2, mean, na.rm=T)) 212 | 213 | #weighting per region as a total. This will be used for getting the overall mean 214 | weights_totals <- apply(t(sapply(l$weights, apply, 2, sum, na.rm=T)), 2, sum) 215 | 216 | l$weights_overall_mean <- weights_totals / sum(weights_totals) 217 | 218 | if (is.null(lengths) == TRUE) l$lengths <- sapply(l$interval_data, function(df) tail(df$end,1), simplify=TRUE) 219 | else l$lengths = lengths 220 | 221 | 222 | for (i in 1:length(l$interval_data)) { 223 | if (is.null(l$interval_data[[i]]$mid) == TRUE | recalculate_mid == TRUE) { 224 | l$interval_data[[i]]$mid <- (l$interval_data[[i]]$start + l$interval_data[[i]]$end)/2 225 | } 226 | } 227 | 228 | l$pos=sapply(l$interval_data, function(df) df$mid, simplify=FALSE) 229 | 230 | l 231 | } 232 | 233 | 234 | smooth.twisst <- function(twisst_object, span=0.05, span_bp=NULL, new_positions = NULL, spacing=NULL) { 235 | l=list() 236 | 237 | l$topos <- twisst_object$topos 238 | 239 | l$n_regions <- twisst_object$n_regions 240 | 241 | l$weights <- list() 242 | 243 | l$lengths = twisst_object$lengths 244 | 245 | l$pos <- list() 246 | 247 | for (i in 1:l$n_regions){ 248 | if (is.null(span_bp) == FALSE) span <- span_bp/twisst_object$length[[i]] 249 | 250 | if (is.null(new_positions) == TRUE){ 251 | if (is.null(spacing) == TRUE) spacing <- twisst_object$length[[i]]*span*.1 252 | new_positions <- seq(twisst_object$pos[[i]][1], tail(twisst_object$pos[[i]],1), spacing) 253 | } 254 | 255 | l$pos[[i]] <- new_positions 256 | 257 | interval_lengths <- twisst_object$interval_data[[i]][,3] - twisst_object$interval_data[[i]][,2] + 1 258 | 259 | l$weights[[i]] <- smooth.weights(twisst_object$pos[[i]], twisst_object$weights[[i]], new_positions = new_positions, span = span, interval_lengths = interval_lengths) 260 | } 261 | 262 | names(l$weights) <- names(twisst_object$weights) 263 | 264 | l 265 | } 266 | 267 | is.hex.col <- function(string){ 268 | strvec <- strsplit(string, "")[[1]] 269 | if (strvec[1] != "#") return(FALSE) 270 | if (length(strvec) != 7 & length(strvec) != 9) return(FALSE) 271 | for (character in strvec[-1]){ 272 | if (!(character %in% c("0","1","2","3","4","5","6","7","8","9","a","b","c","d","e","f","A","B","C","D","E","F"))) return(FALSE) 273 | } 274 | TRUE 275 | } 276 | 277 | hex.transparency <- function(hex, transstring="88"){ 278 | if (is.hex.col(hex)==FALSE){ 279 | print("WARNING: colour not hexadecimal. Cannot modify transparency.") 280 | return(hex) 281 | } 282 | if (nchar(hex) == 7) return(paste0(hex, transstring)) 283 | else { 284 | substr(hex,8,9) <- transstring 285 | return(hex) 286 | } 287 | } 288 | 289 | 290 | plot.twisst <- function(twisst_object, show_topos=TRUE, rel_height=3, ncol_topos=NULL, regions=NULL, ncol_weights=1, 291 | cols=topo_cols, tree_type="clad", tree_x_lim=c(0,5), xlim=NULL, ylim=NULL, xlab=NULL, xaxt=NULL, 292 | mode=2, margins = c(4,4,2,2), concatenate=FALSE, gap=0, include_region_names=FALSE){ 293 | 294 | #check if there are enough colours 295 | if (length(twisst_object$topos) > length(cols)){ 296 | print("Not enough colours provided (option 'cols'), using rainbow instead") 297 | cols = rainbow(length(twisst_object$topos)) 298 | } 299 | 300 | if (mode==5) { 301 | stacked=TRUE 302 | weights_order=1:length(twisst_object$topos) 303 | fill_cols = cols 304 | line_cols = NA 305 | lwd = 0 306 | } 307 | 308 | if (mode==4) { 309 | stacked=FALSE 310 | weights_order=order(apply(twisst_data$weights_mean, 2, mean, na.rm=T)[1:length(twisst_object$topos)], decreasing=T) #so that the largest values are at the back 311 | fill_cols = cols 312 | line_cols = NA 313 | lwd=par("lwd") 314 | } 315 | 316 | if (mode==3) { 317 | stacked=FALSE 318 | weights_order=order(apply(twisst_data$weights_mean, 2, mean, na.rm=T)[1:length(twisst_object$topos)], decreasing=T) #so that the largest values are at the back 319 | fill_cols = sapply(cols, hex.transparency, transstring="80") 320 | line_cols = cols 321 | lwd=par("lwd") 322 | } 323 | 324 | if (mode==2) { 325 | stacked=FALSE 326 | weights_order=order(apply(twisst_data$weights_mean, 2, mean, na.rm=T)[1:length(twisst_object$topos)], decreasing=T) #so that the largest values are at the back 327 | fill_cols = sapply(cols, hex.transparency, transstring="80") 328 | line_cols = NA 329 | lwd=par("lwd") 330 | } 331 | 332 | if (mode==1) { 333 | stacked=FALSE 334 | weights_order=1:length(twisst_object$topos) 335 | fill_cols = NA 336 | line_cols = cols 337 | lwd=par("lwd") 338 | } 339 | 340 | if (is.null(regions)==TRUE) regions <- 1:twisst_object$n_regions 341 | 342 | if (concatenate == TRUE) ncol_weights <- 1 343 | 344 | if (show_topos==TRUE){ 345 | n_topos <- length(twisst_object$topos) 346 | 347 | if (is.null(ncol_topos)) ncol_topos <- n_topos 348 | 349 | #if we have too few topologies to fill the spaces in the plot, we can pad in the remainder 350 | topos_pad <- (n_topos * ncol_weights) %% (ncol_topos*ncol_weights) 351 | 352 | topos_layout_matrix <- matrix(c(rep(1:n_topos, each=ncol_weights), rep(0, topos_pad)), 353 | ncol=ncol_topos*ncol_weights, byrow=T) 354 | } 355 | else { 356 | ncol_topos <- 1 357 | n_topos <- 0 358 | topos_layout_matrix <- matrix(NA, nrow= 0, ncol=ncol_topos*ncol_weights) 359 | } 360 | 361 | #if we have too few regions to fill the spaces in the plot, we pad in the remainder 362 | data_pad <- (length(regions)*ncol_topos) %% (ncol_topos*ncol_weights) 363 | 364 | if (concatenate==TRUE) weights_layout_matrix <- matrix(rep(n_topos+1,ncol_topos), nrow=1) 365 | else { 366 | weights_layout_matrix <- matrix(c(rep(n_topos+(1:length(regions)), each=ncol_topos),rep(0,data_pad)), 367 | ncol=ncol_topos*ncol_weights, byrow=T) 368 | } 369 | 370 | layout(rbind(topos_layout_matrix, weights_layout_matrix), 371 | height=c(rep(1, nrow(topos_layout_matrix)), rep(rel_height, nrow(weights_layout_matrix)))) 372 | 373 | if (show_topos == TRUE){ 374 | if (tree_type=="unrooted"){ 375 | par(mar=c(1,2,3,2), xpd=NA) 376 | 377 | for (i in 1:n_topos){ 378 | plot.phylo(twisst_object$topos[[i]], type = tree_type, edge.color=cols[i], 379 | edge.width=5, label.offset=0.3, cex=1, rotate.tree = 90, x.lim=tree_x_lim) 380 | mtext(side=3,text=paste0("topo",i), cex=0.75) 381 | } 382 | } 383 | else{ 384 | par(mar=c(1,1,1,1), xpd=NA) 385 | 386 | for (i in 1:n_topos){ 387 | plot.phylo(twisst_object$topos[[i]], type = tree_type, edge.color=cols[i], 388 | edge.width=5, label.offset=0.3, cex=1, rotate.tree = 0, x.lim=tree_x_lim) 389 | mtext(side=3,text=paste0("topo",i," "), cex=0.75) 390 | } 391 | } 392 | } 393 | 394 | if (is.null(ylim)==TRUE) ylim <- c(0,1) 395 | 396 | par(mar=margins, xpd=FALSE) 397 | 398 | if (concatenate == TRUE) { 399 | chrom_offsets = cumsum(twisst_object$lengths + gap) - (twisst_object$lengths + gap) 400 | chrom_ends <- chrom_offsets + twisst_object$lengths 401 | 402 | plot(0, pch = "",xlim = c(chrom_offsets[1],tail(chrom_ends,1)), ylim = ylim, 403 | ylab = "", yaxt = "n", xlab = xlab, , xaxt = "n", bty = "n", main = "") 404 | 405 | for (j in regions) { 406 | if (is.null(twisst_object$interval_data[[j]])) positions <- twisst_object$pos[[j]] + chrom_offsets[j] 407 | else positions <- twisst_object$interval_data[[j]][,c("start","end")] + chrom_offsets[j] 408 | plot.weights(twisst_object$weights[[j]][weights_order], positions, xlim=xlim, 409 | fill_cols = fill_cols[weights_order], line_cols=line_cols[weights_order],lwd=lwd,stacked=stacked, add=T) 410 | } 411 | } 412 | else{ 413 | for (j in regions){ 414 | if (is.null(twisst_object$interval_data[[j]])) positions <- twisst_object$pos[[j]] 415 | else positions <- twisst_object$interval_data[[j]][,c("start","end")] 416 | plot.weights(twisst_object$weights[[j]][weights_order], positions, xlim=xlim, ylim = ylim, 417 | xlab=xlab, fill_cols = fill_cols[weights_order], line_cols=line_cols[weights_order],lwd=lwd,stacked=stacked, xaxt=xaxt) 418 | if (include_region_names==TRUE) mtext(3, text=names(twisst_object$weights)[j], adj=0, cex=0.75) 419 | } 420 | } 421 | } 422 | 423 | 424 | #function for plotting tree that uses ape to get node positions 425 | draw.tree <- function(phy, x, y, x_scale=1, y_scale=1, method=1, direction="right", 426 | col="black", col.label="black", add_labels=TRUE, add_symbols=FALSE, 427 | label_offset = 1, symbol_offset=0, col.symbol="black",symbol_bg="NA", 428 | pch=19, cex=NULL, lwd=NULL, label_alias=NULL){ 429 | 430 | n_tips = length(phy$tip.label) 431 | 432 | if (direction=="right") { 433 | node_x = (node.depth(phy, method=method) - 1) * x_scale * -1 434 | node_y = node.height(phy) * y_scale 435 | label_x = node_x[1:n_tips] + label_offset 436 | label_y = node_y[1:n_tips] 437 | adj_x = 0 438 | adj_y = .5 439 | symbol_x = node_x[1:n_tips] + symbol_offset 440 | symbol_y = node_y[1:n_tips] 441 | } 442 | if (direction=="down") { 443 | node_y = (node.depth(phy, method=method) - 1) * y_scale * 1 444 | node_x = node.height(phy) * x_scale 445 | label_x = node_x[1:n_tips] 446 | label_y = node_y[1:n_tips] - label_offset 447 | adj_x = .5 448 | adj_y = 1 449 | symbol_x = node_x[1:n_tips] 450 | symbol_y = node_y[1:n_tips] - symbol_offset 451 | } 452 | 453 | #draw edges 454 | segments(x + node_x[phy$edge[,1]], y + node_y[phy$edge[,1]], 455 | x + node_x[phy$edge[,2]], y + node_y[phy$edge[,2]], col=col, lwd=lwd) 456 | 457 | if (is.null(label_alias) == FALSE) tip_labels <- label_alias[phy$tip.label] 458 | else tip_labels <- phy$tip.label 459 | 460 | if (add_labels=="TRUE") text(x + label_x, y + label_y, col = col.label, labels=tip_labels, adj=c(adj_x,adj_y),cex=cex) 461 | if (add_symbols=="TRUE") points(x + symbol_x, y + symbol_y, pch = pch, col=col.symbol, bg=symbol_bg) 462 | 463 | } 464 | 465 | #code for plotting a summary barplot 466 | plot.twisst.summary <- function(twisst_object, order_by_weights=TRUE, only_best=NULL, cols=topo_cols, 467 | x_scale=0.12, y_scale=0.15, direction="right", col="black", col.label="black", 468 | label_offset = 0.05, lwd=NULL, cex=NULL){ 469 | 470 | #check if there are enough colours 471 | if (length(twisst_object$topos) > length(cols)){ 472 | print("Not enough colours provided (option 'cols'), using rainbow instead") 473 | cols = rainbow(length(twisst_object$topos)) 474 | } 475 | 476 | # Either order 1-15 or order with highest weigted topology first 477 | 478 | if (order_by_weights == TRUE) { 479 | ord <- order(twisst_object$weights_overall_mean[1:length(twisst_object$topos)], decreasing=T) 480 | if (is.null(only_best) == FALSE) ord=ord[1:only_best] 481 | } 482 | else ord <- 1:length(twisst_object$topos) 483 | 484 | N=length(ord) 485 | 486 | #set the plot layout, with the tree panel one third the height of the barplot panel 487 | layout(matrix(c(2,1)), heights=c(1,3)) 488 | 489 | par(mar = c(1,4,.5,1)) 490 | 491 | #make the barplot 492 | x=barplot(twisst_object$weights_overall_mean[ord], col = cols[ord], 493 | xaxt="n", las=1, ylab="Average weighting", space = 0.2, xlim = c(0.2, 1.2*N)) 494 | 495 | #draw the trees 496 | #first make an empty plot for the trees. Ensure left and right marhins are the same 497 | par(mar=c(0,4,0,1)) 498 | plot(0,cex=0,xlim = c(0.2, 1.2*N), xaxt="n",yaxt="n",xlab="",ylab="",ylim=c(0,1), bty="n") 499 | 500 | #now run the draw.tree function for each topology. You can set x_scale and y_scale to alter the tree width and height. 501 | for (i in 1:length(ord)){ 502 | draw.tree(twisst_object$topos[[ord[i]]], x=x[i]+.2, y=0, x_scale=x_scale, y_scale=y_scale, 503 | col=cols[ord[i]], label_offset=label_offset, cex=cex, lwd=lwd) 504 | } 505 | 506 | #add labels for each topology 507 | text(x,.9,names(twisst_object$topos)[ord],col=cols[ord]) 508 | } 509 | 510 | 511 | #code for plotting a summary boxplot 512 | plot.twisst.summary.boxplot <- function(twisst_object, order_by_weights=TRUE, only_best=NULL, trees_below=FALSE, cols=topo_cols, 513 | x_scale=0.12, y_scale=0.15, direction="right", col="black", col.label="black", 514 | label_offset = 0.05, lwd=NULL, label_alias=NULL, cex=NULL, outline=FALSE, 515 | cex.outline=NULL, lwd.box=NULL, topo_names=NULL){ 516 | 517 | #check if there are enough colours 518 | if (length(twisst_object$topos) > length(cols)){ 519 | print("Not enough colours provided (option 'cols'), using rainbow instead") 520 | cols = rainbow(length(twisst_object$topos)) 521 | } 522 | 523 | # Either order 1-15 or order with highest weigted topology first 524 | 525 | if (order_by_weights == TRUE) { 526 | ord <- order(twisst_object$weights_overall_mean[1:length(twisst_object$topos)], decreasing=T) 527 | if (is.null(only_best) == FALSE) ord=ord[1:only_best] 528 | } 529 | else ord <- 1:length(twisst_object$topos) 530 | 531 | N=length(ord) 532 | 533 | #set the plot layout, with the tree panel one third the height of the barplot panel 534 | if (trees_below==FALSE) layout(matrix(c(2,1)), heights=c(1,3)) 535 | else layout(matrix(c(1,2)), heights=c(3,1)) 536 | 537 | 538 | par(mar = c(1,4,.5,1), bty="n") 539 | 540 | #make the barplot 541 | boxplot(as.data.frame(rbindlist(twisst_object$weights))[,ord], col = cols[ord], 542 | xaxt="n", las=1, xlim = c(.5, N+.5), ylab="Weighting", outline=outline, cex=cex.outline, lwd=lwd.box) 543 | 544 | #draw the trees 545 | #first make an empty plot for the trees. Ensure left and right marhins are the same 546 | par(mar=c(0,4,0,1)) 547 | plot(0,cex=0, xlim = c(.5, N+.5), xaxt="n",yaxt="n",xlab="",ylab="",ylim=c(0,1), bty="n") 548 | 549 | #now run the draw.tree function for each topology. You can set x_scale and y_scale to alter the tree width and height. 550 | for (i in 1:N){ 551 | draw.tree(twisst_object$topos[[ord[i]]], i+.1, y=0, x_scale=x_scale, y_scale=y_scale, 552 | col=cols[ord[i]], label_offset=label_offset, cex=cex, lwd=lwd, label_alias=label_alias) 553 | } 554 | 555 | if (is.null(topo_names)==TRUE) topo_names <- names(twisst_object$topos) 556 | 557 | #add labels for each topology 558 | text(1:N,.9,topo_names[ord],col=cols[ord], cex=cex) 559 | } 560 | 561 | #function for subsetting the twisst object by a set of topologies 562 | subset.twisst.by.topos <- function(twisst_object, topos){ 563 | l <- list() 564 | regions <- names(twisst_object$weights) 565 | l$interval_data <- twisst_object$interval_data 566 | l$n_regions <- twisst_object$n_regions 567 | l$lengths <- twisst_object$lengths 568 | l$pos <- twisst_object$pos 569 | l$topocounts <- sapply(regions, function(region) twisst_data$topocounts[[region]][,topos], simplify=F) 570 | l$weights <- sapply(regions, function(region) twisst_object$weights[[region]][,topos], simplify=F) 571 | l$weights_mean <- l$weights_mean[,topos] 572 | l$weights_overall_mean <- twisst_object$weights_overall_mean[topos] 573 | l$topos <- twisst_object$topos[topos] 574 | l 575 | } 576 | 577 | #function for subsetting the twisst object by specific regions 578 | subset.twisst.by.regions <- function(twisst_object, regions){ 579 | l <- list() 580 | regions <- names(twisst_object$weights[regions]) 581 | l$interval_data <- twisst_object$interval_data[regions] 582 | l$n_regions <- length(regions) 583 | l$lengths <- twisst_object$lengths[regions] 584 | l$pos <- twisst_object$pos[regions] 585 | l$topocounts <- twisst_object$topocounts[regions] 586 | l$weights <- twisst_object$weights[regions] 587 | l$weights_mean <- twisst_object$weights_mean[regions] 588 | weights_totals <- apply(t(sapply(l$weights, apply, 2, sum, na.rm=T)), 2, sum) 589 | l$weights_overall_mean <- weights_totals / sum(weights_totals) 590 | l$topos <- twisst_object$topos 591 | l 592 | } 593 | 594 | rbindlist <- function(l){ 595 | df <- l[[1]] 596 | if (length(l) > 1){ 597 | for (i in 2:length(l)){ 598 | df <- rbind(df, l[[i]]) 599 | } 600 | } 601 | df 602 | } 603 | 604 | 605 | palettes = list( 606 | sashamaps = c( #https://sashamaps.net/docs/resources/20-colors/ 607 | '#4363d8', #azure blue 608 | '#3cb44b', #grass green 609 | '#ffe119', #custard yellow 610 | '#e6194b', #pomergranate red 611 | '#f58231', #chalk orange 612 | '#911eb4', #violet 613 | '#46f0f0', #cyan 614 | '#f032e6', #hot pink 615 | '#000075', #navy blue 616 | '#fabebe', #pale blush 617 | '#008080', #teal 618 | '#e6beff', #mauve 619 | '#9a6324', #coconut bown 620 | '#800000', #red leather 621 | '#aaffc3', #light sage 622 | '#808000', #olive 623 | '#ffd8b1', #white skin 624 | '#bcf60c', #green banana 625 | '#808080', #grey 626 | '#fffac8'), #beach sand 627 | 628 | alphabet = c( #https://en.wikipedia.org/wiki/Help:Distinguishable_colors 629 | "#0075DC", #Blue 630 | "#FFA405", #Orpiment 631 | "#2BCE48", #Green 632 | "#003380", #Navy 633 | "#F0A3FF", #Amethyst 634 | "#990000", #Wine 635 | "#9DCC00", #Lime 636 | "#8F7C00", #Khaki 637 | "#FFE100", #Yellow 638 | "#C20088", #Mallow 639 | "#FFCC99", #Honeydew 640 | "#5EF1F2", #Sky 641 | "#00998F", #Turquoise 642 | "#FF5005", #Zinnia 643 | "#4C005C", #Damson 644 | "#426600", #Quagmire 645 | "#005C31", #Forest 646 | "#FFA8BB", #Pink 647 | "#740AFF", #Violet 648 | "#E0FF66", #Uranium 649 | "#993F00", #Caramel 650 | "#94FFB5", #Jade 651 | "#FFFF80", #Xanthin 652 | "#FF0010", #Red 653 | "#191919", #Ebony 654 | "#808080"), #Iron 655 | 656 | 657 | krzywinski = c( #https://mk.bcgsc.ca/colorblind/palettes.mhtml#15-color-palette-for-colorbliness 658 | "#68023F", #104 2 63 imperial purple, tyrian purple, nightclub, pompadour, ribbon, deep cerise, mulberry, mulberry wood, pansy purple, merlot 659 | "#008169", # 0 129 105 deep sea, generic viridian, observatory, deep sea, elf green, deep green cyan turquoise, tropical rain forest, tropical rain forest, blue green, elf green 660 | "#EF0096", #239 0 150 vivid cerise, persian rose, fashion fuchsia, hollywood cerise, neon pink, luminous vivid cerise, shocking pink, deep pink, deep pink, fluorescent pink 661 | "#00DCB5", # 0 220 181 aquamarine, aqua marine, caribbean green, aqua, eucalyptus, vivid opal, brilliant turquoise, shamrock, shamrock, caribbean green 662 | "#FFCFE2", #255 207 226 light pink, azalea, classic rose, pale pink, classic rose, pastel pink, pink lace, pig pink, orchid pink, chantilly 663 | "#003C86", # 0 60 134 flat medium blue, royal blue, bay of many, bondi blue, congress blue, cobalt, elvis, darkish blue, submerge, yale blue 664 | "#9400E6", #148 0 230 vivid purple, violet, purple, veronica, purple, electric purple, purple, vivid mulberry, vivid purple, vivid violet 665 | "#009FFA", # 0 159 250 azure, luminous vivid cornflower blue, vivid cornflower blue, brilliant azure, bleu de france, dark sky blue, brilliant azure, united nations blue, light brilliant cobalt blue, cornflower 666 | "#FF71FD", #255 113 253 blush pink, shocking pink, ultra pink, pink flamingo, fuchsia pink, light brilliant orchid, pink flamingo, candy pink, light brilliant magenta, light magenta 667 | "#7CFFFA", #124 255 250 light brilliant cyan, electric blue, dark slate grey, very light cyan, very light opal, bright cyan, dark slate grey, brilliant cyan, light brilliant opal, bright light blue 668 | "#6A0213", #106 2 19 burnt crimson, rosewood, claret, venetian red, dark tan, persian plum, prune, deep amaranth, deep reddish brown, crown of thorns 669 | "#008607", # 0 134 7 india green, ao, green, office green, web green, green, emerald green, islamic green, forest green, forest green 670 | "#F60239", #246 2 57 american rose, red, carmine, electric crimson, luminous vivid crimson, neon red, carmine red, scarlet, torch red, tractor red 671 | "#00E307", # 0 227 7 vivid sap green, vivid emerald green, vivid green, vibrant green, vivid green, vivid harlequin, green, radioactive green, vivid malachite green, vivid pistachio 672 | "#FFDC3D"), #255 220 61 gargoyle gas, filmpro lemon yellow, banana yellow, golden dream, bright sun, broom, banana split, golden dream, twentyfourseven, wild thing 673 | 674 | 675 | safe = c( #from cartomap package 676 | "#88CCEE", 677 | "#CC6677", 678 | "#DDCC77", 679 | "#117733", 680 | "#332288", 681 | "#AA4499", 682 | "#44AA99", 683 | "#999933", 684 | "#882255", 685 | "#661100", 686 | "#6699CC", 687 | "#888888"), 688 | 689 | cud = c( 690 | "#E69F00", 691 | "#56B4E9", 692 | "#009E73", 693 | "#F0E442", 694 | "#0072B2", 695 | "#D55E00", 696 | "#CC79A7", 697 | "#000000", 698 | "#999999")) 699 | 700 | show_palettes <- function(){ 701 | par(mfrow = c(length(palettes), 1), mar = c(2,0,2,0)) 702 | 703 | for (name in names(palettes)){ 704 | n = length(palettes[[name]]) 705 | plot(0, cex=0, xlim = c(1,n+1), ylim = c(0,1), xaxt="n", yaxt="n", bty="n", main=name) 706 | rect(1:n, 0, 2:(n+1), 1, col=palettes[[name]]) 707 | axis(1, at=(1:n)+0.5, labels = palettes[[name]], tick=F) 708 | } 709 | } 710 | 711 | -------------------------------------------------------------------------------- /twisst2/twisst2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import random 4 | import argparse 5 | import gzip 6 | import sys 7 | import itertools 8 | import numpy as np 9 | from twisst2.TopologySummary import * 10 | from twisst2.parse_newick import * 11 | from sticcs import sticcs 12 | import cyvcf2 13 | 14 | try: import tskit 15 | except ImportError: 16 | pass 17 | 18 | def get_combos(groups, max_subtrees, ploidies, random_seed=None): 19 | total_subtrees = np.prod([ploidies[group].sum() for group in groups]) 20 | if total_subtrees <= max_subtrees: 21 | for combo in itertools.product(*groups): 22 | yield combo 23 | else: 24 | #doing only a subset, so need to keep track of how many subtrees have been done 25 | subtrees_done = 0 26 | #random number generator for repeatibility 27 | rng = random.Random(random_seed) 28 | while True: 29 | combo = [rng.choice(group) for group in groups] 30 | yield tuple(combo) 31 | subtrees_done += np.prod(ploidies[combo]) 32 | if subtrees_done >= max_subtrees: break 33 | 34 | 35 | ##older code when we were specifiying number of combinations (which results in different total subtrees depending on ploidy) 36 | #else: 37 | #assert max_combos, "Please specify either maximum number of combos or subtrees" 38 | #total_combos = np.prod([len(t) for t in groups]) 39 | #if total_combos <= max_combos: 40 | #for combo in itertools.product(*groups): 41 | #yield combo 42 | #else: 43 | #for i in range(max_combos): 44 | #yield tuple(random.choice(group) for group in groups) 45 | 46 | 47 | ### The code below has been retired. It was used for stacking topocounts with non-identical sets of intervals, but thanks to Claude.ai, a faster way to do this was divised. 48 | 49 | #def contains_interval(intervals, interval): 50 | #return (intervals[:,0] <= interval[0]) & (interval[1] <= intervals[:,1]) 51 | 52 | 53 | #def get_unique_intervals(intervals, verbose=True): 54 | 55 | #intervals.sort(axis=1) 56 | #intervals = intervals[intervals[:,1].argsort()] 57 | #intervals = intervals[intervals[:,0].argsort()] 58 | 59 | #intervals = np.unique(intervals, axis=0) 60 | 61 | #starts = intervals[:,0] 62 | #starts = np.unique(starts) 63 | #n_starts = starts.shape[0] 64 | 65 | #ends = intervals[:,1] 66 | #ends.sort() 67 | #ends = np.unique(ends) 68 | #n_ends = ends.shape[0] 69 | 70 | #current_start = starts[0] 71 | #next_start_idx = 1 72 | #next_end_idx = 0 73 | #next_start = starts[next_start_idx] 74 | #next_end = ends[next_end_idx] 75 | 76 | #output = [] 77 | 78 | #while True: 79 | #if next_start and next_start <= next_end: 80 | #if next_start == current_start: 81 | ##if we get here, the previous end was exactly 1 before the next start, so we jump to the next start 82 | #next_start_idx += 1 83 | #next_start = starts[next_start_idx] if next_start_idx < n_starts else None 84 | #else: 85 | ##we have to cut before this next start and start again 86 | #output.append((current_start, next_start-1)) 87 | #current_start = next_start 88 | #next_start_idx += 1 89 | #next_start = starts[next_start_idx] if next_start_idx < n_starts else None 90 | #else: 91 | ##otherwise we use the end and start at next position 92 | #output.append((current_start, next_end)) 93 | #current_start = next_end + 1 94 | #next_end_idx += 1 95 | #if next_end_idx == n_ends: break 96 | #next_end = ends[next_end_idx] 97 | 98 | ##retain only those that are contained within the starting intervals 99 | #output = [interval for interval in output if np.any(contains_interval(intervals, interval))] 100 | 101 | #return np.array(output) 102 | 103 | 104 | def show_tskit_tree(tree, node_labels=None): print(tree.draw(format="unicode", node_labels=node_labels)) 105 | 106 | def is_bifurcating(tree): 107 | for node in tree.nodes(): 108 | if len(tree.children(node)) > 2: 109 | return False 110 | return True 111 | 112 | def number_of_possible_trees(n_tips, include_ploytomies=False): 113 | i = 0 114 | for tree in tskit.all_trees(n_tips): 115 | if include_ploytomies or is_bifurcating(tree): i += 1 116 | return i 117 | 118 | def list_all_topologies_tskit(n, include_polytomies = False): 119 | assert n <= 7, "There are 135135 rooted topologies with 8 tips (660032 including polytomies). I doubt you want to list all of those." 120 | if include_polytomies: return list(tskit.all_trees(n)) 121 | return [tree for tree in tskit.all_trees(n) if is_bifurcating(tree)] 122 | 123 | def all_possible_patterns(n, lower=2, upper=None): 124 | if upper: assert upper <= n 125 | else: upper=n 126 | for k in range(lower,upper+1): 127 | for indices in list(itertools.combinations(range(n), k)): 128 | pattern = np.zeros(n, dtype=int) 129 | pattern[(indices,)] = 1 130 | yield pattern 131 | 132 | 133 | def contains_array(arrayA, arrayB): 134 | assert arrayA.shape[1] == arrayB.shape[1] 135 | if arrayA.shape[0] < arrayB.shape[0]: return False 136 | for b in arrayB: 137 | contained = False 138 | for a in arrayA: 139 | if np.all(a == b): contained = True 140 | if not contained: 141 | return False 142 | return True 143 | 144 | def all_possible_clusters(n, lower=2, upper=None): 145 | if upper: assert upper <= n 146 | else: upper=n-1 147 | ploidies = np.array([1]*n) 148 | #make first set of cluster pairs 149 | clusters = [] 150 | size = 0 151 | while True: 152 | size += 1 153 | found_one = False 154 | for new_cluster in itertools.combinations(all_possible_patterns(n, lower, upper), size): 155 | fail=False 156 | new_cluster = np.array(new_cluster) 157 | for i in range(1,size): 158 | for j in range(i): 159 | if not sticcs.passes_derived_gamete_test(new_cluster[i],new_cluster[j], ploidies): 160 | fail = True 161 | break 162 | if fail: 163 | break 164 | 165 | if not fail: 166 | found_one = True 167 | #check if an old cluster is contained within the new one 168 | clusters = [cluster for cluster in clusters if not contains_array(new_cluster, cluster)] + [new_cluster] 169 | 170 | if not found_one: 171 | return clusters 172 | 173 | def list_all_topologies(n): 174 | assert n <= 7, "There are 135135 rooted topologies with 8 tips (660032 including bifurcations). I doubt you want to list all of those." 175 | clusters = all_possible_clusters(n) 176 | topos = [sticcs.patterns_to_tree(cluster) for cluster in clusters] 177 | return topos 178 | 179 | def make_topoDict(n, unrooted=False): 180 | topos = [] 181 | topoIDs = [] 182 | for topo in list_all_topologies(n): 183 | topoSummary = TopologySummary(topo) 184 | ID = topoSummary.get_topology_ID(list(range(n)), unrooted=unrooted) 185 | if ID not in topoIDs: #this is because multiple topos are the same when unrooted, so we just use the first 186 | topoIDs.append(ID) 187 | topos.append(topo) 188 | 189 | return {"topos":topos, "topoIDs":topoIDs} 190 | 191 | def make_topoDict_tskit(n, include_polytomies=False): 192 | topos = list_all_topologies_tskit(n, include_polytomies=include_polytomies) 193 | ranks = [t.rank() for t in topos] 194 | return {"topos":topos, "ranks":ranks} 195 | 196 | 197 | class Topocounts: 198 | def __init__(self, topos, counts, totals=None, intervals=None, label_dict=None, rooted=None): 199 | assert counts.shape[1] == len(topos) 200 | if intervals is not None: 201 | assert intervals.shape[1] == 2 202 | assert counts.shape[0] == intervals.shape[0], f"There are {counts.shape[0]} rows of counts but {intervals.shape[0]} intervals." 203 | self.topos = topos 204 | self.counts = counts 205 | if totals is not None: 206 | assert np.all(totals >= self.counts.sum(axis=1).round(3)) 207 | self.totals = totals 208 | else: self.totals = counts.sum(axis=1) 209 | self.intervals = intervals 210 | self.label_dict = label_dict 211 | self.rooted = rooted 212 | 213 | def unroot(self): 214 | #returns a Topocounts instance reduced to unrooted version of each topology 215 | topoSummaries = [TopologySummary(topo) for topo in self.topos] 216 | 217 | topoIDs_unrooted = [topoSummary.get_topology_ID(unrooted=True) for topoSummary in topoSummaries] 218 | 219 | indices = defaultdict(list) 220 | for i,ID in enumerate(topoIDs_unrooted): indices[ID].append(i) 221 | 222 | counts_unrooted = np.column_stack([self.counts[:,idx].sum(axis=1) for idx in indices.values()]) 223 | 224 | new_topos = [self.topos[idx[0]] for idx in indices.values()] 225 | 226 | return Topocounts(new_topos, counts_unrooted, totals=self.totals, intervals=self.intervals, label_dict=self.label_dict, rooted=False) 227 | 228 | def fill_gaps(self, seq_start = None, seq_end = None): 229 | 230 | if seq_start is not None: 231 | assert seq_start <= self.intervals[0,0], "First interval starts before the stated start" 232 | self.intervals[0,0] = seq_start 233 | 234 | if seq_end is not None: 235 | assert seq_end >= self.intervals[-1,1], "Last interval ends after the stated end" 236 | self.intervals[-1,1] = seq_end 237 | 238 | for i in range(1, self.intervals.shape[0]): 239 | new_start = np.ceil((self.intervals[i-1,1] + self.intervals[i,0])/2) 240 | self.intervals[i,0] = new_start 241 | self.intervals[i-1,1] = new_start-1 242 | 243 | def simplify(self, fill_gaps=True, seq_start = None, seq_end = None): 244 | new_counts_list = [self.counts[0]] 245 | new_intervals_list = [self.intervals[0][:]] 246 | new_totals_list = [self.totals[0]] 247 | 248 | for j in range(1, self.intervals.shape[0]): 249 | 250 | if np.all(new_counts_list[-1] == self.counts[j]): 251 | new_intervals_list[-1][1] = self.intervals[j][-1] 252 | else: 253 | new_intervals_list.append(self.intervals[j][:]) 254 | new_counts_list.append(self.counts[j]) 255 | new_totals_list.append(self.totals[j]) 256 | 257 | simplified = Topocounts(self.topos, np.array(new_counts_list, dtype=int), 258 | totals=np.array(new_totals_list, dtype=int), 259 | intervals=np.array(new_intervals_list), label_dict=self.label_dict, rooted=self.rooted) 260 | 261 | if fill_gaps: 262 | simplified.fill_gaps(seq_start=seq_start, seq_end=seq_end) 263 | 264 | return simplified 265 | 266 | def split_intervals(self, split_len=None, new_intervals=None): 267 | #function to split topocounts into a narrower set of intervals 268 | #if new intervals are given, it will be split at those points and return any splits between those points 269 | 270 | if split_len != None: 271 | new_intervals = np.array([(x, x+split_len-1) for x in range(1, int(self.intervals[-1][1]), split_len)]) 272 | else: assert new_intervals is not None, "Either provide a split length or new intervals to define the split locations" 273 | 274 | ncol = self.counts.shape[1] 275 | nrow = new_intervals.shape[0] 276 | 277 | dummy_topocounts = Topocounts(topos =self.topos, counts=np.zeros(shape=(nrow,ncol), dtype=int), intervals = new_intervals) 278 | 279 | return stack_topocounts([self, dummy_topocounts], silent=True) 280 | 281 | #def split_intervals(self, split_len=None, new_intervals=None): 282 | ##function to split topocounts into a narrower set of intervals 283 | ##if new intervals are given, it will be split at those points and return any splits between those points 284 | 285 | #if split_len != None: 286 | #new_intervals = np.array([(x, x+split_len-1) for x in range(1, int(self.intervals[-1][1]), split_len)]) 287 | #else: assert new_intervals is not None, "Either provide a split length or new intervals to define the split locations" 288 | 289 | #unique_intervals = get_unique_intervals(np.row_stack([self.intervals, new_intervals])) 290 | 291 | #n_unique_intervals = len(unique_intervals) 292 | 293 | #_counts_ = np.zeros((n_unique_intervals, self.counts.shape[1]), dtype=int) 294 | #_totals_ = np.zeros(n_unique_intervals, dtype=int) 295 | 296 | ##for each iteration we check each of its intervals 297 | #k=0 298 | #max_k = self.intervals.shape[0] 299 | #for j in range(n_unique_intervals): 300 | ##for each of the unique intervals - if nested within the one of the starting intervals add it 301 | #if self.intervals[k,0] <= unique_intervals[j,0] and self.intervals[k,1] >= unique_intervals[j,1]: 302 | ##k contains j 303 | #_counts_[j,:] = self.counts[k] 304 | #_totals_[j] = self.totals[k] 305 | 306 | #if self.intervals[k,1] == unique_intervals[j,1]: 307 | ##both are ending, so advance k to the next interval in this iteration 308 | #k += 1 309 | #if k == max_k: break 310 | 311 | #return Topocounts(self.topos, _counts_, _totals_, unique_intervals, self.label_dict) 312 | 313 | def recast(self, interval_len=None, new_intervals=None): 314 | #function to project topocounts to a new set of intervals 315 | #only possible if the totals are the same for all intervals (i.e. no subssampling of combinations) 316 | #if you just want to split to smaller intervals use split_intervals 317 | #the resulting counts will be floats because they represent a weighted average 318 | 319 | assert len(set(self.totals)) == 1, "Cannot merge intervals with different total number of combinations." 320 | 321 | if interval_len != None: 322 | new_intervals = np.array([(x, x+interval_len-1) for x in range(1, int(self.intervals[-1][1]), interval_len)]) 323 | else: assert new_intervals is not None, "Either provide a split length or new intervals to define the split locations" 324 | 325 | n_new_intervals = len(new_intervals) 326 | 327 | #first make a split set of topocounts 328 | topocounts_split = self.split_intervals(new_intervals=new_intervals) 329 | 330 | _totals_ = np.array([self.totals[0]]*n_new_intervals) 331 | 332 | #now begin merging 333 | _counts_ = np.zeros((n_new_intervals, self.counts.shape[1]), dtype=float) 334 | 335 | k=0 336 | max_k = n_new_intervals 337 | current_indices = [] 338 | for j in range(topocounts_split.intervals.shape[0]): 339 | #for each of the unique intervals - if nested within the new interval add the count to a bin 340 | if new_intervals[k,0] <= topocounts_split.intervals[j,0] and topocounts_split.intervals[j,1] <= new_intervals[k,1]: 341 | current_indices.append(j) 342 | 343 | if topocounts_split.intervals[j,1] == new_intervals[k,1]: 344 | #end of new interval record mean 345 | interval_lengths = np.diff(topocounts_split.intervals[current_indices], axis=1)[:,0] + 1 # weights will be the lengths 346 | _counts_[k] = np.average(topocounts_split.counts[current_indices], axis=0, weights=interval_lengths) 347 | current_indices = [] 348 | k += 1 349 | if k == max_k: break 350 | 351 | return Topocounts(self.topos, _counts_, _totals_, new_intervals, self.label_dict) 352 | 353 | 354 | def get_interval_lengths(self): 355 | return np.diff(self.intervals, axis=1)[:,0] + 1 356 | 357 | def get_weights(self): 358 | return self.counts / self.totals.reshape((len(self.totals),1)) 359 | 360 | def write(self, outfile, include_topologies=True, include_header=True): 361 | nTopos = len(self.topos) 362 | if include_topologies: 363 | for x in range(nTopos): outfile.write("#topo" + str(x+1) + " " + self.topos[x].as_newick(node_labels=self.label_dict) + "\n") 364 | if include_header: 365 | outfile.write("\t".join(["topo" + str(x+1) for x in range(nTopos)]) + "\tother\n") 366 | #write counts 367 | output_array = np.column_stack([self.counts, self.totals - self.counts.sum(axis=1)]).astype(str) 368 | outfile.write("\n".join(["\t".join(row) for row in output_array]) + "\n") 369 | 370 | def write_intervals(self, outfile, chrom="chr1", include_header=True): 371 | if include_header: outfile.write("chrom\tstart\tend\n") 372 | outfile.write("\n".join(["\t".join([chrom, str(self.intervals[i,0]), str(self.intervals[i,1])]) for i in range(len(self.totals))]) + "\n") 373 | 374 | # this was the older, slower version. Now retired. 375 | 376 | #def stack_topocounts(topocounts_list, silent=False): 377 | ##get all unique intervals 378 | #unique_intervals = get_unique_intervals(np.row_stack([tc.intervals for tc in topocounts_list])) 379 | 380 | #n_unique_intervals = len(unique_intervals) 381 | 382 | #_counts_ = np.zeros((n_unique_intervals, topocounts_list[0].counts.shape[1]), dtype=int) 383 | #_totals_ = np.zeros(n_unique_intervals, dtype=int) 384 | 385 | ##for each iteration we check each of its intervals 386 | #for i in range(len(topocounts_list)): 387 | #if not silent: print(".", end="", file=sys.stderr, flush=True) 388 | #k=0 389 | #intervals = topocounts_list[i].intervals 390 | #max_k = intervals.shape[0] 391 | #for j in range(n_unique_intervals): 392 | ##for each of the unique intervals - if nested within the iteration's intervals, add to the stack 393 | #if intervals[k,0] <= unique_intervals[j,0] and intervals[k,1] >= unique_intervals[j,1]: 394 | ##k contains j 395 | #_counts_[j,:] += topocounts_list[i].counts[k] 396 | #_totals_[j] += topocounts_list[i].totals[k] 397 | 398 | #if intervals[k,1] == unique_intervals[j,1]: 399 | ##both are ending, so advance k to the next interval in this iteration 400 | #k += 1 401 | #if k == max_k: break 402 | 403 | #if not silent: print("\n", file=sys.stderr, flush=True) 404 | 405 | #return Topocounts(topocounts_list[0].topos, _counts_, _totals_, unique_intervals, topocounts_list[i].label_dict) 406 | 407 | #more efficient function from claude.ai 408 | def stack_topocounts(topocounts_list, silent=False): 409 | if len(topocounts_list) == 1: 410 | return topocounts_list[0] 411 | 412 | # Get number of columns from first object 413 | n_cols = len(topocounts_list[0].topos) 414 | 415 | # Step 1: Collect all breakpoints and create events 416 | events = [] # (position, event_type, tc_idx, interval_idx) 417 | # event_type: 0 = start, 1 = end 418 | 419 | for tc_idx, tc in enumerate(topocounts_list): 420 | for interval_idx, (start, end) in enumerate(tc.intervals): 421 | events.append((start, 0, tc_idx, interval_idx)) # start event 422 | events.append((end + 1, 1, tc_idx, interval_idx)) # end event (exclusive) 423 | 424 | # Sort events by position, then by type (starts before ends) 425 | events.sort(key=lambda x: (x[0], x[1])) 426 | 427 | # Step 2: Sweep through events to build merged intervals 428 | active_intervals = set() # Set of (tc_idx, interval_idx) tuples 429 | merged_intervals = [] 430 | merged_counts = [] 431 | merged_totals = [] 432 | 433 | nevents = len(events) 434 | 435 | if not silent and nevents >= 100: 436 | report = True 437 | onePercent = int(np.ceil(nevents/100)) 438 | else: 439 | report = False 440 | 441 | i = 0 442 | while i < nevents: 443 | current_pos = events[i][0] 444 | 445 | # Process all events at current position 446 | while i < nevents and events[i][0] == current_pos: 447 | pos, event_type, tc_idx, interval_idx = events[i] 448 | 449 | if event_type == 0: # start event 450 | active_intervals.add((tc_idx, interval_idx)) 451 | else: # end event 452 | active_intervals.discard((tc_idx, interval_idx)) 453 | i += 1 454 | if report and i % onePercent == 0: print(".", end="", file=sys.stderr, flush=True) 455 | 456 | # If we have active intervals, create a merged interval 457 | if active_intervals: 458 | # Find the next position where the active set changes 459 | next_pos = current_pos 460 | if i < nevents: 461 | next_pos = events[i][0] 462 | else: 463 | # No more events, extend to the last end position 464 | # Find the maximum end position among active intervals 465 | max_end = current_pos 466 | for tc_idx, interval_idx in active_intervals: 467 | end_pos = topocounts_list[tc_idx].intervals[interval_idx][1] 468 | max_end = max(max_end, end_pos) 469 | next_pos = max_end + 1 470 | 471 | # Create interval from current_pos to next_pos-1 472 | if next_pos > current_pos: 473 | interval_start = current_pos 474 | interval_end = next_pos - 1 475 | 476 | # Sum counts from all active intervals 477 | summed_counts = np.zeros(n_cols, dtype=topocounts_list[0].counts.dtype) 478 | summed_total = 0 479 | for tc_idx, interval_idx in active_intervals: 480 | summed_counts += topocounts_list[tc_idx].counts[interval_idx] 481 | summed_total += topocounts_list[tc_idx].totals[interval_idx] 482 | 483 | merged_intervals.append([interval_start, interval_end]) 484 | merged_counts.append(summed_counts) 485 | merged_totals.append(summed_total) 486 | 487 | if report: print("", file=sys.stderr, flush=True) 488 | 489 | return Topocounts(topos=topocounts_list[0].topos, counts=np.array(merged_counts), totals=np.array(merged_totals), intervals=np.array(merged_intervals), label_dict=topocounts_list[0].label_dict) 490 | 491 | 492 | 493 | def get_topocounts_tskit(ts, leaf_groups=None, group_names=None, topoDict=None, include_polytomies=False): 494 | 495 | if leaf_groups is None: 496 | #use populations from ts 497 | assert list(ts.populations()) != [], "Either specify groups or provide a treesequence with embedded population data." 498 | if group_names is None: group_names = [str(pop.id) for pop in ts.populations()] 499 | leaf_groups = [[s for s in ts.samples() if str(ts.get_population(s)) == t] for t in taxonNames] 500 | 501 | ngroups = len(leaf_groups) 502 | label_dict = dict(zip(range(ngroups), group_names if group_names else ("group"+str(i) for i in range(1, ngroups+1)))) 503 | 504 | if not topoDict: 505 | topoDict = make_topoDict_tskit(ngroups, include_polytomies=include_polytomies) 506 | 507 | topos = topoDict["topos"] 508 | ranks = topoDict["ranks"] 509 | 510 | intervals = np.array([tree.interval for tree in ts.trees()], dtype= int if ts.discrete_genome else float) 511 | 512 | #intervals[:,0] +=1 #convert back to 1-based - I decided to keep this out of the function, because it should be optional to adjust that 513 | 514 | counts = np.zeros((ts.num_trees, len(ranks)), dtype=int) 515 | 516 | totals = np.zeros(ts.num_trees, dtype=int) 517 | 518 | counter_generator = ts.count_topologies(leaf_groups) 519 | #counter_generator = tskit.combinatorics.treeseq_count_topologies(ts, leaf_groups) #this seems no faster when tested Jan 2025 520 | 521 | key = tuple(range(ngroups)) #we only want to count subtree topologies with a tip for each of the leaf_groups 522 | 523 | for i, counter in enumerate(counter_generator): 524 | counts[i] = [counter[key][rank] for rank in ranks] 525 | totals[i] = sum(counter[key].values()) 526 | 527 | return Topocounts(topos, counts, totals, intervals, label_dict) 528 | 529 | 530 | def get_topocounts(trees, leaf_groups, max_subtrees, simplify=True, group_names=None, topoDict=None, unrooted=False): 531 | 532 | ngroups = len(leaf_groups) 533 | label_dict = dict(zip(range(ngroups), group_names if group_names else ("group"+str(i) for i in range(1, ngroups+1)))) 534 | 535 | if not topoDict: 536 | topoDict = make_topoDict(ngroups, unrooted) 537 | 538 | topos = topoDict["topos"] 539 | topoIDs = topoDict["topoIDs"] 540 | 541 | counts = [] 542 | 543 | totals = [] 544 | 545 | intervals = [] 546 | 547 | leafGroupDict = makeGroupDict(leaf_groups) if simplify else None 548 | 549 | for i,tree in enumerate(trees): 550 | topoSummary = TopologySummary(tree, leafGroupDict) 551 | counts_dict = topoSummary.get_topology_counts(leaf_groups, max_subtrees=max_subtrees, unrooted=unrooted) 552 | counts.append([counts_dict[ID] for ID in topoIDs]) 553 | totals.append(sum(counts_dict.values())) 554 | intervals.append(tree.interval) 555 | 556 | return Topocounts(topos, np.array(counts, dtype=int), np.array(totals, dtype=int), np.array(intervals, dtype=float), label_dict) 557 | 558 | 559 | def get_topocounts_stacking_sticcs(der_counts, positions, ploidies, groups, max_subtrees, group_names=None, 560 | unrooted=False, second_chances=False, multi_pass=True, 561 | chrom_start=None, chrom_len=None, random_seed=None, silent=True): 562 | 563 | comboGenerator = get_combos(groups, max_subtrees = max_subtrees, ploidies=ploidies, random_seed=random_seed) 564 | 565 | topocounts_iterations = [] 566 | 567 | for iteration,combo in enumerate(comboGenerator): 568 | 569 | if not silent: 570 | print(f"\nSample combo {iteration+1} indices: {', '.join([str(idx) for idx in combo])}", file=sys.stderr, flush=True) 571 | print(f"\nSample combo {iteration+1} will contribute {np.prod(ploidies[list(combo)])} subtree(s)", file=sys.stderr, flush=True) 572 | print(f"\nInferring tree sequence for combo {iteration+1}.", file=sys.stderr, flush=True) 573 | 574 | der_counts_sub = der_counts[:, combo] 575 | 576 | #ploidies 577 | ploidies_sub = ploidies[(combo,)] 578 | 579 | #Find non-missing genoypes for these individuals 580 | no_missing = np.all(der_counts_sub >= 0, axis=1) 581 | 582 | site_sum = der_counts_sub.sum(axis=1) 583 | variable = (1 < site_sum) & (site_sum < sum(ploidies_sub)) 584 | 585 | usable_sites = np.where(no_missing & variable) 586 | 587 | #der counts and positions 588 | der_counts_sub = der_counts_sub[usable_sites] 589 | 590 | positions_sub = positions[usable_sites] 591 | 592 | patterns, matches, n_matches = sticcs.get_patterns_and_matches(der_counts_sub) 593 | 594 | clusters = sticcs.get_clusters(patterns, matches, positions_sub, ploidies=ploidies_sub, second_chances=second_chances, 595 | seq_start=chrom_start, seq_len=chrom_len, silent=True) 596 | 597 | trees = sticcs.infer_trees(patterns, ploidies_sub, clusters, multi_pass = multi_pass, silent=True) 598 | 599 | if not silent: 600 | print(f"\nCounting topologies for combo {iteration+1}.", file=sys.stderr, flush=True) 601 | 602 | topocounts = get_topocounts(trees, leaf_groups = make_numeric_groups(ploidies_sub), 603 | group_names=group_names, max_subtrees=max_subtrees, unrooted=unrooted) #here we specify max subtrees, but really each combination will usually have a much smaller number of subtrees than the overall max requested 604 | 605 | topocounts_iterations.append(topocounts.simplify()) 606 | 607 | if not silent: 608 | print(f"\nStacking", file=sys.stderr) 609 | 610 | return stack_topocounts(topocounts_iterations, silent=silent) 611 | 612 | 613 | def makeGroupDict(groups, names=None): 614 | groupDict = {} 615 | for x in range(len(groups)): 616 | for y in groups[x]: groupDict[y] = x if not names else names[x] 617 | return groupDict 618 | 619 | 620 | def make_numeric_groups(sizes): 621 | partitions = np.cumsum(sizes)[:-1] 622 | return np.split(np.arange(sum(sizes)), partitions) 623 | 624 | 625 | ############################################################################### 626 | ############################################################################### 627 | # Functions below this point are specifically for command line stuff. 628 | # Probably not useful for the API 629 | # Chould maybe be in a separate script 630 | 631 | 632 | def parse_groups_command_line(args): 633 | group_names = args.group_names 634 | 635 | ngroups = len(group_names) 636 | 637 | assert ngroups >= 3, "Please specify at least three groups." 638 | 639 | if args.groups: 640 | assert len(args.groups) == ngroups, "Number of groups does not much number of group names" 641 | groups = [g.split(",") for g in args.groups] 642 | 643 | elif args.groups_file: 644 | groups = [[] for i in range(ngroups)] 645 | with open(args.groups_file, "rt") as gf: groupDict = dict([ln.split() for ln in gf.readlines()]) 646 | for sample in groupDict.keys(): 647 | try: groups[group_names.index(groupDict[sample])].append(sample) 648 | except: pass 649 | 650 | if args.verbose: 651 | for i in range(ngroups): 652 | print(f"{group_names[i]}: {', '.join(groups[i])}\n", file=sys.stderr) 653 | 654 | assert min([len(g) for g in groups]) >= 1, "Please specify at least one sample ID per group." 655 | 656 | sampleIDs = [ID for group in groups for ID in group] 657 | assert len(sampleIDs) == len(set(sampleIDs)), "Each sample should only be in one group." 658 | 659 | label_dict = dict(zip(range(ngroups), group_names)) 660 | 661 | topoDict = make_topoDict(ngroups, unrooted=args.unrooted) 662 | 663 | if args.output_topos: 664 | with open(args.output_topos, "wt") as tf: 665 | tf.write("\n".join([t.as_newick(node_labels=label_dict) for t in topoDict["topos"]]) + "\n") 666 | 667 | sys.stderr.write("\n".join([t.as_newick(node_labels=label_dict) for t in topoDict["topos"]]) + "\n") 668 | 669 | return (group_names, groups) 670 | 671 | 672 | def sticcstack_command_line(args): 673 | 674 | group_names, groups = parse_groups_command_line(args) 675 | groups_numeric = make_numeric_groups([len(g) for g in groups]) 676 | 677 | sampleIDs = [ID for group in groups for ID in group] 678 | 679 | #VCF file 680 | vcf = cyvcf2.VCF(args.input_vcf, samples=sampleIDs) 681 | 682 | for ID in sampleIDs: assert ID in vcf.samples, f"ID {ID} not found in vcf file header line." 683 | 684 | #the dac vcf reader does not give us the dac values in the right order, so we need to record the sample indices 685 | sample_indices = [vcf.samples.index(ID) for ID in sampleIDs] 686 | 687 | ploidies, ploidyDict = sticcs.parsePloidyArgs(args, sampleIDs) 688 | 689 | ploidies = np.array(ploidies) 690 | 691 | dac_generator = sticcs.parse_vcf_with_DC_field(vcf) 692 | 693 | chromLenDict = dict(zip(vcf.seqnames, vcf.seqlens)) 694 | 695 | print(f"\nReading first chromosome...", file=sys.stderr) 696 | 697 | for chrom, positions, der_counts in dac_generator: 698 | 699 | assert positions.max() <= chromLenDict[chrom], f"\tSNP at position {positions.max()} exceeds chromosome length {chromLenDict[chrom]} for {chrom}." 700 | 701 | if args.variant_range_only: 702 | chrom_start = positions[0] 703 | chrom_len = positions[-1] 704 | else: 705 | chrom_start = 1 706 | chrom_len = chromLenDict[chrom] 707 | 708 | print(f"\nAnalysing {chrom}. {positions.shape[0]} usable SNPs found.", file=sys.stderr) 709 | 710 | #correct sample order so that groups are together 711 | der_counts = der_counts[:,sample_indices] 712 | 713 | topocounts_stacked = get_topocounts_stacking_sticcs(der_counts, positions, ploidies=ploidies, groups=groups_numeric, 714 | group_names=group_names, max_subtrees=args.max_subtrees, 715 | unrooted=args.unrooted, multi_pass=not args.single_pass, 716 | second_chances = not args.no_second_chances, 717 | chrom_start=chrom_start, chrom_len=chrom_len, 718 | silent= not args.verbose) 719 | 720 | #could potnentially add step merging identical intervals here 721 | with gzip.open(args.out_prefix + "." + chrom + ".topocounts.tsv.gz", "wt") as outfile: 722 | topocounts_stacked.write(outfile) 723 | 724 | with gzip.open(args.out_prefix + "." + chrom + ".intervals.tsv.gz", "wt") as outfile: 725 | topocounts_stacked.write_intervals(outfile, chrom=chrom) 726 | 727 | print(f"\nEnd of file reached. Looks like my work is done.", file=sys.stderr) 728 | 729 | 730 | def trees_command_line(args): 731 | 732 | group_names, groups = parse_groups_command_line(args) 733 | groups_numeric = make_numeric_groups([len(g) for g in groups]) 734 | leaf_names = [ID for group in groups for ID in group] 735 | 736 | max_subtrees = np.prod([len(g) for g in groups]) 737 | 738 | if args.input_format == "tskit": 739 | ts = tskit.load(args.input_file) 740 | 741 | elif args.input_format == "argweaver": 742 | with gzip.open(args.input_file, "rt") if args.input_file.endswith(".gz") else open(args.input_file, "rt") as treesfile: 743 | ts = parse_argweaver_file(treesfile, leaf_names=leaf_names) 744 | 745 | elif args.input_format == "newick": 746 | with gzip.open(args.input_file, "rt") if args.input_file.endswith(".gz") else open(args.input_file, "rt") as treesfile: 747 | ts = parse_newick_file(treesfile, leaf_names=leaf_names) 748 | 749 | topocounts = get_topocounts(ts.trees(), leaf_groups=groups_numeric, max_subtrees=max_subtrees, unrooted=args.unrooted) 750 | 751 | with gzip.open(args.out_prefix + ".topocounts.tsv.gz", "wt") as outfile: 752 | topocounts.write(outfile) 753 | 754 | if (args.input_format == "tskit" or args.input_format == "argweaver"): 755 | with gzip.open(args.out_prefix + ".intervals.tsv.gz", "wt") as outfile: 756 | topocounts.write_intervals(outfile, chrom=args.chrom_name) 757 | 758 | 759 | 760 | def main(): 761 | parser = argparse.ArgumentParser(prog="twisst2") 762 | 763 | subparsers = parser.add_subparsers(title = "subcommands", dest="mode") 764 | subparsers.required = True 765 | 766 | sticcstack_parser = subparsers.add_parser("sticcstack", help="Compute topology weights from input VCF using sticcstack method") 767 | 768 | sticcstack_parser.add_argument("-i", "--input_vcf", help="Input VCF file with DC field (make thsi with sticcs prep)", action = "store", required=True) 769 | sticcstack_parser.add_argument("-o", "--out_prefix", help="Output file prefix", action = "store", required=True) 770 | sticcstack_parser.add_argument("--unrooted", help="Unroot topologies (results in fewer topologies)", action = "store_true") 771 | sticcstack_parser.add_argument("--ploidy", help="Sample ploidy if all the same. Use --ploidy_file if samples differ.", action = "store", type=int) 772 | sticcstack_parser.add_argument("--ploidy_file", help="File with samples names and ploidy as columns", action = "store") 773 | sticcstack_parser.add_argument("--max_subtrees", help="Maximum number of subtrees to consider (note that each combination of diploids represents multiple subtrees)", action = "store", type=int, required=True) 774 | sticcstack_parser.add_argument("--no_second_chances", help="Do not consider SNPs that are separated by incompatible SNPs", action='store_true') 775 | sticcstack_parser.add_argument("--single_pass", help="Single pass when building trees (only relevant for ploidy > 1, but not recommended)", action='store_true') 776 | #sticcstack_parser.add_argument("--inputTopos", help="Input file for user-defined topologies (optional)", action = "store", required = False) 777 | sticcstack_parser.add_argument("--output_topos", help="Output file for topologies used", action = "store", required = False) 778 | sticcstack_parser.add_argument("--group_names", help="Name for each group (separated by spaces)", action='store', nargs="+", required = True) 779 | sticcstack_parser.add_argument("--groups", help="Sample IDs for each individual (separated by commas), for each group (separated by spaces)", action='store', nargs="+") 780 | sticcstack_parser.add_argument("--groups_file", help="Optional file with a column for sample ID and group", action = "store", required = False) 781 | sticcstack_parser.add_argument("--variant_range_only", help="Verbose output", action="store_true") 782 | sticcstack_parser.add_argument("--verbose", help="Verbose output", action="store_true") 783 | 784 | inputtrees_parser = subparsers.add_parser("trees", help="Input trees in tskit, argweaver or newick format") 785 | 786 | inputtrees_parser.add_argument("-i", "--input_file", help="Input trees file", action = "store", required=True) 787 | inputtrees_parser.add_argument("-f", "--input_format", help="Input file format (tskit, argweaver or newick)", choices=["tskit", "argweaver", "newick"], action = "store", required=True) 788 | inputtrees_parser.add_argument("-o", "--out_prefix", help="Output file prefix", action = "store", required=True) 789 | inputtrees_parser.add_argument("--chrom_name", help="Chromosome name for output intervals file", action = "store", default="unknown_chrom") 790 | inputtrees_parser.add_argument("--unrooted", help="Unroot topologies (results in fewer topologies)", action = "store_true") 791 | #inputtrees_parser.add_argument("--inputTopos", help="Input file for user-defined topologies (optional)", action = "store", required = False) 792 | inputtrees_parser.add_argument("--output_topos", help="Output file for topologies used", action = "store", required = False) 793 | inputtrees_parser.add_argument("--group_names", help="Name for each group (separated by spaces)", action='store', nargs="+", required = True) 794 | inputtrees_parser.add_argument("--groups", help="Sample IDs for each individual (separated by commas), for each group (separated by spaces)", action='store', nargs="+") 795 | inputtrees_parser.add_argument("--groups_file", help="Optional file with a column for sample ID and group", action = "store", required = False) 796 | inputtrees_parser.add_argument("--verbose", help="Verbose output", action="store_true") 797 | 798 | ### parse arguments 799 | args = parser.parse_args() 800 | 801 | if args.mode == "sticcstack": 802 | sticcstack_command_line(args) 803 | 804 | if args.mode == "trees": 805 | trees_command_line(args) 806 | 807 | 808 | 809 | 810 | if __name__ == '__main__': 811 | main() 812 | --------------------------------------------------------------------------------