├── twisst2
├── __init__.py
├── parse_newick.py
├── TopologySummary.py
└── twisst2.py
├── example
├── admix_hiILS_123_l5e6_r1e8_mu1e8.vcf.gz
└── groups_4_20.txt
├── plot_twisst
├── examples
│ ├── admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.intervals.tsv.gz
│ └── admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.topocounts.tsv.gz
├── example_plot.R
└── plot_twisst.R
├── pyproject.toml
├── LICENSE
└── README.md
/twisst2/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.5"
2 |
3 |
--------------------------------------------------------------------------------
/example/admix_hiILS_123_l5e6_r1e8_mu1e8.vcf.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/simonhmartin/twisst2/HEAD/example/admix_hiILS_123_l5e6_r1e8_mu1e8.vcf.gz
--------------------------------------------------------------------------------
/plot_twisst/examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.intervals.tsv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/simonhmartin/twisst2/HEAD/plot_twisst/examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.intervals.tsv.gz
--------------------------------------------------------------------------------
/plot_twisst/examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.topocounts.tsv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/simonhmartin/twisst2/HEAD/plot_twisst/examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.topocounts.tsv.gz
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "twisst2"
3 | authors = [{name = "Simon Martin", email = "simon.martin@ed.ac.uk"}]
4 | description = "Topology weighting from unphased genotypes of any ploidy"
5 | readme = "README.md"
6 | requires-python = ">=3.7"
7 | dynamic = ["version"]
8 | # version = "0.0.1"
9 | dependencies = ["numpy >= 1.21.5", "cyvcf2"]
10 |
11 | [tool.setuptools.dynamic]
12 | version = {attr = "twisst2.__version__"}
13 |
14 | [project.scripts]
15 | twisst2 = "twisst2.twisst2:main"
16 |
17 | [tool.setuptools]
18 | packages = ["twisst2"]
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Simon martin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/example/groups_4_20.txt:
--------------------------------------------------------------------------------
1 | tsk_0 group1
2 | tsk_1 group1
3 | tsk_2 group1
4 | tsk_3 group1
5 | tsk_4 group1
6 | tsk_5 group1
7 | tsk_6 group1
8 | tsk_7 group1
9 | tsk_8 group1
10 | tsk_9 group1
11 | tsk_10 group1
12 | tsk_11 group1
13 | tsk_12 group1
14 | tsk_13 group1
15 | tsk_14 group1
16 | tsk_15 group1
17 | tsk_16 group1
18 | tsk_17 group1
19 | tsk_18 group1
20 | tsk_19 group1
21 | tsk_20 group2
22 | tsk_21 group2
23 | tsk_22 group2
24 | tsk_23 group2
25 | tsk_24 group2
26 | tsk_25 group2
27 | tsk_26 group2
28 | tsk_27 group2
29 | tsk_28 group2
30 | tsk_29 group2
31 | tsk_30 group2
32 | tsk_31 group2
33 | tsk_32 group2
34 | tsk_33 group2
35 | tsk_34 group2
36 | tsk_35 group2
37 | tsk_36 group2
38 | tsk_37 group2
39 | tsk_38 group2
40 | tsk_39 group2
41 | tsk_40 group3
42 | tsk_41 group3
43 | tsk_42 group3
44 | tsk_43 group3
45 | tsk_44 group3
46 | tsk_45 group3
47 | tsk_46 group3
48 | tsk_47 group3
49 | tsk_48 group3
50 | tsk_49 group3
51 | tsk_50 group3
52 | tsk_51 group3
53 | tsk_52 group3
54 | tsk_53 group3
55 | tsk_54 group3
56 | tsk_55 group3
57 | tsk_56 group3
58 | tsk_57 group3
59 | tsk_58 group3
60 | tsk_59 group3
61 | tsk_60 group4
62 | tsk_61 group4
63 | tsk_62 group4
64 | tsk_63 group4
65 | tsk_64 group4
66 | tsk_65 group4
67 | tsk_66 group4
68 | tsk_67 group4
69 | tsk_68 group4
70 | tsk_69 group4
71 | tsk_70 group4
72 | tsk_71 group4
73 | tsk_72 group4
74 | tsk_73 group4
75 | tsk_74 group4
76 | tsk_75 group4
77 | tsk_76 group4
78 | tsk_77 group4
79 | tsk_78 group4
80 | tsk_79 group4
81 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # twisst2
2 |
3 | `twisst2` is a tool for [topology weighting](https://doi.org/10.1534/genetics.116.194720). Topology weighting summarises genealogies in terms of the relative abundance of different sub-tree topologies. It can be used to explore processes like introgression and it can aid the identification of trait-associated loci.
4 |
5 | `twisst2` has a number of important improvements over the original [`twisst`](https://github.com/simonhmartin/twisst) tool. Most importantly, `twisst2` **incorporates inference of the ancestral recombination graph (ARG) or tree sequence** - local genealogies and their breakpoints along the chromosome. It does this using [`sticcs`](https://github.com/simonhmartin/sticcs). `sticcs` is a model-free approach and it does not require phased data, so `twisst2` **can run on unphased genotypes of any ploidy**.
6 |
7 | The recommended way to run `twisst2` is to start from polarised genotype data. This means you either need to know the ancestral allele at each site, or you need an appropriate outgroup(s) to allow inference of the derived allele.
8 |
9 | An alternative way to run it is by first inferring the ancestral recombination graph (ARG) tree sequence using a different tool like [Relate](https://myersgroup.github.io/relate/) or [`tsinfer`](https://tskit.dev/tsinfer/docs/stable/index.html). However, this typically requires phased genotypes, and [my tests](https://doi.org/10.1093/genetics/iyaf181) suggest that `twisst2+sticcs` is more accurate than other methods anyway.
10 |
11 | # Publications
12 | - The general concept of topology weighting is described by [Martin and Van Belleghem 2017](https://doi.org/10.1534/genetics.116.194720).
13 | - Combining genealogy inference with `sticcs` and topology weighting with `twisst2` is described by [Martin 2025](https://doi.org/10.1093/genetics/iyaf181).
14 |
15 |
16 | ### Installation
17 |
18 | First install [`sticcs`](https://github.com/simonhmartin/sticcs) by following the intructions there.
19 |
20 | If you would like to analyse tree sequence objects from tools like [`msprime`](https://tskit.dev/msprime/docs/stable/intro.html) and [tsinfer](https://tskit.dev/tsinfer/docs/stable/index.html), you will also need to install [`tskit`](https://tskit.dev/tskit/docs/stable/introduction.html) yourself. To install `twisst2`:
21 |
22 | ```bash
23 | git clone https://github.com/simonhmartin/twisst2.git
24 |
25 | cd twisst2
26 |
27 | pip install -e .
28 | ```
29 |
30 | ### Command line tool
31 |
32 | #### Starting from unphased (or phased) genotypes
33 |
34 | To perform tree inference and topology weighting, `twisst2` takes as input a modified vcf file that contains a `DC` field, giving the count of derived alleles for each individual at each site.
35 |
36 | Once you have a vcf file for your genotype data, make the modified version using `sticcs` (this needs to be installed, see above):
37 | ```bash
38 | sticcs prep -i -o --outgroup
39 | ```
40 |
41 | If the vcf file already has the ancestral allele (provided in the `AA` field in the `INFO` section), then you do not need to specifiy outrgoups for polarising.
42 |
43 | Now you can run the `twisst2` to count sub-tree topologies:
44 |
45 | ```bash
46 | twisst2 sticcstack -i -o --max_subtrees 512 --ploidy 2 --groups --groups_file
47 | ```
48 |
49 | #### Starting from pre-inferred trees or ARG (e.g. Relate, tsinfer, argweaver, Singer)
50 |
51 | ```bash
52 | twisst2 trees -i -o --groups --groups_file
53 | ```
54 |
55 | ### Output
56 |
57 | - `.topocounts.tsv.gz` gives the count of each group tree topology for each interval.
58 | - `.intervals.tsv.gz` gives the chromosome, start and end position of each interval.
59 |
60 | ### R functions for plotting
61 |
62 | Some functions for importing and plotting are provided in the `plot_twisst/plot_twisst.R` script. For examples of how to use these functions, see the `plot_twisst/example_plot.R` script.
63 |
64 |
--------------------------------------------------------------------------------
/plot_twisst/example_plot.R:
--------------------------------------------------------------------------------
1 |
2 | ################################# overview #####################################
3 |
4 | # The main data produced by Twisst is a weights file which has columns for each
5 | # topology and their number of observations of that topology within each
6 | # genealogy. Weights files produced by Twisst also contain initial comment lines
7 | # speficying the topologies.
8 |
9 | # The other data file that may be of interest is the window data. That is, the
10 | # chromosome/scaffold and start and end positions for each of the regions or
11 | # windows represented in the weights file.
12 |
13 | # Both of the above files can be read into R, manipulated and plotted however
14 | # you like, but I have written some functions to make these tasks easier.
15 | # These functions are provided in the script plot_twisst.R
16 |
17 | ################### load helpful plotting functions #############################
18 |
19 | source("plot_twisst.R")
20 |
21 | ############################## input files ######################################
22 |
23 | # Each genomic region (chromosome, contig etc) shgould have two files@
24 | # A topocounts file giving the count of each topology (one per column)
25 | # And an intervals file, giving the chromosome name, start and end position
26 |
27 | # Here we just import one genomic region
28 |
29 | intervals_file <- "examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.intervals.tsv.gz"
30 |
31 | topocounts_file <- "examples/admix_hiILS_123_l5e6_r1e8_mu1e8.chr1.topocounts.tsv.gz"
32 |
33 | ################################# import data ##################################
34 |
35 | # The function import.twisst reads the topology counts and intervals data files into a list object
36 | # If there are multiple weights files, or a single file with different chromosomes/scaffolds/contigs
37 | # in the window data file, these will be separated when importing.
38 |
39 | twisst_data <- import.twisst(intervals_files=intervals_file,
40 | topocounts_files=topocounts_file)
41 |
42 | #Some additional arguments:
43 |
44 | # topos_file=
45 | # If you prefer to provide the topologies as a separate file (rather than as comment lines in the topocounts file)
46 |
47 | # ignore_extra_columns=TRUE
48 | # This option will ignore subtrees that did not match any of the defined topologies (usually due to polytomies).
49 | # If you use this option, the weightings for every tree will sum to 1, but you may be throwing away some information
50 | # If you use this option, you might want to set the min_subtrees option too.
51 |
52 | # min_subtrees=100
53 | # This option can be used in conjunction with ignore_extra_columns=TRUE.
54 | # For trees with polytomies there may be few subtrees that match any of the defined topologies.
55 | # These weightings would be less reliable (more noisy).
56 | # Any tree with too fewer subtrees considered than the defined number will be ignored.
57 |
58 | # max_interval=10000
59 | # In parts of the genome with bad data, the tree interval can be very large.
60 | # You can simply exclude this information by setting the maximum interval.
61 |
62 | # names=
63 | # If you have multiple regions they will be named according to the chromosome by default,
64 | # but you can give your own names if you prefer
65 |
66 | # reorder_by_start=TRUE
67 | # If the tree intervals are out of order, this option will reorder them by the start position
68 |
69 |
70 | ############################## combined plots ##################################
71 | # there are a functions available to plot both the weightings and the topologies
72 |
73 | #for all plots we will use the 15 colours that come with plot_twisst.R
74 | # But we will reorder them according to the most abundant topology
75 | topo_cols <- topo_cols[order(order(twisst_data$weights_overall_mean[1:length(twisst_data$topos)], decreasing=TRUE))]
76 |
77 | #a summary plot shows all the topologies and a bar plot of their relative weightings
78 | plot.twisst.summary(twisst_data, lwd=3, cex=0.7)
79 |
80 | plot.twisst.summary.boxplot(twisst_data)
81 |
82 |
83 | #or plot ALL the data across the chromosome(s)
84 | # Note, this is not recommended if there are large numbers of windows.
85 | # instead, it is recommended to first smooth the weghtings and plot the smoothed values
86 | # There are three plotting modes to try
87 | plot.twisst(twisst_data, mode=1, show_topos=TRUE, ncol_topos=15)
88 | plot.twisst(twisst_data, mode=2, show_topos=TRUE, ncol_topos=15)
89 | plot.twisst(twisst_data, mode=3, show_topos=TRUE, ncol_topos=15)
90 |
91 |
92 | # make smooth weightings and plot those across chromosomes
93 | twisst_data_smooth <- smooth.twisst(twisst_data, span_bp = 20000, spacing = 1000)
94 | plot.twisst(twisst_data_smooth, mode=2, ncol_topos=15) #mode 2 overlays polygons, mode 3 would stack them
95 |
96 |
97 | ##################### individual plots: raw weights ############################
98 |
99 | #plot raw data in "stepped" style, with polygons stacked.
100 | #specify stepped style by providing a matrix of starts and ends for positions
101 | par(mfrow = c(1,1), mar = c(4,4,1,1))
102 | plot.weights(weights_dataframe=twisst_data$weights[[1]], positions=twisst_data$interval_data[[1]][,c("start","end")],
103 | line_cols=topo_cols, fill_cols=topo_cols, stacked=TRUE)
104 |
105 | #plot raw data in stepped style, with polygons unstacked (stacked =FLASE)
106 | #use semi-transparent colours for fill
107 | plot.weights(weights_dataframe=twisst_data$weights[[1]], positions=twisst_data$interval_data[[1]][,c("start","end")],
108 | line_cols=topo_cols, fill_cols=paste0(topo_cols,80), stacked=FALSE)
109 |
110 |
111 | #################### individual plots: smoothed weights ########################
112 |
113 | #plot smoothed data with polygons stacked
114 | plot.weights(weights_dataframe=twisst_data_smooth$weights[[1]], positions=twisst_data_smooth$pos[[1]],
115 | line_cols=topo_cols, fill_cols=topo_cols, stacked=TRUE)
116 |
117 | #plot smoothed data with polygons unstacked
118 | plot.weights(weights_dataframe=twisst_data_smooth$weights[[1]], positions=twisst_data_smooth$pos[[1]],
119 | line_cols=topo_cols, fill_cols=paste0(topo_cols,80), stacked=FALSE)
120 |
121 |
122 |
123 | #################### subset to only the most abundant topologies #################
124 |
125 | #get list of the most abundant topologies (top 2 in this case)
126 | top2_topos <- order(twisst_data$weights_overall_mean, decreasing=T)[1:2]
127 |
128 | #subset twisst object for these
129 | twisst_data_top2topos <- subset.twisst.by.topos(twisst_data, top2_topos)
130 | #this can then be used in all the same plotting functions above.
131 |
132 | ######################## subset to only specific regions #########################
133 |
134 | #regions to keep (more than one can be specified)
135 | regions <- c("chr1")
136 |
137 | #subset twisst object for these
138 | twisst_data_chr1 <- subset.twisst.by.regions(twisst_data, regions)
139 | #this can then be used in all the same plotting functions above.
140 |
141 |
142 | ########################### plot topologies using Ape ##########################
143 | #unroot trees if you want to
144 | # for (i in 1:length(twisst_data$topos)) twisst_data$topos[[i]] <- ladderize(unroot(twisst_data$topos[[i]]))
145 |
146 | par(mfrow = c(3,length(twisst_data$topos)/3), mar = c(1,1,2,1), xpd=NA)
147 | for (n in 1:length(twisst_data$topos)){
148 | plot.phylo(twisst_data$topos[[n]], type = "cladogram", edge.color=topo_cols[n], edge.width=5, rotate.tree = 90, cex = 1, adj = .5, label.offset=.2)
149 | mtext(side=3,text=paste0("topo",n))
150 | }
151 |
152 |
153 |
--------------------------------------------------------------------------------
/twisst2/parse_newick.py:
--------------------------------------------------------------------------------
1 | def parse_newick(newick_string):
2 | """
3 | Parse a Newick string into a dictionary representation.
4 |
5 | Returns:
6 | dict: Tree structure where keys are node IDs and values are lists of children
7 | dict: Node metadata (names, branch lengths, etc.)
8 | int: Root node ID
9 | """
10 |
11 | # Remove whitespace and trailing semicolon
12 | newick = newick_string.strip().rstrip(';')
13 |
14 | # Global variables for the parser
15 | node_counter = 0
16 | node_children = {}
17 | node_label = {}
18 | branch_length = {}
19 |
20 | def get_next_node_id():
21 | nonlocal node_counter
22 | node_id = node_counter
23 | node_counter += 1
24 | return node_id
25 |
26 | def parse_node(s, pos=0):
27 | """
28 | Recursively parse a node from position 'pos' in string 's'.
29 |
30 | Returns:
31 | tuple: (node_id, new_position)
32 | """
33 | nonlocal node_children, node_label, branch_length
34 |
35 | current_node_id = get_next_node_id()
36 | node_children[current_node_id] = [] # Initialize children list
37 | node_label[current_node_id] = None
38 | branch_length[current_node_id] = None
39 |
40 | # Skip whitespace
41 | while pos < len(s) and s[pos].isspace(): pos += 1
42 |
43 | # Case 1: Internal node - starts with '('
44 | if pos < len(s) and s[pos] == '(':
45 | pos += 1 # Skip opening parenthesis
46 |
47 | # Parse children until we hit the closing parenthesis
48 | while pos < len(s) and s[pos] != ')':
49 | # Skip whitespace and commas
50 | while pos < len(s) and s[pos] in ' \t,': pos += 1
51 |
52 | if pos < len(s) and s[pos] != ')':
53 | # Recursively parse child
54 | child_id, pos = parse_node(s, pos)
55 | node_children[current_node_id].append(child_id)
56 |
57 | if pos < len(s) and s[pos] == ')': pos += 1 # Skip closing parenthesis
58 |
59 | # Case 2 & 3: Parse node label and branch length (for both leaf and internal nodes)
60 | # Node label comes first
61 | label_start = pos
62 | while pos < len(s) and s[pos] not in '[:,();': pos += 1
63 |
64 | if pos > label_start:
65 | node_label[current_node_id] = s[label_start:pos].strip()
66 |
67 | # Branch length comes after ':'
68 | if pos < len(s) and s[pos] == ':':
69 | pos += 1 # Skip ':'
70 |
71 | # Parse branch length
72 | length_start = pos
73 | while pos < len(s) and s[pos] not in '[,();': pos += 1
74 |
75 | if pos > length_start:
76 | try:
77 | branch_length[current_node_id] = float(s[length_start:pos])
78 | except:
79 | raise ValueError(f"{length_start:pos} not a valid branch length")
80 | #pass # Invalid branch length, keep as None
81 |
82 | #progress past anything included beyond branch length
83 | while pos < len(s) and s[pos] not in ',();': pos += 1
84 |
85 | return current_node_id, pos
86 |
87 | # Start parsing from the beginning
88 | root_id, _ = parse_node(newick, 0)
89 |
90 | #now figure out which nodes are leaves and which are parents
91 | n_nodes = get_next_node_id()
92 | parents = []
93 | leaves = []
94 | for id in range(n_nodes):
95 | if node_children[id] == []: leaves.append(id)
96 | else: parents.append(id)
97 |
98 | return node_children, node_label, branch_length, leaves, parents, root_id
99 |
100 |
101 | from sticcs import sticcs
102 |
103 |
104 | # another function to convert a newick string to a sticcs tree.
105 | # This object uses numeric leaf IDs from 0 to N-1.
106 | # These can be passed to the function as a dictionary.
107 | # Otherwise the function attempts to convert leaf IDs to integers
108 | # Internal nodes don't need the same level of consistency, so those will
109 | # be kept in the order i which they were found, but renumbered so they come after the leaf indices
110 | def newick_to_sticcs_Tree(newick_string, leaf_idx_dict=None, allow_additional_leaves=False, interval=None):
111 |
112 | #first parse newick and number nodes and leaves by order of finding them
113 | node_children, node_label, branch_length, leaves, parents, root_id = parse_newick(newick_string)
114 | #nodes that come out of parse_newick are integers, but they are in the order they were encountered in reading the string, so they are meaningless
115 | #What matters is the node_labels, as these are what those nodes were called (i.e. leaf names, but internal nodes can theoretically have labels too)
116 | #Now the sticcs tree needs numeric leaves, but these need to start from zero - so they are NOT the same as the nodes that come from parse_newick()
117 | #Instead each node_label needs to map to an integer between zero and n_leaves
118 | #This can be provided in the leaf_idx_dict
119 | #If not, all node_labels need to be integers in the newick string
120 |
121 | n_leaves = len(leaves)
122 |
123 | #get the label for each leaf
124 | leaf_labels = [node_label[leaf_number] for leaf_number in leaves]
125 | assert len(set(leaf_labels)) == n_leaves, "Leaf labels must all be unique."
126 |
127 | #now, if there is no dictionary provided, we assume leaves are already numbered 0:n-1
128 | if not leaf_idx_dict:
129 | if not set(leaf_labels) == set([str(i) for i in range(n_leaves)]):
130 | raise ValueError("Leaf labels not consecutive integers. Please provide a leaf_idx_dict with consecutive indices from 0 to n_leaves-1")
131 |
132 | _leaf_idx_dict_ = dict([(label, int(label)) for label in leaf_labels])
133 |
134 | else:
135 | #We already have the dictionary to convert leaf number to new index
136 | _leaf_idx_dict_ = {}
137 | _leaf_idx_dict_.update(leaf_idx_dict)
138 | assert set(_leaf_idx_dict_.values()) == set(range(len(_leaf_idx_dict_))), "leaf_idx_dict must provide consecutive indices from zero"
139 | #Check that all leaves are represented
140 | unrepresented_leaves = [label for label in leaf_labels if label not in leaf_idx_dict]
141 | if len(unrepresented_leaves) > 0:
142 | if allow_additional_leaves:
143 | i = len(_leaf_idx_dict_)
144 | for leaf_label in sorted(unrepresented_leaves):
145 | _leaf_idx_dict_[leaf_label] = i
146 | i += 1
147 | else:
148 | raise ValueError("Some leaves are not listed in leaf_idx_dict. Set allow_additional_leaves=True if you want to risk it.")
149 |
150 | new_leaf_IDs = sorted(_leaf_idx_dict_.values())
151 |
152 | #now assign the new ID to each leaf number (remember leaf numbers in the parsed newick are simply in the order they were encountered)
153 | new_node_idx = dict(zip(leaves, [_leaf_idx_dict_[node_label[leaf_number]] for leaf_number in leaves]))
154 |
155 | #the new parents are just n_leaves along from where they were
156 | new_parents = [id+n_leaves for id in parents]
157 |
158 | new_node_idx.update(dict(zip(parents, new_parents)))
159 |
160 | #now update objects for the new sticcs Tree object
161 | new_node_children = dict()
162 | for item in node_children.items():
163 | new_node_children[new_node_idx[item[0]]] = [new_node_idx[x] for x in item[1]]
164 |
165 | new_node_parent = dict()
166 | for item in new_node_children.items():
167 | for child in item[1]:
168 | new_node_parent[child] = item[0]
169 |
170 | new_node_label = dict()
171 | for item in node_label.items():
172 | new_node_label[new_node_idx[item[0]]] = item[1]
173 |
174 | objects = {"leaves":new_leaf_IDs, "root":n_leaves, "parents":new_parents,
175 | "node_children": new_node_children, "node_parent":new_node_parent,
176 | "interval":interval}
177 |
178 | return (sticcs.Tree(objects=objects), new_node_label)
179 |
180 |
181 |
182 | # a class that just holds a bunch of trees, but has a few attributes and methods that resemble a tskit treesequence object
183 | # This allos us to import newick trees and treat the list as a treesequence for a few things (like my quartet distance calculation)
184 | class TreeList:
185 | def __init__(self, trees):
186 | self.num_trees = len(trees)
187 | self.tree_list = trees
188 |
189 | def trees(self):
190 | for tree in self.tree_list:
191 | yield tree
192 |
193 |
194 | def parse_newick_file(newickfile, leaf_names=None):
195 |
196 | #if leaves are not integers, need to link each leaf to its index
197 | leaf_idx_dict = dict(zip(leaf_names, range(len(leaf_names)))) if leaf_names is not None else None
198 |
199 | trees = []
200 |
201 | i = 1
202 |
203 | for line in newickfile:
204 | tree, node_labels = newick_to_sticcs_Tree(line.strip(), leaf_idx_dict=leaf_idx_dict, allow_additional_leaves=True, interval = (i, i))
205 | trees.append(tree)
206 | i+=1
207 |
208 | return TreeList(trees)
209 |
210 |
211 | def parse_argweaver_smc(argfile, leaf_names=None):
212 | #annoyingly, the trees have numeric leaves, but these are NOT in the order of the haps
213 | #The order if given by the first line, so we can figure out which hap each leaf points to
214 |
215 | #get the order of haps from argweaver output
216 | #The position of each number in this list links it to a leaf (leaves are numeric)
217 | leaf_names_reordered = argfile.readline().split()[1:]
218 |
219 | n_leaves = len(leaf_names_reordered)
220 |
221 | _leaf_names = leaf_names if leaf_names is not None else [str(i) for i in range(n_leaves)]
222 |
223 | #link each leaf name to its real index. This is the index that will be used for topology weighting
224 | leaf_name_to_idx_dict = dict([(_leaf_names[i], i) for i in range(n_leaves)])
225 |
226 | #link the leaf number in the current file (which is arbitrary) to its real index
227 | leaf_idx_dict = dict([(str(i), leaf_name_to_idx_dict[leaf_names_reordered[i]]) for i in range(n_leaves)])
228 |
229 | trees = []
230 |
231 | chrom, chrom_start, chrom_len = argfile.readline().split()[1:]
232 | for line in argfile:
233 | if line.startswith("TREE"):
234 | elements = line.split()
235 | interval=(int(elements[1]), int(elements[2]),)
236 | tree_newick = elements[3]
237 | tree, node_labels = newick_to_sticcs_Tree(tree_newick, leaf_idx_dict=leaf_idx_dict, allow_additional_leaves=True, interval=interval)
238 | trees.append(tree)
239 |
240 | return TreeList(trees)
241 |
242 |
243 |
244 | def sim_test_newick_parser(n=12, reps=5):
245 | import msprime
246 | from twisst2 import TopologySummary
247 |
248 | for i in range(reps):
249 | ts = msprime.sim_ancestry(n, ploidy=1)
250 | t = ts.first()
251 | sim_tree_summary = TopologySummary.TopologySummary(t)
252 | sim_tree_ID = sim_tree_summary.get_topology_ID()
253 | sim_tree_newick = t.as_newick(include_branch_lengths=False)
254 | print("Simulated tree:", sim_tree_newick)
255 |
256 | #now read it
257 | parsed_tree, node_labels = newick_to_sticcs_Tree(sim_tree_newick, leaf_idx_dict= dict([("n"+str(i), i) for i in range(n)]))
258 |
259 | parsed_tree_summary = TopologySummary.TopologySummary(parsed_tree)
260 | parsed_tree_ID = parsed_tree_summary.get_topology_ID()
261 | parsed_tree_ID == sim_tree_ID
262 | print("Parsed tree: ", parsed_tree.as_newick(node_labels=node_labels))
263 |
264 | print("Match:", parsed_tree_ID == sim_tree_ID)
265 |
266 |
267 |
268 | # Test the parser
269 | if __name__ == "__main__":
270 |
271 | #Manual test
272 | #newick_string = '((n1,n4),((n0,n3),(n2,n5)));'
273 | #parsed_tree, node_labels = newick_to_sticcs_Tree(newick_string, leaf_idx_dict={"n0":0, "n1":1, "n2":2, "n3":3, "n4":4, "n5": 5})
274 | #print(parsed_tree.as_newick(node_labels=node_labels))
275 |
276 | #newick_string = '((1,4),((0,3),(2,5)));'
277 | #parsed_tree, node_labels = newick_to_sticcs_Tree(newick_string)
278 | #print(parsed_tree.as_newick())
279 |
280 | #Auto test multiple times
281 | sim_test_newick_parser()
282 |
283 |
--------------------------------------------------------------------------------
/twisst2/TopologySummary.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import defaultdict
3 | from collections import deque
4 | import itertools, random
5 |
6 | class NodeChain(deque):
7 | def __init__(self, nodeList, dists=None):
8 | super(NodeChain, self).__init__(nodeList)
9 | if dists is None: self.dists = None
10 | else:
11 | assert len(dists) == len(self)-1, "incorrect number of iternode distances"
12 | self.dists = deque(dists)
13 | self._set_ = None
14 |
15 | def addNode(self, name, dist=0):
16 | self.append(name)
17 | if self.dists is not None: self.dists.append(dist)
18 |
19 | def addNodeLeft(self, name, dist=0):
20 | self.appendleft(name)
21 | if self.dists is not None: self.dists.appendleft(dist)
22 |
23 | def addNodeChain(self, chainToAdd, joinDist=0):
24 | self.extend(chainToAdd)
25 | if self.dists is not None:
26 | assert chainToAdd.dists is not None, "Cannot add a chain without distances to one with distances"
27 | self.dists.append(joinDist)
28 | self.dists.extend(chainToAdd.dists)
29 |
30 | def addNodeChainLeft(self, chainToAdd, joinDist=0):
31 | self.extendleft(chainToAdd)
32 | if self.dists is not None:
33 | assert chainToAdd.dists is not None, "Cannot add a chain without distances to one with distances"
34 | self.dists.appendleft(joinDist)
35 | self.dists.extendleft(chainToAdd.dists)
36 |
37 | def chopLeft(self):
38 | self.popleft()
39 | if self.dists is not None: self.dists.popleft()
40 |
41 | def chop(self):
42 | self.pop()
43 | if self.dists is not None: self.dists.pop()
44 |
45 | def fuseLeft(self, chainToFuse):
46 | new = NodeChain(self, self.dists)
47 | assert new[0] == chainToFuse[0], "No common nodes"
48 | i = 1
49 | while new[1] == chainToFuse[i]:
50 | new.chopLeft()
51 | i += 1
52 | m = len(chainToFuse)
53 | while i < m:
54 | new.addNodeLeft(chainToFuse[i], chainToFuse.dists[i-1] if self.dists is not None else None)
55 | i += 1
56 | return new
57 |
58 | def simplifyToEnds(self, newDist=None):
59 | if self.dists is not None:
60 | if not newDist: newDist = sum(self.dists)
61 | self.dists.clear()
62 | leftNode = self.popleft()
63 | rightNode = self.pop()
64 | self.clear()
65 | self.append(leftNode)
66 | self.append(rightNode)
67 | if self.dists is not None:
68 | self.dists.append(newDist)
69 |
70 | def setSet(self):
71 | self._set_ = set(self)
72 |
73 |
74 | def getChainsToLeaves(tree, node=None, simplifyDict = None):
75 | if node is None: node = tree.root
76 | children = tree.children(node) #this syntax for tskit trees
77 | #children = tree.node_children[node]
78 | if len(children) == 0:
79 | #if it has no children is is a child
80 | #if it's in the simplifyDict or there is not simplifyDict
81 | #just record a weight for the node and return is as a new 1-node chain
82 | if simplifyDict is None or node in simplifyDict:
83 | chain = NodeChain([node])
84 | setattr(chain, "weight", 1)
85 | return [chain]
86 | else:
87 | return []
88 | #otherwise get chains for all children
89 | childrenChains = [getChainsToLeaves(tree, child, simplifyDict) for child in children]
90 | #now we have the chains from all children, we need to add the current node
91 | for childChains in childrenChains:
92 | for chain in childChains: chain.addNodeLeft(node)
93 |
94 | #if collapsing, check groups for each node
95 | if simplifyDict:
96 | nodeGroupsAll = np.array([simplifyDict[chain[-1]] for childChains in childrenChains for chain in childChains])
97 | nodeGroups = list(set(nodeGroupsAll))
98 | nGroups = len(nodeGroups)
99 |
100 | if (nGroups == 1 and len(nodeGroupsAll) > 1) or (nGroups == 2 and len(nodeGroupsAll) > 2):
101 | #all chains end in a leaf from one or two groups, so we can simplify.
102 | #first list all chains
103 | chains = [chain for childChains in childrenChains for chain in childChains]
104 | #Start by getting index of each chain for each group
105 | indices = [(nodeGroupsAll == group).nonzero()[0] for group in nodeGroups]
106 | #the new weight for each chain we keep will be the total node weight of all from each group
107 | newWeights = [sum([chains[i].weight for i in idx]) for idx in indices]
108 | #now reduce to just a chain for each group
109 | chains = [chains[idx[0]] for idx in indices]
110 | for j in range(nGroups):
111 | chains[j].simplifyToEnds()
112 | chains[j].weight = newWeights[j]
113 |
114 | #if we couldn't simply collapse completely, we might still be able to merge down a side branch
115 | #Side branches are child chains ending in a single leaf
116 | #If there is a lower level child branch that is itself a side branch, we can merge to it
117 | elif (len(childrenChains) == 2 and
118 | ((len(childrenChains[0]) == 1 and len(childrenChains[1]) > 1) or
119 | (len(childrenChains[1]) == 1 and len(childrenChains[0]) > 1))):
120 | chains,sideChain = (childrenChains[1],childrenChains[0][0]) if len(childrenChains[0]) == 1 else (childrenChains[0],childrenChains[1][0])
121 | #now check if any main chain is suitable (should be length 3, and the only one that is such. and have correct group
122 | targets = (np.array([len(chain) for chain in chains]) == 3).nonzero()[0]
123 | if len(targets) == 1 and simplifyDict[chains[targets[0]][-1]] == simplifyDict[sideChain[-1]]:
124 | #we have found a suitable internal chain to merge to
125 | targetChain = chains[targets[0]]
126 | newWeight = targetChain.weight + sideChain.weight
127 | targetChain.simplifyToEnds()
128 | targetChain.weight = newWeight
129 | else:
130 | #if we didn't find a suitable match, just add side chain
131 | chains.append(sideChain)
132 | else:
133 | #if there was no side chain, just list all chains
134 | chains = [chain for childChains in childrenChains for chain in childChains]
135 | #otherwise we are not collapsing, so just list all chains
136 | else:
137 | chains = [chain for childChains in childrenChains for chain in childChains]
138 | #now we have the chains from all children, we need to add the current node
139 |
140 | return chains
141 |
142 |
143 | def get_leaf_combos(leaf_groups, max_subtrees):
144 | total = np.prod([len(t) for t in leaf_groups])
145 | if total <= max_subtrees:
146 | for combo in itertools.product(*leaf_groups):
147 | yield combo
148 | else:
149 | for i in range(max_subtrees):
150 | yield tuple(random.choice(group) for group in groups)
151 |
152 | def pairsDisjoint(pairOfPairs):
153 | if pairOfPairs[0][0] in pairOfPairs[1] or pairOfPairs[0][1] in pairOfPairs[1]: return False
154 | return True
155 |
156 |
157 | pairs_generic = dict([(n, tuple(itertools.combinations(range(n),2))) for n in range(4,8)])
158 |
159 | pairPairs_generic = dict([(n, [pairPair for pairPair in itertools.combinations(pairs_generic[n],2) if pairsDisjoint(pairPair)],) for n in range(4,8)])
160 |
161 |
162 | class TopologySummary:
163 | #Summarises a tree as a set of NodeChain objects
164 | #Provides a convenient way to summarise the topology as a unique ID
165 | #If group data is provided, tree is only summarised in terms of chains
166 | # between nodes from distinct groups, and redundant chains (e.g. monophyletic clades)
167 | # are simplified and can be weighted accordingly with recorded leaf weights
168 |
169 | def __init__(self, tree, simplifyDict=None):
170 | self.root = tree.root
171 | self.simplifyDict = simplifyDict
172 | chains = getChainsToLeaves(tree, simplifyDict=self.simplifyDict)
173 | self.leafWeights = {}
174 | self.chains = {}
175 | #record each root-leaf chain with th
176 | for chain in chains:
177 | self.chains[(chain[0], chain[-1])] = chain
178 | self.leafWeights[chain[-1]] = chain.weight
179 |
180 | self.leavesRetained = set(self.leafWeights.keys())
181 | #now make chains for all pairs of leaves (or all pairs in separate groups if defined)
182 | if self.simplifyDict:
183 | leafPairs = [pair for pair in itertools.combinations(self.leavesRetained,2) if self.simplifyDict[pair[0]] != self.simplifyDict[pair[1]]]
184 | else:
185 | leafPairs = [pair for pair in itertools.combinations(self.leavesRetained,2)]
186 |
187 | for l0,l1 in leafPairs:
188 | self.chains[(l0,l1)] = self.chains[(tree.root,l0)].fuseLeft(self.chains[(tree.root,l1)])
189 |
190 | #add a reversed entry for each chain, and set the set for each one
191 | for pair in list(self.chains.keys()):
192 | self.chains[pair].setSet()
193 | self.chains[pair[::-1]] = self.chains[pair]
194 |
195 | #add weight 1 for root (do this now only because we don't want to include the root in the leavesRetained set)
196 | self.leafWeights[tree.root] = 1
197 |
198 | #def get_topology_ID(self, leaves, unrooted=False):
199 | #if leaves is None: leaves = sorted(self.leavesRetained)
200 | #if not unrooted: leaves = list(leaves) + [self.root]
201 | #n = len(leaves)
202 | #return tuple([self.chains[(leaves[pairs[0][0]],
203 | #leaves[pairs[0][1]],)]._set_.isdisjoint(self.chains[(leaves[pairs[1][0]],
204 | #leaves[pairs[1][1]],)]._set_)
205 | #for pairs in pairPairs_generic[n]])
206 |
207 | def get_quartet_ID(self, quartet):
208 | if self.chains[(quartet[0],quartet[1],)]._set_.isdisjoint(self.chains[(quartet[2],quartet[3],)]._set_):
209 | return 0
210 | elif self.chains[(quartet[0],quartet[2],)]._set_.isdisjoint(self.chains[(quartet[1],quartet[3],)]._set_):
211 | return 1
212 | elif self.chains[(quartet[0],quartet[3],)]._set_.isdisjoint(self.chains[(quartet[1],quartet[2],)]._set_):
213 | return 2
214 | return 3
215 |
216 | #def get_all_quartet_IDs(self, leaves=None, unrooted=False):
217 | #if leaves is None: leaves = sorted(self.leavesRetained)
218 | #if not unrooted: leaves = list(leaves) + [self.root]
219 | #return [self.get_quartet_ID(quartet) for quartet in itertools.combinations(leaves, 4)]
220 |
221 | def get_all_quartet_IDs(self, leaves=None, unrooted=False):
222 | if leaves is None: leaves = sorted(self.leavesRetained)
223 | if unrooted:
224 | return [self.get_quartet_ID(quartet) for quartet in itertools.combinations(leaves, 4)]
225 | else:
226 | return [self.get_quartet_ID(trio + (self.root,)) for trio in itertools.combinations(leaves, 3)]
227 |
228 | def get_topology_ID(self, leaves=None, unrooted=False):
229 | return tuple(self.get_all_quartet_IDs(leaves, unrooted))
230 |
231 | def get_topology_counts(self, leaf_groups, max_subtrees, unrooted=False):
232 | _leaf_groups = [[leaf for leaf in group if leaf in self.leavesRetained] for group in leaf_groups]
233 |
234 | if self.simplifyDict is not None:
235 | total = np.prod([len(g) for g in _leaf_groups])
236 | assert total <= max_subtrees, f"With groups {_leaf_groups}, there will be {total} subtrees, but you have requested only {max_subtrees}, you you need to turn off tree simplification or increase max_subtrees."
237 |
238 | leaf_combos = get_leaf_combos(_leaf_groups, max_subtrees)
239 | counts = defaultdict(int)
240 | for combo in leaf_combos:
241 | comboWeight = np.prod([self.leafWeights[leaf] for leaf in combo])
242 | ID = self.get_topology_ID(combo, unrooted)
243 | counts[ID] += comboWeight
244 |
245 | return counts
246 |
247 | def get_quartet_dist(tree1, tree2, unrooted=False, approximation_subset_size=None):
248 | topoSummary1 = TopologySummary(tree1)
249 | topoSummary2 = TopologySummary(tree2)
250 | leaves = list(topoSummary1.leavesRetained)
251 |
252 | if approximation_subset_size is None:
253 | quartetIDs1 = np.array(topoSummary1.get_all_quartet_IDs(unrooted=unrooted))
254 | quartetIDs2 = np.array(topoSummary2.get_all_quartet_IDs(unrooted=unrooted))
255 | else:
256 | #do approximate distance with random sets of quartets
257 | quartetIDs1 = np.zeros(approximation_subset_size, dtype=int)
258 | quartetIDs2 = np.zeros(approximation_subset_size, dtype=int)
259 |
260 | for i in range(approximation_subset_size):
261 | if unrooted:
262 | quartet = random.sample(leaves, 4)
263 | quartetIDs1[i] = topoSummary1.get_quartet_ID(quartet)
264 | quartetIDs2[i] = topoSummary2.get_quartet_ID(quartet)
265 |
266 | else:
267 | trio = random.sample(leaves, 3)
268 | quartetIDs1[i] = topoSummary1.get_quartet_ID(trio + [tree1.root])
269 | quartetIDs2[i] = topoSummary2.get_quartet_ID(trio + [tree2.root])
270 |
271 | dif = quartetIDs1 - quartetIDs2
272 | return np.mean(dif != 0)
273 |
274 |
275 | def get_min_quartet_dist(tree1, tree2, inds, max_itr=10, unrooted=False):
276 | topoSummary1 = TopologySummary(tree1)
277 | topoSummary2 = TopologySummary(tree2)
278 |
279 | #quartet IDs for tree 1 are unchanging
280 | quartetIDs1 = np.array(topoSummary1.get_all_quartet_IDs(unrooted=unrooted))
281 |
282 | #for tree 2 we try with different permutations for each individual
283 | new_inds = inds[:]
284 |
285 | for itr in range(max_itr):
286 |
287 | for i in range(len(inds)):
288 |
289 | ind_orderings = list(itertools.permutations(inds[i]))
290 |
291 | dists = []
292 |
293 | for ind_ordering in ind_orderings:
294 | current_inds = new_inds[:]
295 | current_inds[i] = ind_ordering
296 | quartetIDs2 = np.array(topoSummary2.get_all_quartet_IDs(leaves=[i for ind in current_inds for i in ind], unrooted=unrooted))
297 | dif = quartetIDs1 - quartetIDs2
298 | dists.append(np.mean(dif != 0))
299 |
300 | new_inds[i] = ind_orderings[np.argmin(dists)]
301 |
302 | if itr > 0 and new_inds == previous_new_inds: break
303 | else:
304 | previous_new_inds = new_inds[:]
305 |
306 | #get final dist
307 | quartetIDs2 = np.array(topoSummary2.get_all_quartet_IDs(leaves=[i for ind in new_inds for i in ind], unrooted=unrooted))
308 | dif = quartetIDs1 - quartetIDs2
309 |
310 | return np.mean(dif != 0)
311 |
312 |
--------------------------------------------------------------------------------
/plot_twisst/plot_twisst.R:
--------------------------------------------------------------------------------
1 | simple.loess.predict <- function(x, y, span, new_x=NULL, weights = NULL, max = NULL, min = NULL, family=NULL){
2 | y.loess <- loess(y ~ x, span = span, weights = weights, family=family)
3 | if (is.null(new_x)) {y.predict <- predict(y.loess,x)}
4 | else {y.predict <- predict(y.loess,new_x)}
5 | if (is.null(min) == FALSE) {y.predict = ifelse(y.predict > min, y.predict, min)}
6 | if (is.null(max) == FALSE) {y.predict = ifelse(y.predict < max, y.predict, max)}
7 | y.predict
8 | }
9 |
10 | smooth.df <- function(x, df, span, new_x = NULL, col.names=NULL, weights=NULL, min=NULL, max=NULL, family=NULL){
11 | if (is.null(new_x)) {smoothed <- df}
12 | else smoothed = df[1:length(new_x),]
13 | if (is.null(col.names)){col.names=colnames(df)}
14 | for (col.name in col.names){
15 | print(paste("smoothing",col.name))
16 | smoothed[,col.name] <- simple.loess.predict(x,df[,col.name],span = span, new_x = new_x, max = max, min = min, weights = weights, family=family)
17 | }
18 | smoothed
19 | }
20 |
21 | smooth.weights <- function(interval_positions, weights_dataframe, span, new_positions=NULL, interval_lengths=NULL){
22 | weights_smooth <- smooth.df(x=interval_positions,df=weights_dataframe, weights = interval_lengths,
23 | span=span, new_x=new_positions, min=0, max=1)
24 |
25 | #return rescaled to sum to 1
26 | weights_smooth <- weights_smooth / apply(weights_smooth, 1, sum)
27 |
28 | weights_smooth[is.na(weights_smooth)] <- 0
29 |
30 | weights_smooth
31 | }
32 |
33 |
34 | stack <- function(mat){
35 | upper <- t(apply(mat, 1, cumsum))
36 | lower <- upper - mat
37 | list(upper=upper,lower=lower)
38 | }
39 |
40 | interleave <- function(x1,x2){
41 | output <- vector(length= length(x1) + length(x2))
42 | output[seq(1,length(output),2)] <- x1
43 | output[seq(2,length(output),2)] <- x2
44 | output
45 | }
46 |
47 |
48 | sum_df_columns <- function(df, columns_list){
49 | new_df <- df[,0]
50 | for (x in 1:length(columns_list)){
51 | if (length(columns_list[[x]]) > 1) new_df[,x] <- apply(df[,columns_list[[x]]], 1, sum, na.rm=T)
52 | else new_df[,x] <- df[,columns_list[[x]]]
53 | if (is.null(names(columns_list)[x]) == FALSE) names(new_df)[x] <- names(columns_list)[x]
54 | }
55 | new_df
56 | }
57 |
58 |
59 | plot.weights <- function(weights_dataframe,positions=NULL,line_cols=NULL,fill_cols=NULL,density=NULL,lwd=1,xlim=NULL,ylim=c(0,1),stacked=FALSE,
60 | ylab="Weighting", xlab = "Position", main="",xaxt=NULL,yaxt=NULL,bty="n", add=FALSE){
61 | #get x axis
62 | x = positions
63 | #if a two-column matrix is given - plot step-like weights with start and end of each interval
64 | if (dim(as.matrix(x))[2]==2) {
65 | x = interleave(positions[,1],positions[,2])
66 | yreps=2
67 | }
68 | else {
69 | if (is.null(x)==FALSE) x = positions
70 | else x = 1:nrow(weights_dataframe)
71 | yreps=1
72 | }
73 |
74 | #set x limits
75 | if(is.null(xlim)) xlim = c(min(x), max(x))
76 |
77 | #if not adding to an old plot, make a new plot
78 | if (add==FALSE) plot(0, pch = "", xlim = xlim, ylim=ylim, ylab=ylab, xlab=xlab, main=main,xaxt=xaxt,yaxt=yaxt,bty=bty)
79 |
80 | if (stacked == TRUE){
81 | y_stacked <- stack(weights_dataframe)
82 | for (n in 1:ncol(weights_dataframe)){
83 | y_upper = rep(y_stacked[["upper"]][,n],each=yreps)
84 | y_lower = rep(y_stacked[["lower"]][,n],each = yreps)
85 | polygon(c(x,rev(x)),c(y_upper, rev(y_lower)), col = fill_cols[n], density=density[n], border=NA)
86 | }
87 | }
88 | else{
89 | for (n in 1:ncol(weights_dataframe)){
90 | y = rep(weights_dataframe[,n],each=yreps)
91 | polygon(c(x,rev(x)),c(y, rep(0,length(y))), col=fill_cols[n], border=NA,density=density[n])
92 | lines(x,y, type = "l", col = line_cols[n],lwd=lwd)
93 | }
94 | }
95 | }
96 |
97 | options(scipen = 7)
98 |
99 | #Heres a set of 15 colourful colours from https://en.wikipedia.org/wiki/Help:Distinguishable_colors
100 | topo_cols <- c(
101 | "#0075DC", #Blue
102 | "#2BCE48", #Green
103 | "#FFA405", #Orpiment
104 | "#5EF1F2", #Sky
105 | "#FF5005", #Zinnia
106 | "#005C31", #Forest
107 | "#00998F", #Turquoise
108 | "#FF0010", #Red
109 | "#9DCC00", #Lime
110 | "#003380", #Navy
111 | "#F0A3FF", #Amethyst
112 | "#740AFF", #Violet
113 | "#426600", #Quagmire
114 | "#C20088", #Mallow
115 | "#94FFB5") #Jade
116 |
117 |
118 |
119 |
120 | ########### Below are some more object-oriented tools for working with standard twisst output files
121 |
122 | library(ape)
123 | library(tools)
124 |
125 | #a function that imports topocounts and computes weights
126 | import.twisst <- function(topocounts_files, intervals_files=NULL, split_by_chrom=TRUE, reorder_by_start=FALSE, na.rm=TRUE, max_interval=Inf,
127 | lengths=NULL, topos_file=NULL, ignore_extra_columns=FALSE, min_subtrees=1, recalculate_mid=FALSE, names=NULL){
128 | l = list()
129 |
130 | if (length(intervals_files) > 1){
131 | print("Reading topocounts and interval data")
132 | l$interval_data <- lapply(intervals_files, read.table ,header=TRUE)
133 | l$topocounts <- lapply(topocounts_files, read.table, header=TRUE)
134 | if (is.null(names) == FALSE) names(l$interval_data) <- names(l$topocounts) <- names
135 | }
136 |
137 | if (length(intervals_files) == 1){
138 | print("Reading topocounts and interval data")
139 | l$interval_data <- list(read.table(intervals_files, header=TRUE))
140 | l$topocounts <- list(read.table(topocounts_files, header=TRUE))
141 | if (split_by_chrom == TRUE){
142 | l$topocounts <- split(l$topocounts[[1]], l$interval_data[[1]][,1])
143 | l$interval_data <- split(l$interval_data[[1]], l$interval_data[[1]][,1])
144 | }
145 | }
146 |
147 | if (is.null(intervals_files) == TRUE) {
148 | print("Reading topocounts")
149 | l$topocounts <- lapply(topocounts_files, read.table, header=TRUE)
150 | n <- nrow(l$topocounts[[1]])
151 | l$interval_data <- list(data.frame(chrom=rep(0,n), start=1:n, end=1:n))
152 | if (is.null(names) == FALSE) names(l$interval_data) <- names
153 | }
154 |
155 | l$n_regions <- length(l$topocounts)
156 |
157 | if (is.null(names(l$interval_data)) == TRUE) {
158 | names(l$interval_data) <- names(l$topocounts) <- paste0("region", 1:l$n_regions)
159 | }
160 |
161 | print(paste("Number of regions:", l$n_regions))
162 |
163 | if (reorder_by_start==TRUE & is.null(intervals_files) == FALSE){
164 | print("Reordering")
165 | orders = sapply(l$interval_data, function(df) order(df[,2]), simplify=FALSE)
166 | l$interval_data <- sapply(names(orders), function(x) l$interval_data[[x]][orders[[x]],], simplify=F)
167 | l$topocounts <- sapply(names(orders), function(x) l$topocounts[[x]][orders[[x]],], simplify=F)
168 | }
169 |
170 | print("Getting topologies")
171 |
172 | #attempt to retrieve topologies
173 | l$topos=NULL
174 | #first, check if a topologies file is provided
175 | if (is.null(topos_file) == FALSE) {
176 | l$topos <- read.tree(file=topos_file)
177 | if (is.null(names(l$topos)) == TRUE) names(l$topos) <- names(l$topocounts[[1]])
178 | }
179 | else{
180 | #otherwise we try to retrieve topologies from the (first) topocounts file
181 | n_topos = ncol(l$topocounts[[1]]) - 1
182 | topos_text <- read.table(topocounts_files[1], nrow=n_topos, comment.char="", sep="\t", as.is=T)[,1]
183 | try(l$topos <- read.tree(text = topos_text))
184 | try(names(l$topos) <- sapply(names(l$topos), substring, 2))
185 | }
186 |
187 | print("Cleaning data")
188 |
189 | if (ignore_extra_columns == TRUE & is.null(l$topos)==FALSE) {
190 | for (i in 1:l$n_regions){
191 | #columns that are unwanted
192 | l$topocounts[[i]] <- l$topocounts[[i]][,1:length(l$topos)]
193 | }
194 | }
195 |
196 | if (na.rm==TRUE){
197 | for (i in 1:l$n_regions){
198 | #remove rows containing NA values
199 | row_sums = apply(l$topocounts[[i]],1,sum)
200 | good_rows = which(is.na(row_sums) == F & row_sums >= min_subtrees &
201 | l$interval_data[[i]]$end - l$interval_data[[i]]$start + 1 <= max_interval)
202 | l$topocounts[[i]] <- l$topocounts[[i]][good_rows,]
203 | l$interval_data[[i]] = l$interval_data[[i]][good_rows,]
204 | }
205 | }
206 |
207 | print("Computing summaries")
208 |
209 | l$weights <- sapply(l$topocounts, function(raw) raw/apply(raw, 1, sum), simplify=FALSE)
210 |
211 | l$weights_mean <- t(sapply(l$weights, apply, 2, mean, na.rm=T))
212 |
213 | #weighting per region as a total. This will be used for getting the overall mean
214 | weights_totals <- apply(t(sapply(l$weights, apply, 2, sum, na.rm=T)), 2, sum)
215 |
216 | l$weights_overall_mean <- weights_totals / sum(weights_totals)
217 |
218 | if (is.null(lengths) == TRUE) l$lengths <- sapply(l$interval_data, function(df) tail(df$end,1), simplify=TRUE)
219 | else l$lengths = lengths
220 |
221 |
222 | for (i in 1:length(l$interval_data)) {
223 | if (is.null(l$interval_data[[i]]$mid) == TRUE | recalculate_mid == TRUE) {
224 | l$interval_data[[i]]$mid <- (l$interval_data[[i]]$start + l$interval_data[[i]]$end)/2
225 | }
226 | }
227 |
228 | l$pos=sapply(l$interval_data, function(df) df$mid, simplify=FALSE)
229 |
230 | l
231 | }
232 |
233 |
234 | smooth.twisst <- function(twisst_object, span=0.05, span_bp=NULL, new_positions = NULL, spacing=NULL) {
235 | l=list()
236 |
237 | l$topos <- twisst_object$topos
238 |
239 | l$n_regions <- twisst_object$n_regions
240 |
241 | l$weights <- list()
242 |
243 | l$lengths = twisst_object$lengths
244 |
245 | l$pos <- list()
246 |
247 | for (i in 1:l$n_regions){
248 | if (is.null(span_bp) == FALSE) span <- span_bp/twisst_object$length[[i]]
249 |
250 | if (is.null(new_positions) == TRUE){
251 | if (is.null(spacing) == TRUE) spacing <- twisst_object$length[[i]]*span*.1
252 | new_positions <- seq(twisst_object$pos[[i]][1], tail(twisst_object$pos[[i]],1), spacing)
253 | }
254 |
255 | l$pos[[i]] <- new_positions
256 |
257 | interval_lengths <- twisst_object$interval_data[[i]][,3] - twisst_object$interval_data[[i]][,2] + 1
258 |
259 | l$weights[[i]] <- smooth.weights(twisst_object$pos[[i]], twisst_object$weights[[i]], new_positions = new_positions, span = span, interval_lengths = interval_lengths)
260 | }
261 |
262 | names(l$weights) <- names(twisst_object$weights)
263 |
264 | l
265 | }
266 |
267 | is.hex.col <- function(string){
268 | strvec <- strsplit(string, "")[[1]]
269 | if (strvec[1] != "#") return(FALSE)
270 | if (length(strvec) != 7 & length(strvec) != 9) return(FALSE)
271 | for (character in strvec[-1]){
272 | if (!(character %in% c("0","1","2","3","4","5","6","7","8","9","a","b","c","d","e","f","A","B","C","D","E","F"))) return(FALSE)
273 | }
274 | TRUE
275 | }
276 |
277 | hex.transparency <- function(hex, transstring="88"){
278 | if (is.hex.col(hex)==FALSE){
279 | print("WARNING: colour not hexadecimal. Cannot modify transparency.")
280 | return(hex)
281 | }
282 | if (nchar(hex) == 7) return(paste0(hex, transstring))
283 | else {
284 | substr(hex,8,9) <- transstring
285 | return(hex)
286 | }
287 | }
288 |
289 |
290 | plot.twisst <- function(twisst_object, show_topos=TRUE, rel_height=3, ncol_topos=NULL, regions=NULL, ncol_weights=1,
291 | cols=topo_cols, tree_type="clad", tree_x_lim=c(0,5), xlim=NULL, ylim=NULL, xlab=NULL, xaxt=NULL,
292 | mode=2, margins = c(4,4,2,2), concatenate=FALSE, gap=0, include_region_names=FALSE){
293 |
294 | #check if there are enough colours
295 | if (length(twisst_object$topos) > length(cols)){
296 | print("Not enough colours provided (option 'cols'), using rainbow instead")
297 | cols = rainbow(length(twisst_object$topos))
298 | }
299 |
300 | if (mode==5) {
301 | stacked=TRUE
302 | weights_order=1:length(twisst_object$topos)
303 | fill_cols = cols
304 | line_cols = NA
305 | lwd = 0
306 | }
307 |
308 | if (mode==4) {
309 | stacked=FALSE
310 | weights_order=order(apply(twisst_data$weights_mean, 2, mean, na.rm=T)[1:length(twisst_object$topos)], decreasing=T) #so that the largest values are at the back
311 | fill_cols = cols
312 | line_cols = NA
313 | lwd=par("lwd")
314 | }
315 |
316 | if (mode==3) {
317 | stacked=FALSE
318 | weights_order=order(apply(twisst_data$weights_mean, 2, mean, na.rm=T)[1:length(twisst_object$topos)], decreasing=T) #so that the largest values are at the back
319 | fill_cols = sapply(cols, hex.transparency, transstring="80")
320 | line_cols = cols
321 | lwd=par("lwd")
322 | }
323 |
324 | if (mode==2) {
325 | stacked=FALSE
326 | weights_order=order(apply(twisst_data$weights_mean, 2, mean, na.rm=T)[1:length(twisst_object$topos)], decreasing=T) #so that the largest values are at the back
327 | fill_cols = sapply(cols, hex.transparency, transstring="80")
328 | line_cols = NA
329 | lwd=par("lwd")
330 | }
331 |
332 | if (mode==1) {
333 | stacked=FALSE
334 | weights_order=1:length(twisst_object$topos)
335 | fill_cols = NA
336 | line_cols = cols
337 | lwd=par("lwd")
338 | }
339 |
340 | if (is.null(regions)==TRUE) regions <- 1:twisst_object$n_regions
341 |
342 | if (concatenate == TRUE) ncol_weights <- 1
343 |
344 | if (show_topos==TRUE){
345 | n_topos <- length(twisst_object$topos)
346 |
347 | if (is.null(ncol_topos)) ncol_topos <- n_topos
348 |
349 | #if we have too few topologies to fill the spaces in the plot, we can pad in the remainder
350 | topos_pad <- (n_topos * ncol_weights) %% (ncol_topos*ncol_weights)
351 |
352 | topos_layout_matrix <- matrix(c(rep(1:n_topos, each=ncol_weights), rep(0, topos_pad)),
353 | ncol=ncol_topos*ncol_weights, byrow=T)
354 | }
355 | else {
356 | ncol_topos <- 1
357 | n_topos <- 0
358 | topos_layout_matrix <- matrix(NA, nrow= 0, ncol=ncol_topos*ncol_weights)
359 | }
360 |
361 | #if we have too few regions to fill the spaces in the plot, we pad in the remainder
362 | data_pad <- (length(regions)*ncol_topos) %% (ncol_topos*ncol_weights)
363 |
364 | if (concatenate==TRUE) weights_layout_matrix <- matrix(rep(n_topos+1,ncol_topos), nrow=1)
365 | else {
366 | weights_layout_matrix <- matrix(c(rep(n_topos+(1:length(regions)), each=ncol_topos),rep(0,data_pad)),
367 | ncol=ncol_topos*ncol_weights, byrow=T)
368 | }
369 |
370 | layout(rbind(topos_layout_matrix, weights_layout_matrix),
371 | height=c(rep(1, nrow(topos_layout_matrix)), rep(rel_height, nrow(weights_layout_matrix))))
372 |
373 | if (show_topos == TRUE){
374 | if (tree_type=="unrooted"){
375 | par(mar=c(1,2,3,2), xpd=NA)
376 |
377 | for (i in 1:n_topos){
378 | plot.phylo(twisst_object$topos[[i]], type = tree_type, edge.color=cols[i],
379 | edge.width=5, label.offset=0.3, cex=1, rotate.tree = 90, x.lim=tree_x_lim)
380 | mtext(side=3,text=paste0("topo",i), cex=0.75)
381 | }
382 | }
383 | else{
384 | par(mar=c(1,1,1,1), xpd=NA)
385 |
386 | for (i in 1:n_topos){
387 | plot.phylo(twisst_object$topos[[i]], type = tree_type, edge.color=cols[i],
388 | edge.width=5, label.offset=0.3, cex=1, rotate.tree = 0, x.lim=tree_x_lim)
389 | mtext(side=3,text=paste0("topo",i," "), cex=0.75)
390 | }
391 | }
392 | }
393 |
394 | if (is.null(ylim)==TRUE) ylim <- c(0,1)
395 |
396 | par(mar=margins, xpd=FALSE)
397 |
398 | if (concatenate == TRUE) {
399 | chrom_offsets = cumsum(twisst_object$lengths + gap) - (twisst_object$lengths + gap)
400 | chrom_ends <- chrom_offsets + twisst_object$lengths
401 |
402 | plot(0, pch = "",xlim = c(chrom_offsets[1],tail(chrom_ends,1)), ylim = ylim,
403 | ylab = "", yaxt = "n", xlab = xlab, , xaxt = "n", bty = "n", main = "")
404 |
405 | for (j in regions) {
406 | if (is.null(twisst_object$interval_data[[j]])) positions <- twisst_object$pos[[j]] + chrom_offsets[j]
407 | else positions <- twisst_object$interval_data[[j]][,c("start","end")] + chrom_offsets[j]
408 | plot.weights(twisst_object$weights[[j]][weights_order], positions, xlim=xlim,
409 | fill_cols = fill_cols[weights_order], line_cols=line_cols[weights_order],lwd=lwd,stacked=stacked, add=T)
410 | }
411 | }
412 | else{
413 | for (j in regions){
414 | if (is.null(twisst_object$interval_data[[j]])) positions <- twisst_object$pos[[j]]
415 | else positions <- twisst_object$interval_data[[j]][,c("start","end")]
416 | plot.weights(twisst_object$weights[[j]][weights_order], positions, xlim=xlim, ylim = ylim,
417 | xlab=xlab, fill_cols = fill_cols[weights_order], line_cols=line_cols[weights_order],lwd=lwd,stacked=stacked, xaxt=xaxt)
418 | if (include_region_names==TRUE) mtext(3, text=names(twisst_object$weights)[j], adj=0, cex=0.75)
419 | }
420 | }
421 | }
422 |
423 |
424 | #function for plotting tree that uses ape to get node positions
425 | draw.tree <- function(phy, x, y, x_scale=1, y_scale=1, method=1, direction="right",
426 | col="black", col.label="black", add_labels=TRUE, add_symbols=FALSE,
427 | label_offset = 1, symbol_offset=0, col.symbol="black",symbol_bg="NA",
428 | pch=19, cex=NULL, lwd=NULL, label_alias=NULL){
429 |
430 | n_tips = length(phy$tip.label)
431 |
432 | if (direction=="right") {
433 | node_x = (node.depth(phy, method=method) - 1) * x_scale * -1
434 | node_y = node.height(phy) * y_scale
435 | label_x = node_x[1:n_tips] + label_offset
436 | label_y = node_y[1:n_tips]
437 | adj_x = 0
438 | adj_y = .5
439 | symbol_x = node_x[1:n_tips] + symbol_offset
440 | symbol_y = node_y[1:n_tips]
441 | }
442 | if (direction=="down") {
443 | node_y = (node.depth(phy, method=method) - 1) * y_scale * 1
444 | node_x = node.height(phy) * x_scale
445 | label_x = node_x[1:n_tips]
446 | label_y = node_y[1:n_tips] - label_offset
447 | adj_x = .5
448 | adj_y = 1
449 | symbol_x = node_x[1:n_tips]
450 | symbol_y = node_y[1:n_tips] - symbol_offset
451 | }
452 |
453 | #draw edges
454 | segments(x + node_x[phy$edge[,1]], y + node_y[phy$edge[,1]],
455 | x + node_x[phy$edge[,2]], y + node_y[phy$edge[,2]], col=col, lwd=lwd)
456 |
457 | if (is.null(label_alias) == FALSE) tip_labels <- label_alias[phy$tip.label]
458 | else tip_labels <- phy$tip.label
459 |
460 | if (add_labels=="TRUE") text(x + label_x, y + label_y, col = col.label, labels=tip_labels, adj=c(adj_x,adj_y),cex=cex)
461 | if (add_symbols=="TRUE") points(x + symbol_x, y + symbol_y, pch = pch, col=col.symbol, bg=symbol_bg)
462 |
463 | }
464 |
465 | #code for plotting a summary barplot
466 | plot.twisst.summary <- function(twisst_object, order_by_weights=TRUE, only_best=NULL, cols=topo_cols,
467 | x_scale=0.12, y_scale=0.15, direction="right", col="black", col.label="black",
468 | label_offset = 0.05, lwd=NULL, cex=NULL){
469 |
470 | #check if there are enough colours
471 | if (length(twisst_object$topos) > length(cols)){
472 | print("Not enough colours provided (option 'cols'), using rainbow instead")
473 | cols = rainbow(length(twisst_object$topos))
474 | }
475 |
476 | # Either order 1-15 or order with highest weigted topology first
477 |
478 | if (order_by_weights == TRUE) {
479 | ord <- order(twisst_object$weights_overall_mean[1:length(twisst_object$topos)], decreasing=T)
480 | if (is.null(only_best) == FALSE) ord=ord[1:only_best]
481 | }
482 | else ord <- 1:length(twisst_object$topos)
483 |
484 | N=length(ord)
485 |
486 | #set the plot layout, with the tree panel one third the height of the barplot panel
487 | layout(matrix(c(2,1)), heights=c(1,3))
488 |
489 | par(mar = c(1,4,.5,1))
490 |
491 | #make the barplot
492 | x=barplot(twisst_object$weights_overall_mean[ord], col = cols[ord],
493 | xaxt="n", las=1, ylab="Average weighting", space = 0.2, xlim = c(0.2, 1.2*N))
494 |
495 | #draw the trees
496 | #first make an empty plot for the trees. Ensure left and right marhins are the same
497 | par(mar=c(0,4,0,1))
498 | plot(0,cex=0,xlim = c(0.2, 1.2*N), xaxt="n",yaxt="n",xlab="",ylab="",ylim=c(0,1), bty="n")
499 |
500 | #now run the draw.tree function for each topology. You can set x_scale and y_scale to alter the tree width and height.
501 | for (i in 1:length(ord)){
502 | draw.tree(twisst_object$topos[[ord[i]]], x=x[i]+.2, y=0, x_scale=x_scale, y_scale=y_scale,
503 | col=cols[ord[i]], label_offset=label_offset, cex=cex, lwd=lwd)
504 | }
505 |
506 | #add labels for each topology
507 | text(x,.9,names(twisst_object$topos)[ord],col=cols[ord])
508 | }
509 |
510 |
511 | #code for plotting a summary boxplot
512 | plot.twisst.summary.boxplot <- function(twisst_object, order_by_weights=TRUE, only_best=NULL, trees_below=FALSE, cols=topo_cols,
513 | x_scale=0.12, y_scale=0.15, direction="right", col="black", col.label="black",
514 | label_offset = 0.05, lwd=NULL, label_alias=NULL, cex=NULL, outline=FALSE,
515 | cex.outline=NULL, lwd.box=NULL, topo_names=NULL){
516 |
517 | #check if there are enough colours
518 | if (length(twisst_object$topos) > length(cols)){
519 | print("Not enough colours provided (option 'cols'), using rainbow instead")
520 | cols = rainbow(length(twisst_object$topos))
521 | }
522 |
523 | # Either order 1-15 or order with highest weigted topology first
524 |
525 | if (order_by_weights == TRUE) {
526 | ord <- order(twisst_object$weights_overall_mean[1:length(twisst_object$topos)], decreasing=T)
527 | if (is.null(only_best) == FALSE) ord=ord[1:only_best]
528 | }
529 | else ord <- 1:length(twisst_object$topos)
530 |
531 | N=length(ord)
532 |
533 | #set the plot layout, with the tree panel one third the height of the barplot panel
534 | if (trees_below==FALSE) layout(matrix(c(2,1)), heights=c(1,3))
535 | else layout(matrix(c(1,2)), heights=c(3,1))
536 |
537 |
538 | par(mar = c(1,4,.5,1), bty="n")
539 |
540 | #make the barplot
541 | boxplot(as.data.frame(rbindlist(twisst_object$weights))[,ord], col = cols[ord],
542 | xaxt="n", las=1, xlim = c(.5, N+.5), ylab="Weighting", outline=outline, cex=cex.outline, lwd=lwd.box)
543 |
544 | #draw the trees
545 | #first make an empty plot for the trees. Ensure left and right marhins are the same
546 | par(mar=c(0,4,0,1))
547 | plot(0,cex=0, xlim = c(.5, N+.5), xaxt="n",yaxt="n",xlab="",ylab="",ylim=c(0,1), bty="n")
548 |
549 | #now run the draw.tree function for each topology. You can set x_scale and y_scale to alter the tree width and height.
550 | for (i in 1:N){
551 | draw.tree(twisst_object$topos[[ord[i]]], i+.1, y=0, x_scale=x_scale, y_scale=y_scale,
552 | col=cols[ord[i]], label_offset=label_offset, cex=cex, lwd=lwd, label_alias=label_alias)
553 | }
554 |
555 | if (is.null(topo_names)==TRUE) topo_names <- names(twisst_object$topos)
556 |
557 | #add labels for each topology
558 | text(1:N,.9,topo_names[ord],col=cols[ord], cex=cex)
559 | }
560 |
561 | #function for subsetting the twisst object by a set of topologies
562 | subset.twisst.by.topos <- function(twisst_object, topos){
563 | l <- list()
564 | regions <- names(twisst_object$weights)
565 | l$interval_data <- twisst_object$interval_data
566 | l$n_regions <- twisst_object$n_regions
567 | l$lengths <- twisst_object$lengths
568 | l$pos <- twisst_object$pos
569 | l$topocounts <- sapply(regions, function(region) twisst_data$topocounts[[region]][,topos], simplify=F)
570 | l$weights <- sapply(regions, function(region) twisst_object$weights[[region]][,topos], simplify=F)
571 | l$weights_mean <- l$weights_mean[,topos]
572 | l$weights_overall_mean <- twisst_object$weights_overall_mean[topos]
573 | l$topos <- twisst_object$topos[topos]
574 | l
575 | }
576 |
577 | #function for subsetting the twisst object by specific regions
578 | subset.twisst.by.regions <- function(twisst_object, regions){
579 | l <- list()
580 | regions <- names(twisst_object$weights[regions])
581 | l$interval_data <- twisst_object$interval_data[regions]
582 | l$n_regions <- length(regions)
583 | l$lengths <- twisst_object$lengths[regions]
584 | l$pos <- twisst_object$pos[regions]
585 | l$topocounts <- twisst_object$topocounts[regions]
586 | l$weights <- twisst_object$weights[regions]
587 | l$weights_mean <- twisst_object$weights_mean[regions]
588 | weights_totals <- apply(t(sapply(l$weights, apply, 2, sum, na.rm=T)), 2, sum)
589 | l$weights_overall_mean <- weights_totals / sum(weights_totals)
590 | l$topos <- twisst_object$topos
591 | l
592 | }
593 |
594 | rbindlist <- function(l){
595 | df <- l[[1]]
596 | if (length(l) > 1){
597 | for (i in 2:length(l)){
598 | df <- rbind(df, l[[i]])
599 | }
600 | }
601 | df
602 | }
603 |
604 |
605 | palettes = list(
606 | sashamaps = c( #https://sashamaps.net/docs/resources/20-colors/
607 | '#4363d8', #azure blue
608 | '#3cb44b', #grass green
609 | '#ffe119', #custard yellow
610 | '#e6194b', #pomergranate red
611 | '#f58231', #chalk orange
612 | '#911eb4', #violet
613 | '#46f0f0', #cyan
614 | '#f032e6', #hot pink
615 | '#000075', #navy blue
616 | '#fabebe', #pale blush
617 | '#008080', #teal
618 | '#e6beff', #mauve
619 | '#9a6324', #coconut bown
620 | '#800000', #red leather
621 | '#aaffc3', #light sage
622 | '#808000', #olive
623 | '#ffd8b1', #white skin
624 | '#bcf60c', #green banana
625 | '#808080', #grey
626 | '#fffac8'), #beach sand
627 |
628 | alphabet = c( #https://en.wikipedia.org/wiki/Help:Distinguishable_colors
629 | "#0075DC", #Blue
630 | "#FFA405", #Orpiment
631 | "#2BCE48", #Green
632 | "#003380", #Navy
633 | "#F0A3FF", #Amethyst
634 | "#990000", #Wine
635 | "#9DCC00", #Lime
636 | "#8F7C00", #Khaki
637 | "#FFE100", #Yellow
638 | "#C20088", #Mallow
639 | "#FFCC99", #Honeydew
640 | "#5EF1F2", #Sky
641 | "#00998F", #Turquoise
642 | "#FF5005", #Zinnia
643 | "#4C005C", #Damson
644 | "#426600", #Quagmire
645 | "#005C31", #Forest
646 | "#FFA8BB", #Pink
647 | "#740AFF", #Violet
648 | "#E0FF66", #Uranium
649 | "#993F00", #Caramel
650 | "#94FFB5", #Jade
651 | "#FFFF80", #Xanthin
652 | "#FF0010", #Red
653 | "#191919", #Ebony
654 | "#808080"), #Iron
655 |
656 |
657 | krzywinski = c( #https://mk.bcgsc.ca/colorblind/palettes.mhtml#15-color-palette-for-colorbliness
658 | "#68023F", #104 2 63 imperial purple, tyrian purple, nightclub, pompadour, ribbon, deep cerise, mulberry, mulberry wood, pansy purple, merlot
659 | "#008169", # 0 129 105 deep sea, generic viridian, observatory, deep sea, elf green, deep green cyan turquoise, tropical rain forest, tropical rain forest, blue green, elf green
660 | "#EF0096", #239 0 150 vivid cerise, persian rose, fashion fuchsia, hollywood cerise, neon pink, luminous vivid cerise, shocking pink, deep pink, deep pink, fluorescent pink
661 | "#00DCB5", # 0 220 181 aquamarine, aqua marine, caribbean green, aqua, eucalyptus, vivid opal, brilliant turquoise, shamrock, shamrock, caribbean green
662 | "#FFCFE2", #255 207 226 light pink, azalea, classic rose, pale pink, classic rose, pastel pink, pink lace, pig pink, orchid pink, chantilly
663 | "#003C86", # 0 60 134 flat medium blue, royal blue, bay of many, bondi blue, congress blue, cobalt, elvis, darkish blue, submerge, yale blue
664 | "#9400E6", #148 0 230 vivid purple, violet, purple, veronica, purple, electric purple, purple, vivid mulberry, vivid purple, vivid violet
665 | "#009FFA", # 0 159 250 azure, luminous vivid cornflower blue, vivid cornflower blue, brilliant azure, bleu de france, dark sky blue, brilliant azure, united nations blue, light brilliant cobalt blue, cornflower
666 | "#FF71FD", #255 113 253 blush pink, shocking pink, ultra pink, pink flamingo, fuchsia pink, light brilliant orchid, pink flamingo, candy pink, light brilliant magenta, light magenta
667 | "#7CFFFA", #124 255 250 light brilliant cyan, electric blue, dark slate grey, very light cyan, very light opal, bright cyan, dark slate grey, brilliant cyan, light brilliant opal, bright light blue
668 | "#6A0213", #106 2 19 burnt crimson, rosewood, claret, venetian red, dark tan, persian plum, prune, deep amaranth, deep reddish brown, crown of thorns
669 | "#008607", # 0 134 7 india green, ao, green, office green, web green, green, emerald green, islamic green, forest green, forest green
670 | "#F60239", #246 2 57 american rose, red, carmine, electric crimson, luminous vivid crimson, neon red, carmine red, scarlet, torch red, tractor red
671 | "#00E307", # 0 227 7 vivid sap green, vivid emerald green, vivid green, vibrant green, vivid green, vivid harlequin, green, radioactive green, vivid malachite green, vivid pistachio
672 | "#FFDC3D"), #255 220 61 gargoyle gas, filmpro lemon yellow, banana yellow, golden dream, bright sun, broom, banana split, golden dream, twentyfourseven, wild thing
673 |
674 |
675 | safe = c( #from cartomap package
676 | "#88CCEE",
677 | "#CC6677",
678 | "#DDCC77",
679 | "#117733",
680 | "#332288",
681 | "#AA4499",
682 | "#44AA99",
683 | "#999933",
684 | "#882255",
685 | "#661100",
686 | "#6699CC",
687 | "#888888"),
688 |
689 | cud = c(
690 | "#E69F00",
691 | "#56B4E9",
692 | "#009E73",
693 | "#F0E442",
694 | "#0072B2",
695 | "#D55E00",
696 | "#CC79A7",
697 | "#000000",
698 | "#999999"))
699 |
700 | show_palettes <- function(){
701 | par(mfrow = c(length(palettes), 1), mar = c(2,0,2,0))
702 |
703 | for (name in names(palettes)){
704 | n = length(palettes[[name]])
705 | plot(0, cex=0, xlim = c(1,n+1), ylim = c(0,1), xaxt="n", yaxt="n", bty="n", main=name)
706 | rect(1:n, 0, 2:(n+1), 1, col=palettes[[name]])
707 | axis(1, at=(1:n)+0.5, labels = palettes[[name]], tick=F)
708 | }
709 | }
710 |
711 |
--------------------------------------------------------------------------------
/twisst2/twisst2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import random
4 | import argparse
5 | import gzip
6 | import sys
7 | import itertools
8 | import numpy as np
9 | from twisst2.TopologySummary import *
10 | from twisst2.parse_newick import *
11 | from sticcs import sticcs
12 | import cyvcf2
13 |
14 | try: import tskit
15 | except ImportError:
16 | pass
17 |
18 | def get_combos(groups, max_subtrees, ploidies, random_seed=None):
19 | total_subtrees = np.prod([ploidies[group].sum() for group in groups])
20 | if total_subtrees <= max_subtrees:
21 | for combo in itertools.product(*groups):
22 | yield combo
23 | else:
24 | #doing only a subset, so need to keep track of how many subtrees have been done
25 | subtrees_done = 0
26 | #random number generator for repeatibility
27 | rng = random.Random(random_seed)
28 | while True:
29 | combo = [rng.choice(group) for group in groups]
30 | yield tuple(combo)
31 | subtrees_done += np.prod(ploidies[combo])
32 | if subtrees_done >= max_subtrees: break
33 |
34 |
35 | ##older code when we were specifiying number of combinations (which results in different total subtrees depending on ploidy)
36 | #else:
37 | #assert max_combos, "Please specify either maximum number of combos or subtrees"
38 | #total_combos = np.prod([len(t) for t in groups])
39 | #if total_combos <= max_combos:
40 | #for combo in itertools.product(*groups):
41 | #yield combo
42 | #else:
43 | #for i in range(max_combos):
44 | #yield tuple(random.choice(group) for group in groups)
45 |
46 |
47 | ### The code below has been retired. It was used for stacking topocounts with non-identical sets of intervals, but thanks to Claude.ai, a faster way to do this was divised.
48 |
49 | #def contains_interval(intervals, interval):
50 | #return (intervals[:,0] <= interval[0]) & (interval[1] <= intervals[:,1])
51 |
52 |
53 | #def get_unique_intervals(intervals, verbose=True):
54 |
55 | #intervals.sort(axis=1)
56 | #intervals = intervals[intervals[:,1].argsort()]
57 | #intervals = intervals[intervals[:,0].argsort()]
58 |
59 | #intervals = np.unique(intervals, axis=0)
60 |
61 | #starts = intervals[:,0]
62 | #starts = np.unique(starts)
63 | #n_starts = starts.shape[0]
64 |
65 | #ends = intervals[:,1]
66 | #ends.sort()
67 | #ends = np.unique(ends)
68 | #n_ends = ends.shape[0]
69 |
70 | #current_start = starts[0]
71 | #next_start_idx = 1
72 | #next_end_idx = 0
73 | #next_start = starts[next_start_idx]
74 | #next_end = ends[next_end_idx]
75 |
76 | #output = []
77 |
78 | #while True:
79 | #if next_start and next_start <= next_end:
80 | #if next_start == current_start:
81 | ##if we get here, the previous end was exactly 1 before the next start, so we jump to the next start
82 | #next_start_idx += 1
83 | #next_start = starts[next_start_idx] if next_start_idx < n_starts else None
84 | #else:
85 | ##we have to cut before this next start and start again
86 | #output.append((current_start, next_start-1))
87 | #current_start = next_start
88 | #next_start_idx += 1
89 | #next_start = starts[next_start_idx] if next_start_idx < n_starts else None
90 | #else:
91 | ##otherwise we use the end and start at next position
92 | #output.append((current_start, next_end))
93 | #current_start = next_end + 1
94 | #next_end_idx += 1
95 | #if next_end_idx == n_ends: break
96 | #next_end = ends[next_end_idx]
97 |
98 | ##retain only those that are contained within the starting intervals
99 | #output = [interval for interval in output if np.any(contains_interval(intervals, interval))]
100 |
101 | #return np.array(output)
102 |
103 |
104 | def show_tskit_tree(tree, node_labels=None): print(tree.draw(format="unicode", node_labels=node_labels))
105 |
106 | def is_bifurcating(tree):
107 | for node in tree.nodes():
108 | if len(tree.children(node)) > 2:
109 | return False
110 | return True
111 |
112 | def number_of_possible_trees(n_tips, include_ploytomies=False):
113 | i = 0
114 | for tree in tskit.all_trees(n_tips):
115 | if include_ploytomies or is_bifurcating(tree): i += 1
116 | return i
117 |
118 | def list_all_topologies_tskit(n, include_polytomies = False):
119 | assert n <= 7, "There are 135135 rooted topologies with 8 tips (660032 including polytomies). I doubt you want to list all of those."
120 | if include_polytomies: return list(tskit.all_trees(n))
121 | return [tree for tree in tskit.all_trees(n) if is_bifurcating(tree)]
122 |
123 | def all_possible_patterns(n, lower=2, upper=None):
124 | if upper: assert upper <= n
125 | else: upper=n
126 | for k in range(lower,upper+1):
127 | for indices in list(itertools.combinations(range(n), k)):
128 | pattern = np.zeros(n, dtype=int)
129 | pattern[(indices,)] = 1
130 | yield pattern
131 |
132 |
133 | def contains_array(arrayA, arrayB):
134 | assert arrayA.shape[1] == arrayB.shape[1]
135 | if arrayA.shape[0] < arrayB.shape[0]: return False
136 | for b in arrayB:
137 | contained = False
138 | for a in arrayA:
139 | if np.all(a == b): contained = True
140 | if not contained:
141 | return False
142 | return True
143 |
144 | def all_possible_clusters(n, lower=2, upper=None):
145 | if upper: assert upper <= n
146 | else: upper=n-1
147 | ploidies = np.array([1]*n)
148 | #make first set of cluster pairs
149 | clusters = []
150 | size = 0
151 | while True:
152 | size += 1
153 | found_one = False
154 | for new_cluster in itertools.combinations(all_possible_patterns(n, lower, upper), size):
155 | fail=False
156 | new_cluster = np.array(new_cluster)
157 | for i in range(1,size):
158 | for j in range(i):
159 | if not sticcs.passes_derived_gamete_test(new_cluster[i],new_cluster[j], ploidies):
160 | fail = True
161 | break
162 | if fail:
163 | break
164 |
165 | if not fail:
166 | found_one = True
167 | #check if an old cluster is contained within the new one
168 | clusters = [cluster for cluster in clusters if not contains_array(new_cluster, cluster)] + [new_cluster]
169 |
170 | if not found_one:
171 | return clusters
172 |
173 | def list_all_topologies(n):
174 | assert n <= 7, "There are 135135 rooted topologies with 8 tips (660032 including bifurcations). I doubt you want to list all of those."
175 | clusters = all_possible_clusters(n)
176 | topos = [sticcs.patterns_to_tree(cluster) for cluster in clusters]
177 | return topos
178 |
179 | def make_topoDict(n, unrooted=False):
180 | topos = []
181 | topoIDs = []
182 | for topo in list_all_topologies(n):
183 | topoSummary = TopologySummary(topo)
184 | ID = topoSummary.get_topology_ID(list(range(n)), unrooted=unrooted)
185 | if ID not in topoIDs: #this is because multiple topos are the same when unrooted, so we just use the first
186 | topoIDs.append(ID)
187 | topos.append(topo)
188 |
189 | return {"topos":topos, "topoIDs":topoIDs}
190 |
191 | def make_topoDict_tskit(n, include_polytomies=False):
192 | topos = list_all_topologies_tskit(n, include_polytomies=include_polytomies)
193 | ranks = [t.rank() for t in topos]
194 | return {"topos":topos, "ranks":ranks}
195 |
196 |
197 | class Topocounts:
198 | def __init__(self, topos, counts, totals=None, intervals=None, label_dict=None, rooted=None):
199 | assert counts.shape[1] == len(topos)
200 | if intervals is not None:
201 | assert intervals.shape[1] == 2
202 | assert counts.shape[0] == intervals.shape[0], f"There are {counts.shape[0]} rows of counts but {intervals.shape[0]} intervals."
203 | self.topos = topos
204 | self.counts = counts
205 | if totals is not None:
206 | assert np.all(totals >= self.counts.sum(axis=1).round(3))
207 | self.totals = totals
208 | else: self.totals = counts.sum(axis=1)
209 | self.intervals = intervals
210 | self.label_dict = label_dict
211 | self.rooted = rooted
212 |
213 | def unroot(self):
214 | #returns a Topocounts instance reduced to unrooted version of each topology
215 | topoSummaries = [TopologySummary(topo) for topo in self.topos]
216 |
217 | topoIDs_unrooted = [topoSummary.get_topology_ID(unrooted=True) for topoSummary in topoSummaries]
218 |
219 | indices = defaultdict(list)
220 | for i,ID in enumerate(topoIDs_unrooted): indices[ID].append(i)
221 |
222 | counts_unrooted = np.column_stack([self.counts[:,idx].sum(axis=1) for idx in indices.values()])
223 |
224 | new_topos = [self.topos[idx[0]] for idx in indices.values()]
225 |
226 | return Topocounts(new_topos, counts_unrooted, totals=self.totals, intervals=self.intervals, label_dict=self.label_dict, rooted=False)
227 |
228 | def fill_gaps(self, seq_start = None, seq_end = None):
229 |
230 | if seq_start is not None:
231 | assert seq_start <= self.intervals[0,0], "First interval starts before the stated start"
232 | self.intervals[0,0] = seq_start
233 |
234 | if seq_end is not None:
235 | assert seq_end >= self.intervals[-1,1], "Last interval ends after the stated end"
236 | self.intervals[-1,1] = seq_end
237 |
238 | for i in range(1, self.intervals.shape[0]):
239 | new_start = np.ceil((self.intervals[i-1,1] + self.intervals[i,0])/2)
240 | self.intervals[i,0] = new_start
241 | self.intervals[i-1,1] = new_start-1
242 |
243 | def simplify(self, fill_gaps=True, seq_start = None, seq_end = None):
244 | new_counts_list = [self.counts[0]]
245 | new_intervals_list = [self.intervals[0][:]]
246 | new_totals_list = [self.totals[0]]
247 |
248 | for j in range(1, self.intervals.shape[0]):
249 |
250 | if np.all(new_counts_list[-1] == self.counts[j]):
251 | new_intervals_list[-1][1] = self.intervals[j][-1]
252 | else:
253 | new_intervals_list.append(self.intervals[j][:])
254 | new_counts_list.append(self.counts[j])
255 | new_totals_list.append(self.totals[j])
256 |
257 | simplified = Topocounts(self.topos, np.array(new_counts_list, dtype=int),
258 | totals=np.array(new_totals_list, dtype=int),
259 | intervals=np.array(new_intervals_list), label_dict=self.label_dict, rooted=self.rooted)
260 |
261 | if fill_gaps:
262 | simplified.fill_gaps(seq_start=seq_start, seq_end=seq_end)
263 |
264 | return simplified
265 |
266 | def split_intervals(self, split_len=None, new_intervals=None):
267 | #function to split topocounts into a narrower set of intervals
268 | #if new intervals are given, it will be split at those points and return any splits between those points
269 |
270 | if split_len != None:
271 | new_intervals = np.array([(x, x+split_len-1) for x in range(1, int(self.intervals[-1][1]), split_len)])
272 | else: assert new_intervals is not None, "Either provide a split length or new intervals to define the split locations"
273 |
274 | ncol = self.counts.shape[1]
275 | nrow = new_intervals.shape[0]
276 |
277 | dummy_topocounts = Topocounts(topos =self.topos, counts=np.zeros(shape=(nrow,ncol), dtype=int), intervals = new_intervals)
278 |
279 | return stack_topocounts([self, dummy_topocounts], silent=True)
280 |
281 | #def split_intervals(self, split_len=None, new_intervals=None):
282 | ##function to split topocounts into a narrower set of intervals
283 | ##if new intervals are given, it will be split at those points and return any splits between those points
284 |
285 | #if split_len != None:
286 | #new_intervals = np.array([(x, x+split_len-1) for x in range(1, int(self.intervals[-1][1]), split_len)])
287 | #else: assert new_intervals is not None, "Either provide a split length or new intervals to define the split locations"
288 |
289 | #unique_intervals = get_unique_intervals(np.row_stack([self.intervals, new_intervals]))
290 |
291 | #n_unique_intervals = len(unique_intervals)
292 |
293 | #_counts_ = np.zeros((n_unique_intervals, self.counts.shape[1]), dtype=int)
294 | #_totals_ = np.zeros(n_unique_intervals, dtype=int)
295 |
296 | ##for each iteration we check each of its intervals
297 | #k=0
298 | #max_k = self.intervals.shape[0]
299 | #for j in range(n_unique_intervals):
300 | ##for each of the unique intervals - if nested within the one of the starting intervals add it
301 | #if self.intervals[k,0] <= unique_intervals[j,0] and self.intervals[k,1] >= unique_intervals[j,1]:
302 | ##k contains j
303 | #_counts_[j,:] = self.counts[k]
304 | #_totals_[j] = self.totals[k]
305 |
306 | #if self.intervals[k,1] == unique_intervals[j,1]:
307 | ##both are ending, so advance k to the next interval in this iteration
308 | #k += 1
309 | #if k == max_k: break
310 |
311 | #return Topocounts(self.topos, _counts_, _totals_, unique_intervals, self.label_dict)
312 |
313 | def recast(self, interval_len=None, new_intervals=None):
314 | #function to project topocounts to a new set of intervals
315 | #only possible if the totals are the same for all intervals (i.e. no subssampling of combinations)
316 | #if you just want to split to smaller intervals use split_intervals
317 | #the resulting counts will be floats because they represent a weighted average
318 |
319 | assert len(set(self.totals)) == 1, "Cannot merge intervals with different total number of combinations."
320 |
321 | if interval_len != None:
322 | new_intervals = np.array([(x, x+interval_len-1) for x in range(1, int(self.intervals[-1][1]), interval_len)])
323 | else: assert new_intervals is not None, "Either provide a split length or new intervals to define the split locations"
324 |
325 | n_new_intervals = len(new_intervals)
326 |
327 | #first make a split set of topocounts
328 | topocounts_split = self.split_intervals(new_intervals=new_intervals)
329 |
330 | _totals_ = np.array([self.totals[0]]*n_new_intervals)
331 |
332 | #now begin merging
333 | _counts_ = np.zeros((n_new_intervals, self.counts.shape[1]), dtype=float)
334 |
335 | k=0
336 | max_k = n_new_intervals
337 | current_indices = []
338 | for j in range(topocounts_split.intervals.shape[0]):
339 | #for each of the unique intervals - if nested within the new interval add the count to a bin
340 | if new_intervals[k,0] <= topocounts_split.intervals[j,0] and topocounts_split.intervals[j,1] <= new_intervals[k,1]:
341 | current_indices.append(j)
342 |
343 | if topocounts_split.intervals[j,1] == new_intervals[k,1]:
344 | #end of new interval record mean
345 | interval_lengths = np.diff(topocounts_split.intervals[current_indices], axis=1)[:,0] + 1 # weights will be the lengths
346 | _counts_[k] = np.average(topocounts_split.counts[current_indices], axis=0, weights=interval_lengths)
347 | current_indices = []
348 | k += 1
349 | if k == max_k: break
350 |
351 | return Topocounts(self.topos, _counts_, _totals_, new_intervals, self.label_dict)
352 |
353 |
354 | def get_interval_lengths(self):
355 | return np.diff(self.intervals, axis=1)[:,0] + 1
356 |
357 | def get_weights(self):
358 | return self.counts / self.totals.reshape((len(self.totals),1))
359 |
360 | def write(self, outfile, include_topologies=True, include_header=True):
361 | nTopos = len(self.topos)
362 | if include_topologies:
363 | for x in range(nTopos): outfile.write("#topo" + str(x+1) + " " + self.topos[x].as_newick(node_labels=self.label_dict) + "\n")
364 | if include_header:
365 | outfile.write("\t".join(["topo" + str(x+1) for x in range(nTopos)]) + "\tother\n")
366 | #write counts
367 | output_array = np.column_stack([self.counts, self.totals - self.counts.sum(axis=1)]).astype(str)
368 | outfile.write("\n".join(["\t".join(row) for row in output_array]) + "\n")
369 |
370 | def write_intervals(self, outfile, chrom="chr1", include_header=True):
371 | if include_header: outfile.write("chrom\tstart\tend\n")
372 | outfile.write("\n".join(["\t".join([chrom, str(self.intervals[i,0]), str(self.intervals[i,1])]) for i in range(len(self.totals))]) + "\n")
373 |
374 | # this was the older, slower version. Now retired.
375 |
376 | #def stack_topocounts(topocounts_list, silent=False):
377 | ##get all unique intervals
378 | #unique_intervals = get_unique_intervals(np.row_stack([tc.intervals for tc in topocounts_list]))
379 |
380 | #n_unique_intervals = len(unique_intervals)
381 |
382 | #_counts_ = np.zeros((n_unique_intervals, topocounts_list[0].counts.shape[1]), dtype=int)
383 | #_totals_ = np.zeros(n_unique_intervals, dtype=int)
384 |
385 | ##for each iteration we check each of its intervals
386 | #for i in range(len(topocounts_list)):
387 | #if not silent: print(".", end="", file=sys.stderr, flush=True)
388 | #k=0
389 | #intervals = topocounts_list[i].intervals
390 | #max_k = intervals.shape[0]
391 | #for j in range(n_unique_intervals):
392 | ##for each of the unique intervals - if nested within the iteration's intervals, add to the stack
393 | #if intervals[k,0] <= unique_intervals[j,0] and intervals[k,1] >= unique_intervals[j,1]:
394 | ##k contains j
395 | #_counts_[j,:] += topocounts_list[i].counts[k]
396 | #_totals_[j] += topocounts_list[i].totals[k]
397 |
398 | #if intervals[k,1] == unique_intervals[j,1]:
399 | ##both are ending, so advance k to the next interval in this iteration
400 | #k += 1
401 | #if k == max_k: break
402 |
403 | #if not silent: print("\n", file=sys.stderr, flush=True)
404 |
405 | #return Topocounts(topocounts_list[0].topos, _counts_, _totals_, unique_intervals, topocounts_list[i].label_dict)
406 |
407 | #more efficient function from claude.ai
408 | def stack_topocounts(topocounts_list, silent=False):
409 | if len(topocounts_list) == 1:
410 | return topocounts_list[0]
411 |
412 | # Get number of columns from first object
413 | n_cols = len(topocounts_list[0].topos)
414 |
415 | # Step 1: Collect all breakpoints and create events
416 | events = [] # (position, event_type, tc_idx, interval_idx)
417 | # event_type: 0 = start, 1 = end
418 |
419 | for tc_idx, tc in enumerate(topocounts_list):
420 | for interval_idx, (start, end) in enumerate(tc.intervals):
421 | events.append((start, 0, tc_idx, interval_idx)) # start event
422 | events.append((end + 1, 1, tc_idx, interval_idx)) # end event (exclusive)
423 |
424 | # Sort events by position, then by type (starts before ends)
425 | events.sort(key=lambda x: (x[0], x[1]))
426 |
427 | # Step 2: Sweep through events to build merged intervals
428 | active_intervals = set() # Set of (tc_idx, interval_idx) tuples
429 | merged_intervals = []
430 | merged_counts = []
431 | merged_totals = []
432 |
433 | nevents = len(events)
434 |
435 | if not silent and nevents >= 100:
436 | report = True
437 | onePercent = int(np.ceil(nevents/100))
438 | else:
439 | report = False
440 |
441 | i = 0
442 | while i < nevents:
443 | current_pos = events[i][0]
444 |
445 | # Process all events at current position
446 | while i < nevents and events[i][0] == current_pos:
447 | pos, event_type, tc_idx, interval_idx = events[i]
448 |
449 | if event_type == 0: # start event
450 | active_intervals.add((tc_idx, interval_idx))
451 | else: # end event
452 | active_intervals.discard((tc_idx, interval_idx))
453 | i += 1
454 | if report and i % onePercent == 0: print(".", end="", file=sys.stderr, flush=True)
455 |
456 | # If we have active intervals, create a merged interval
457 | if active_intervals:
458 | # Find the next position where the active set changes
459 | next_pos = current_pos
460 | if i < nevents:
461 | next_pos = events[i][0]
462 | else:
463 | # No more events, extend to the last end position
464 | # Find the maximum end position among active intervals
465 | max_end = current_pos
466 | for tc_idx, interval_idx in active_intervals:
467 | end_pos = topocounts_list[tc_idx].intervals[interval_idx][1]
468 | max_end = max(max_end, end_pos)
469 | next_pos = max_end + 1
470 |
471 | # Create interval from current_pos to next_pos-1
472 | if next_pos > current_pos:
473 | interval_start = current_pos
474 | interval_end = next_pos - 1
475 |
476 | # Sum counts from all active intervals
477 | summed_counts = np.zeros(n_cols, dtype=topocounts_list[0].counts.dtype)
478 | summed_total = 0
479 | for tc_idx, interval_idx in active_intervals:
480 | summed_counts += topocounts_list[tc_idx].counts[interval_idx]
481 | summed_total += topocounts_list[tc_idx].totals[interval_idx]
482 |
483 | merged_intervals.append([interval_start, interval_end])
484 | merged_counts.append(summed_counts)
485 | merged_totals.append(summed_total)
486 |
487 | if report: print("", file=sys.stderr, flush=True)
488 |
489 | return Topocounts(topos=topocounts_list[0].topos, counts=np.array(merged_counts), totals=np.array(merged_totals), intervals=np.array(merged_intervals), label_dict=topocounts_list[0].label_dict)
490 |
491 |
492 |
493 | def get_topocounts_tskit(ts, leaf_groups=None, group_names=None, topoDict=None, include_polytomies=False):
494 |
495 | if leaf_groups is None:
496 | #use populations from ts
497 | assert list(ts.populations()) != [], "Either specify groups or provide a treesequence with embedded population data."
498 | if group_names is None: group_names = [str(pop.id) for pop in ts.populations()]
499 | leaf_groups = [[s for s in ts.samples() if str(ts.get_population(s)) == t] for t in taxonNames]
500 |
501 | ngroups = len(leaf_groups)
502 | label_dict = dict(zip(range(ngroups), group_names if group_names else ("group"+str(i) for i in range(1, ngroups+1))))
503 |
504 | if not topoDict:
505 | topoDict = make_topoDict_tskit(ngroups, include_polytomies=include_polytomies)
506 |
507 | topos = topoDict["topos"]
508 | ranks = topoDict["ranks"]
509 |
510 | intervals = np.array([tree.interval for tree in ts.trees()], dtype= int if ts.discrete_genome else float)
511 |
512 | #intervals[:,0] +=1 #convert back to 1-based - I decided to keep this out of the function, because it should be optional to adjust that
513 |
514 | counts = np.zeros((ts.num_trees, len(ranks)), dtype=int)
515 |
516 | totals = np.zeros(ts.num_trees, dtype=int)
517 |
518 | counter_generator = ts.count_topologies(leaf_groups)
519 | #counter_generator = tskit.combinatorics.treeseq_count_topologies(ts, leaf_groups) #this seems no faster when tested Jan 2025
520 |
521 | key = tuple(range(ngroups)) #we only want to count subtree topologies with a tip for each of the leaf_groups
522 |
523 | for i, counter in enumerate(counter_generator):
524 | counts[i] = [counter[key][rank] for rank in ranks]
525 | totals[i] = sum(counter[key].values())
526 |
527 | return Topocounts(topos, counts, totals, intervals, label_dict)
528 |
529 |
530 | def get_topocounts(trees, leaf_groups, max_subtrees, simplify=True, group_names=None, topoDict=None, unrooted=False):
531 |
532 | ngroups = len(leaf_groups)
533 | label_dict = dict(zip(range(ngroups), group_names if group_names else ("group"+str(i) for i in range(1, ngroups+1))))
534 |
535 | if not topoDict:
536 | topoDict = make_topoDict(ngroups, unrooted)
537 |
538 | topos = topoDict["topos"]
539 | topoIDs = topoDict["topoIDs"]
540 |
541 | counts = []
542 |
543 | totals = []
544 |
545 | intervals = []
546 |
547 | leafGroupDict = makeGroupDict(leaf_groups) if simplify else None
548 |
549 | for i,tree in enumerate(trees):
550 | topoSummary = TopologySummary(tree, leafGroupDict)
551 | counts_dict = topoSummary.get_topology_counts(leaf_groups, max_subtrees=max_subtrees, unrooted=unrooted)
552 | counts.append([counts_dict[ID] for ID in topoIDs])
553 | totals.append(sum(counts_dict.values()))
554 | intervals.append(tree.interval)
555 |
556 | return Topocounts(topos, np.array(counts, dtype=int), np.array(totals, dtype=int), np.array(intervals, dtype=float), label_dict)
557 |
558 |
559 | def get_topocounts_stacking_sticcs(der_counts, positions, ploidies, groups, max_subtrees, group_names=None,
560 | unrooted=False, second_chances=False, multi_pass=True,
561 | chrom_start=None, chrom_len=None, random_seed=None, silent=True):
562 |
563 | comboGenerator = get_combos(groups, max_subtrees = max_subtrees, ploidies=ploidies, random_seed=random_seed)
564 |
565 | topocounts_iterations = []
566 |
567 | for iteration,combo in enumerate(comboGenerator):
568 |
569 | if not silent:
570 | print(f"\nSample combo {iteration+1} indices: {', '.join([str(idx) for idx in combo])}", file=sys.stderr, flush=True)
571 | print(f"\nSample combo {iteration+1} will contribute {np.prod(ploidies[list(combo)])} subtree(s)", file=sys.stderr, flush=True)
572 | print(f"\nInferring tree sequence for combo {iteration+1}.", file=sys.stderr, flush=True)
573 |
574 | der_counts_sub = der_counts[:, combo]
575 |
576 | #ploidies
577 | ploidies_sub = ploidies[(combo,)]
578 |
579 | #Find non-missing genoypes for these individuals
580 | no_missing = np.all(der_counts_sub >= 0, axis=1)
581 |
582 | site_sum = der_counts_sub.sum(axis=1)
583 | variable = (1 < site_sum) & (site_sum < sum(ploidies_sub))
584 |
585 | usable_sites = np.where(no_missing & variable)
586 |
587 | #der counts and positions
588 | der_counts_sub = der_counts_sub[usable_sites]
589 |
590 | positions_sub = positions[usable_sites]
591 |
592 | patterns, matches, n_matches = sticcs.get_patterns_and_matches(der_counts_sub)
593 |
594 | clusters = sticcs.get_clusters(patterns, matches, positions_sub, ploidies=ploidies_sub, second_chances=second_chances,
595 | seq_start=chrom_start, seq_len=chrom_len, silent=True)
596 |
597 | trees = sticcs.infer_trees(patterns, ploidies_sub, clusters, multi_pass = multi_pass, silent=True)
598 |
599 | if not silent:
600 | print(f"\nCounting topologies for combo {iteration+1}.", file=sys.stderr, flush=True)
601 |
602 | topocounts = get_topocounts(trees, leaf_groups = make_numeric_groups(ploidies_sub),
603 | group_names=group_names, max_subtrees=max_subtrees, unrooted=unrooted) #here we specify max subtrees, but really each combination will usually have a much smaller number of subtrees than the overall max requested
604 |
605 | topocounts_iterations.append(topocounts.simplify())
606 |
607 | if not silent:
608 | print(f"\nStacking", file=sys.stderr)
609 |
610 | return stack_topocounts(topocounts_iterations, silent=silent)
611 |
612 |
613 | def makeGroupDict(groups, names=None):
614 | groupDict = {}
615 | for x in range(len(groups)):
616 | for y in groups[x]: groupDict[y] = x if not names else names[x]
617 | return groupDict
618 |
619 |
620 | def make_numeric_groups(sizes):
621 | partitions = np.cumsum(sizes)[:-1]
622 | return np.split(np.arange(sum(sizes)), partitions)
623 |
624 |
625 | ###############################################################################
626 | ###############################################################################
627 | # Functions below this point are specifically for command line stuff.
628 | # Probably not useful for the API
629 | # Chould maybe be in a separate script
630 |
631 |
632 | def parse_groups_command_line(args):
633 | group_names = args.group_names
634 |
635 | ngroups = len(group_names)
636 |
637 | assert ngroups >= 3, "Please specify at least three groups."
638 |
639 | if args.groups:
640 | assert len(args.groups) == ngroups, "Number of groups does not much number of group names"
641 | groups = [g.split(",") for g in args.groups]
642 |
643 | elif args.groups_file:
644 | groups = [[] for i in range(ngroups)]
645 | with open(args.groups_file, "rt") as gf: groupDict = dict([ln.split() for ln in gf.readlines()])
646 | for sample in groupDict.keys():
647 | try: groups[group_names.index(groupDict[sample])].append(sample)
648 | except: pass
649 |
650 | if args.verbose:
651 | for i in range(ngroups):
652 | print(f"{group_names[i]}: {', '.join(groups[i])}\n", file=sys.stderr)
653 |
654 | assert min([len(g) for g in groups]) >= 1, "Please specify at least one sample ID per group."
655 |
656 | sampleIDs = [ID for group in groups for ID in group]
657 | assert len(sampleIDs) == len(set(sampleIDs)), "Each sample should only be in one group."
658 |
659 | label_dict = dict(zip(range(ngroups), group_names))
660 |
661 | topoDict = make_topoDict(ngroups, unrooted=args.unrooted)
662 |
663 | if args.output_topos:
664 | with open(args.output_topos, "wt") as tf:
665 | tf.write("\n".join([t.as_newick(node_labels=label_dict) for t in topoDict["topos"]]) + "\n")
666 |
667 | sys.stderr.write("\n".join([t.as_newick(node_labels=label_dict) for t in topoDict["topos"]]) + "\n")
668 |
669 | return (group_names, groups)
670 |
671 |
672 | def sticcstack_command_line(args):
673 |
674 | group_names, groups = parse_groups_command_line(args)
675 | groups_numeric = make_numeric_groups([len(g) for g in groups])
676 |
677 | sampleIDs = [ID for group in groups for ID in group]
678 |
679 | #VCF file
680 | vcf = cyvcf2.VCF(args.input_vcf, samples=sampleIDs)
681 |
682 | for ID in sampleIDs: assert ID in vcf.samples, f"ID {ID} not found in vcf file header line."
683 |
684 | #the dac vcf reader does not give us the dac values in the right order, so we need to record the sample indices
685 | sample_indices = [vcf.samples.index(ID) for ID in sampleIDs]
686 |
687 | ploidies, ploidyDict = sticcs.parsePloidyArgs(args, sampleIDs)
688 |
689 | ploidies = np.array(ploidies)
690 |
691 | dac_generator = sticcs.parse_vcf_with_DC_field(vcf)
692 |
693 | chromLenDict = dict(zip(vcf.seqnames, vcf.seqlens))
694 |
695 | print(f"\nReading first chromosome...", file=sys.stderr)
696 |
697 | for chrom, positions, der_counts in dac_generator:
698 |
699 | assert positions.max() <= chromLenDict[chrom], f"\tSNP at position {positions.max()} exceeds chromosome length {chromLenDict[chrom]} for {chrom}."
700 |
701 | if args.variant_range_only:
702 | chrom_start = positions[0]
703 | chrom_len = positions[-1]
704 | else:
705 | chrom_start = 1
706 | chrom_len = chromLenDict[chrom]
707 |
708 | print(f"\nAnalysing {chrom}. {positions.shape[0]} usable SNPs found.", file=sys.stderr)
709 |
710 | #correct sample order so that groups are together
711 | der_counts = der_counts[:,sample_indices]
712 |
713 | topocounts_stacked = get_topocounts_stacking_sticcs(der_counts, positions, ploidies=ploidies, groups=groups_numeric,
714 | group_names=group_names, max_subtrees=args.max_subtrees,
715 | unrooted=args.unrooted, multi_pass=not args.single_pass,
716 | second_chances = not args.no_second_chances,
717 | chrom_start=chrom_start, chrom_len=chrom_len,
718 | silent= not args.verbose)
719 |
720 | #could potnentially add step merging identical intervals here
721 | with gzip.open(args.out_prefix + "." + chrom + ".topocounts.tsv.gz", "wt") as outfile:
722 | topocounts_stacked.write(outfile)
723 |
724 | with gzip.open(args.out_prefix + "." + chrom + ".intervals.tsv.gz", "wt") as outfile:
725 | topocounts_stacked.write_intervals(outfile, chrom=chrom)
726 |
727 | print(f"\nEnd of file reached. Looks like my work is done.", file=sys.stderr)
728 |
729 |
730 | def trees_command_line(args):
731 |
732 | group_names, groups = parse_groups_command_line(args)
733 | groups_numeric = make_numeric_groups([len(g) for g in groups])
734 | leaf_names = [ID for group in groups for ID in group]
735 |
736 | max_subtrees = np.prod([len(g) for g in groups])
737 |
738 | if args.input_format == "tskit":
739 | ts = tskit.load(args.input_file)
740 |
741 | elif args.input_format == "argweaver":
742 | with gzip.open(args.input_file, "rt") if args.input_file.endswith(".gz") else open(args.input_file, "rt") as treesfile:
743 | ts = parse_argweaver_file(treesfile, leaf_names=leaf_names)
744 |
745 | elif args.input_format == "newick":
746 | with gzip.open(args.input_file, "rt") if args.input_file.endswith(".gz") else open(args.input_file, "rt") as treesfile:
747 | ts = parse_newick_file(treesfile, leaf_names=leaf_names)
748 |
749 | topocounts = get_topocounts(ts.trees(), leaf_groups=groups_numeric, max_subtrees=max_subtrees, unrooted=args.unrooted)
750 |
751 | with gzip.open(args.out_prefix + ".topocounts.tsv.gz", "wt") as outfile:
752 | topocounts.write(outfile)
753 |
754 | if (args.input_format == "tskit" or args.input_format == "argweaver"):
755 | with gzip.open(args.out_prefix + ".intervals.tsv.gz", "wt") as outfile:
756 | topocounts.write_intervals(outfile, chrom=args.chrom_name)
757 |
758 |
759 |
760 | def main():
761 | parser = argparse.ArgumentParser(prog="twisst2")
762 |
763 | subparsers = parser.add_subparsers(title = "subcommands", dest="mode")
764 | subparsers.required = True
765 |
766 | sticcstack_parser = subparsers.add_parser("sticcstack", help="Compute topology weights from input VCF using sticcstack method")
767 |
768 | sticcstack_parser.add_argument("-i", "--input_vcf", help="Input VCF file with DC field (make thsi with sticcs prep)", action = "store", required=True)
769 | sticcstack_parser.add_argument("-o", "--out_prefix", help="Output file prefix", action = "store", required=True)
770 | sticcstack_parser.add_argument("--unrooted", help="Unroot topologies (results in fewer topologies)", action = "store_true")
771 | sticcstack_parser.add_argument("--ploidy", help="Sample ploidy if all the same. Use --ploidy_file if samples differ.", action = "store", type=int)
772 | sticcstack_parser.add_argument("--ploidy_file", help="File with samples names and ploidy as columns", action = "store")
773 | sticcstack_parser.add_argument("--max_subtrees", help="Maximum number of subtrees to consider (note that each combination of diploids represents multiple subtrees)", action = "store", type=int, required=True)
774 | sticcstack_parser.add_argument("--no_second_chances", help="Do not consider SNPs that are separated by incompatible SNPs", action='store_true')
775 | sticcstack_parser.add_argument("--single_pass", help="Single pass when building trees (only relevant for ploidy > 1, but not recommended)", action='store_true')
776 | #sticcstack_parser.add_argument("--inputTopos", help="Input file for user-defined topologies (optional)", action = "store", required = False)
777 | sticcstack_parser.add_argument("--output_topos", help="Output file for topologies used", action = "store", required = False)
778 | sticcstack_parser.add_argument("--group_names", help="Name for each group (separated by spaces)", action='store', nargs="+", required = True)
779 | sticcstack_parser.add_argument("--groups", help="Sample IDs for each individual (separated by commas), for each group (separated by spaces)", action='store', nargs="+")
780 | sticcstack_parser.add_argument("--groups_file", help="Optional file with a column for sample ID and group", action = "store", required = False)
781 | sticcstack_parser.add_argument("--variant_range_only", help="Verbose output", action="store_true")
782 | sticcstack_parser.add_argument("--verbose", help="Verbose output", action="store_true")
783 |
784 | inputtrees_parser = subparsers.add_parser("trees", help="Input trees in tskit, argweaver or newick format")
785 |
786 | inputtrees_parser.add_argument("-i", "--input_file", help="Input trees file", action = "store", required=True)
787 | inputtrees_parser.add_argument("-f", "--input_format", help="Input file format (tskit, argweaver or newick)", choices=["tskit", "argweaver", "newick"], action = "store", required=True)
788 | inputtrees_parser.add_argument("-o", "--out_prefix", help="Output file prefix", action = "store", required=True)
789 | inputtrees_parser.add_argument("--chrom_name", help="Chromosome name for output intervals file", action = "store", default="unknown_chrom")
790 | inputtrees_parser.add_argument("--unrooted", help="Unroot topologies (results in fewer topologies)", action = "store_true")
791 | #inputtrees_parser.add_argument("--inputTopos", help="Input file for user-defined topologies (optional)", action = "store", required = False)
792 | inputtrees_parser.add_argument("--output_topos", help="Output file for topologies used", action = "store", required = False)
793 | inputtrees_parser.add_argument("--group_names", help="Name for each group (separated by spaces)", action='store', nargs="+", required = True)
794 | inputtrees_parser.add_argument("--groups", help="Sample IDs for each individual (separated by commas), for each group (separated by spaces)", action='store', nargs="+")
795 | inputtrees_parser.add_argument("--groups_file", help="Optional file with a column for sample ID and group", action = "store", required = False)
796 | inputtrees_parser.add_argument("--verbose", help="Verbose output", action="store_true")
797 |
798 | ### parse arguments
799 | args = parser.parse_args()
800 |
801 | if args.mode == "sticcstack":
802 | sticcstack_command_line(args)
803 |
804 | if args.mode == "trees":
805 | trees_command_line(args)
806 |
807 |
808 |
809 |
810 | if __name__ == '__main__':
811 | main()
812 |
--------------------------------------------------------------------------------