, ...], ]
16 | """
17 |
18 | # Padding units on the left/right sides of each cell
19 | padding = 1
20 |
21 | def __init__(self, variable: Variable, parents: List[str], table_rows: List):
22 | self.variable = variable # The LHS of the table, single-variable only
23 | self.parents = parents # The RHS/body of the table
24 |
25 | self.table_rows = []
26 |
27 | latent = len(parents) - (len(table_rows) - 2)
28 |
29 | # Clean up the rows; Each is formatted as: [outcome of variable, parent_1, parent_2, ..., probability]
30 | for row in table_rows:
31 | outcome = Outcome(variable.name, row[0])
32 | p = row[1:-1]
33 |
34 | self.table_rows.append([outcome, [Outcome(v, x) for v, x in zip(parents[:-latent], p)], row[-1]])
35 |
36 | def __str__(self) -> str:
37 | """
38 | String builtin for a ConditionalProbabilityTable
39 | @return: A string representation of the table.
40 | """
41 |
42 | # Create a snazzy numpy table
43 | # Rows: 1 for a header + 1 for each row; Columns: 1 for variable, 1 for each given var, 1 for the probability
44 | rows = 1 + len(self.table_rows)
45 | columns = 1 + len(self.parents) + 1
46 |
47 | # dtype declaration is better than "str", as str only allows one character in each cell
48 | table = empty((rows, columns), dtype=' float:
87 | """
88 | Directly lookup the probability for the row corresponding to the queried outcome and given data
89 | @param outcome: The specific outcome to lookup
90 | @param given: A list of Outcome objects
91 | @return: A probability corresponding to the respective row. Raises an Exception otherwise.
92 | """
93 | for row_outcome, row_given, row_p in self.table_rows:
94 | # If the outcome for this row matches, and each outcome for the given data matches...
95 | if outcome == row_outcome and set(row_given) == set(given):
96 | return row_p # We have our answer
97 |
98 | # Iterated over all the rows and didn't find the correct one
99 | print(f"Couldn't find row: {outcome} | {', '.join(map(str, given))}")
100 | raise MissingTableRow
101 |
--------------------------------------------------------------------------------
/do/deconfounding/Do.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 | from loguru import logger
3 | from typing import Collection
4 |
5 | from ..core.Expression import Expression
6 | from ..core.Inference import inference
7 | from ..core.Model import Model
8 | from ..core.Variables import Outcome, Intervention
9 |
10 | from .Backdoor import backdoors, deconfound
11 | from .Exceptions import NoDeconfoundingSet
12 |
13 |
14 | def treat(expression: Expression, interventions: Collection[Intervention], model: Model) -> float:
15 |
16 | head = set(expression.head())
17 | body = set(expression.body())
18 |
19 | print(head, body, interventions)
20 |
21 | # If there are no Interventions, we can compute a standard query
22 | if len(interventions) == 0:
23 | return inference(expression, model)
24 |
25 | # There are interventions; may need to find some valid Z to compute
26 | else:
27 |
28 | paths = backdoors(interventions, head, model.graph(), body)
29 |
30 | # No backdoor paths; augment graph space and compute
31 | if len(paths) == 0:
32 | logger.info(f"no backdoor paths; translating into standard inference query")
33 | expression_transform = Expression(expression.head(), set(expression.body()) | set(Outcome(x.name, x.outcome) for x in interventions))
34 | logger.info(f"translated expression: {expression_transform}")
35 | logger.info(f"disabling incoming edges on graph: {[x.name for x in interventions]}")
36 | model.graph().disable_incoming(*interventions)
37 | p = inference(expression_transform, model)
38 | logger.info("resetting edge transformations")
39 | model.graph().reset_disabled()
40 | return p
41 |
42 | # Backdoor paths found; find deconfounding set to compute
43 | # Find all possible deconfounding sets, and use possible subsets
44 | logger.info("computing deconfounding sets")
45 | deconfounding_sets = deconfound(interventions, head, model.graph())
46 | logger.info(f"resulting deconfounding sets: {deconfounding_sets}")
47 |
48 | # Filter down the deconfounding sets not overlapping with our query body
49 | vertex_dcf = list(filter(lambda s: len(set(s) & {x.name for x in body}) == 0, deconfounding_sets))
50 | if len(vertex_dcf) == 0:
51 | raise NoDeconfoundingSet
52 |
53 | # Compute with every possible deconfounding set as a safety measure; ensuring they all match
54 | probability = None # Sentinel value
55 | for z_set in vertex_dcf:
56 |
57 | result = _marginalize_query(expression, interventions, z_set, model)
58 | if probability is None: # Storing first result
59 | probability = result
60 |
61 | # If results do NOT match; error
62 | assert abs(result-probability) < 0.00000001, f"Error: Distinct results: {probability} vs {result}"
63 |
64 | logger.info("{0} = {1:.5f}".format(Expression(head, set(body) | set(interventions)), probability, precision=1))
65 | return result
66 |
67 |
68 | def _marginalize_query(expression: Expression, interventions: Collection[Intervention], deconfound: Collection[str], model: Model) -> float:
69 | """
70 | Handle the modified query where we require a deconfounding set due to Interventions / treatments.
71 | @param head: The head of the query, a set containing Outcome objects
72 | @param body: The body of the query, a set containing Outcome and Intervention objects
73 | @param dcf: A set of (string) names of variables to serve as a deconfounding set, blocking all backdoor paths
74 | between the head and body
75 | @return:
76 | """
77 |
78 | head = set(expression.head())
79 | body = set(expression.body())
80 |
81 | # Augment graph (isolating interventions as roots) and create engine
82 | model.graph().disable_incoming(*interventions)
83 | as_outcomes = {Outcome(x.name, x.outcome) for x in interventions}
84 |
85 | probability = 0.0
86 |
87 | # We take every possible combination of outcomes of Z and compute each probability separately
88 | for cross in product(*[model.variable(var).outcomes for var in deconfound]):
89 |
90 | # Construct the respective Outcome list of each Z outcome cross product
91 | z_outcomes = {Outcome(x, cross[i]) for i, x in enumerate(deconfound)}
92 |
93 | # First, we do P(Y | do(X), Z)
94 | ex1 = Expression(head, body | as_outcomes | z_outcomes)
95 | logger.info(f"computing sub-query: {ex1}")
96 | p_y_x_z = inference(ex1, model)
97 |
98 | # Second, P(Z)
99 | ex2 = Expression(z_outcomes, body | as_outcomes)
100 | logger.info(f"computing sub-query: {ex2}")
101 | p_z = inference(ex2, model)
102 |
103 | probability += p_y_x_z * p_z
104 |
105 | model.graph().reset_disabled()
106 | return probability
107 |
--------------------------------------------------------------------------------
/do/core/Variables.py:
--------------------------------------------------------------------------------
1 | from re import findall, sub
2 |
3 |
4 | class Outcome:
5 | """
6 | A basic "Outcome" of a variable, representing a specific outcome such as "X = x".
7 | This does essentially act as a Pair-like.
8 | """
9 |
10 | def __init__(self, name: str, outcome: str):
11 | """
12 | Constructor for an Outcome
13 | @param name: The name of the variable. Ex: "X"
14 | @param outcome: The specific outcome of the variable. Ex: "x" or "~x"
15 | """
16 | self.name = name.strip()
17 | self.outcome = outcome.strip()
18 |
19 | def __str__(self) -> str:
20 | return self.name + " = " + self.outcome
21 |
22 | def __hash__(self) -> int:
23 | return hash(self.name + self.outcome)
24 |
25 | def __copy__(self):
26 | return Outcome(self.name, self.outcome)
27 |
28 | def copy(self):
29 | return self.__copy__()
30 |
31 | def __eq__(self, other) -> bool:
32 | if isinstance(other, str):
33 | return self.name == other
34 | return self.name == other.name and self.outcome == other.outcome and type(self) == type(other)
35 |
36 |
37 | class Variable:
38 | """
39 | Represents a basic "Variable", as part of a Conditional Probability Table or the like.
40 | Has a name, list of potential outcomes, and some list of parent variables.
41 | """
42 |
43 | def __init__(self, name: str, outcomes: list, parents: list, descendants=None, topological_order=0):
44 | """
45 | A basic Variable for use in a CPT or Causal Graph
46 | @param name: The name of the Variable, "X"
47 | @param outcomes: A list of all potential outcomes of the variable: ["x", "~x"]
48 | @param parents: A list of strings representing the names of all the parents of this Variable
49 | @param descendants: An optional set of Variables which are reachable from this Variable
50 | @param topological_order: Used in the ordering of Variables as defined by a topological sort
51 | """
52 | self.name = name.strip()
53 | self.outcomes = [outcome.strip() for outcome in outcomes]
54 | self.parents = [parent.strip() for parent in parents]
55 | self.topological_order = topological_order
56 |
57 | if descendants is None:
58 | descendants = set()
59 | self.descendants = descendants
60 |
61 | def __str__(self) -> str:
62 | return self.name + ": <" + ",".join(self.outcomes) + ">, <-- " + ",".join(self.parents)
63 |
64 | def __hash__(self) -> int:
65 | return hash(self.name + str(self.outcomes) + str(self.parents))
66 |
67 | def __eq__(self, other) -> bool:
68 | if isinstance(other, str):
69 | return self.name == other
70 |
71 | return self.name == other.name and \
72 | set(self.outcomes) == set(other.outcomes) and \
73 | set(self.parents) == set(other.parents)
74 |
75 | def __copy__(self):
76 | return Variable(self.name, self.outcomes.copy(), self.parents.copy(), descendants=self.descendants.copy())
77 |
78 | def copy(self):
79 | return self.__copy__()
80 |
81 |
82 | class Intervention(Outcome):
83 | """
84 | Represents an intervention; do(X).
85 | """
86 |
87 | def __init__(self, name: str, fixed_outcome: str):
88 | super().__init__(name, fixed_outcome)
89 |
90 | def __str__(self) -> str:
91 | return "do(" + self.name + "=" + self.outcome + ")"
92 |
93 | def __hash__(self):
94 | return hash(self.name + self.outcome)
95 |
96 | def __copy__(self):
97 | return Intervention(self.name, self.outcome)
98 |
99 | def copy(self):
100 | return self.__copy__()
101 |
102 |
103 | def parse_outcomes_and_interventions(line: str) -> set:
104 | """
105 | Take one string line and parse it into a list of Outcomes and Interventions
106 | @param line: A string representing the query
107 | @return: A list, of Outcomes and/or Interventions
108 | """
109 | # "do(X=x)", "do(X=x, Y=y)", "do(X=x), do(Y=y)" are all valid ways to write interventions
110 | interventions_preprocessed = findall(r'do\([^do]*\)', line)
111 | interventions_preprocessed = [item.strip("do(), ") for item in interventions_preprocessed]
112 | interventions = []
113 | for string in interventions_preprocessed:
114 | interventions.extend([item.strip(", ") for item in string.split(", ")])
115 |
116 | # Remove all the interventions, leaving only specific Outcomes
117 | outcomes_preprocessed = sub(r'do\([^do]*\)', '', line).strip(", ").split(",")
118 | outcomes_preprocessed = [item.strip(", ") for item in outcomes_preprocessed]
119 | outcomes = [string for string in outcomes_preprocessed if string]
120 |
121 | # Convert the outcome and intervention strings into the specific Outcome and Intervention classes
122 | outcomes = [Outcome(item.split("=")[0].strip(), item.split("=")[1].strip()) for item in outcomes]
123 | interventions = [Intervention(item.split("=")[0].strip(), item.split("=")[1].strip()) for item in interventions]
124 |
125 | together = []
126 | together.extend(outcomes)
127 | together.extend(interventions)
128 |
129 | return set(together)
130 |
--------------------------------------------------------------------------------
/tests/core/test_Graph.py:
--------------------------------------------------------------------------------
1 | from do.core.Variables import Outcome, Intervention, Variable
2 | from do.core.Graph import to_label
3 |
4 | from ..source import models
5 | graph = models["pearl-3.4.yml"]._g
6 |
7 |
8 | def test_roots():
9 | assert sum(map(lambda v: len(graph.parents(v)), graph.roots())) == 0
10 |
11 |
12 | def test_descendants():
13 | assert sum(map(lambda v: len(graph.children(v)), graph.sinks())) == 0
14 |
15 |
16 | def test_parents():
17 | graph.reset_disabled()
18 | roots = graph.roots()
19 | for vertex in graph.v:
20 | parents = graph.parents(vertex)
21 | for parent in parents:
22 | assert (parent, vertex) in graph.e
23 |
24 | if vertex in roots:
25 | assert len(parents) == 0
26 | else:
27 | assert len(parents) > 0
28 |
29 |
30 | def test_children():
31 | graph.reset_disabled()
32 | for vertex in graph.v:
33 | children = graph.children(vertex)
34 | for child in children:
35 | assert (vertex, child) in graph.e
36 |
37 | for child in children:
38 | assert vertex in graph.parents(child)
39 |
40 |
41 | def test_ancestors():
42 | graph.reset_disabled()
43 | for vertex in graph.v:
44 | ancestors = graph.ancestors(vertex)
45 | for ancestor in ancestors:
46 | assert vertex in graph.descendants(ancestor)
47 |
48 |
49 | def test_reach():
50 | graph.reset_disabled()
51 | for vertex in graph.v:
52 | descendants = graph.descendants(vertex)
53 | for descendant in descendants:
54 | assert vertex in graph.ancestors(descendant)
55 |
56 |
57 | def test_disable_outgoing():
58 |
59 | graph.reset_disabled()
60 |
61 | for v in graph.v:
62 | children = graph.children(v)
63 | descendants = graph.descendants(v)
64 | graph.disable_outgoing(v)
65 | assert len(graph.children(v)) == 0
66 | assert len(graph.descendants(v)) == 0
67 | for child in children:
68 | assert v not in graph.parents(child)
69 | for descendant in descendants:
70 | assert v not in graph.ancestors(descendant)
71 |
72 | graph.reset_disabled()
73 |
74 |
75 | def test_disable_incoming():
76 |
77 | graph.reset_disabled()
78 |
79 | for v in graph.v:
80 | parents = graph.parents(v)
81 | ancestors = graph.ancestors(v)
82 | graph.disable_incoming(v)
83 | assert len(graph.parents(v)) == 0
84 | assert len(graph.ancestors(v)) == 0
85 | for parent in parents:
86 | assert v not in graph.children(parent)
87 | for ancestor in ancestors:
88 | assert v not in graph.descendants(ancestor)
89 |
90 | graph.reset_disabled()
91 |
92 |
93 | def test_topology_sort():
94 |
95 | topology = graph.topology_sort()
96 |
97 | print(topology)
98 |
99 | for i, v in enumerate(topology):
100 | for before in topology[:i]:
101 | assert before not in graph.descendants(v)
102 |
103 | for after in topology[i:]:
104 | assert after not in graph.ancestors(v)
105 |
106 |
107 | def test_graph_copy():
108 |
109 | graph_2 = graph.copy()
110 |
111 | assert len(graph.v) == len(graph_2.v)
112 | assert len(graph.e) == len(graph_2.e)
113 |
114 | assert graph.v is not graph_2.v
115 | assert graph.e is not graph_2.e
116 |
117 | for v in graph.v:
118 | assert v in graph_2.v
119 |
120 | for v in graph_2.v:
121 | assert v in graph.v
122 |
123 | for e in graph.e:
124 | assert e in graph_2.e
125 |
126 | for e in graph_2.e:
127 | assert e in graph.e
128 |
129 |
130 | def test_without_incoming_edges():
131 |
132 | g = graph.copy()
133 |
134 | roots = g.roots()
135 | root_children = set().union(*[g.children(x) for x in roots])
136 |
137 | nop = g.without_incoming_edges(roots) # roots have no incoming; should change nothing
138 | op = g.without_incoming_edges(root_children) # sever initial roots
139 |
140 | assert g.v == nop.v and g.e == nop.e # ensure no change
141 |
142 | assert g.v == op.v
143 | assert g.e != op.e
144 | assert len(g.e) > len(op.e)
145 | assert len(op.roots()) > len(g.roots())
146 | assert op.roots() == set(g.roots()) | root_children
147 |
148 |
149 | def test_without_outgoing_edges():
150 |
151 | g = graph.copy()
152 |
153 | sinks = g.sinks()
154 | sink_parents = set().union(*[g.parents(x) for x in sinks])
155 |
156 | nop = g.without_outgoing_edges(sinks) # sinks have no outgoing; should change nothing
157 | op = g.without_outgoing_edges(sink_parents) # sever initial sinks
158 |
159 | assert g.v == nop.v and g.e == nop.e # ensure no change
160 |
161 | assert g.v == op.v
162 | assert g.e != op.e
163 | assert len(g.e) > len(op.e)
164 | assert len(op.sinks()) > len(g.sinks())
165 | assert op.sinks() == set(g.sinks()) | sink_parents
166 |
167 |
168 | def test_to_label():
169 | outcome = Outcome("Xj", "xj")
170 | intervention = Intervention("Xj", "xj")
171 | variable = Variable("Xj", [], [])
172 |
173 | assert to_label(outcome) == outcome.name
174 | assert to_label(intervention) == intervention.name
175 | assert to_label(variable) == variable.name
176 |
--------------------------------------------------------------------------------
/.github/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | braden.dubois@usask.ca.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified perid of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/do/core/Graph.py:
--------------------------------------------------------------------------------
1 | from typing import Collection, Optional, Sequence, Set, Tuple, Union
2 |
3 | from .Types import VClass, Vertex
4 |
5 |
6 | class Graph:
7 |
8 | """A basic graph, with edge control."""
9 |
10 | def __init__(self, v: Set[str], e: Set[Tuple[str, str]], topology: Optional[Sequence[Union[str, VClass]]] = None):
11 | """
12 | Initializer for a basic Graph.
13 | @param v: A set of vertices
14 | @param e: A set of edges, each edge being (source, target)
15 | @param topology: An optional sequence of vertices defining the topological ordering of the graph
16 | """
17 |
18 | self.v = v
19 | self.e = {(s.strip(), t.strip()) for s, t in e}
20 |
21 | # Declare the keys (which are vertices)
22 | self.incoming = {vertex.strip(): set() for vertex in v}
23 | self.outgoing = {vertex.strip(): set() for vertex in v}
24 |
25 | for s, t in e:
26 | self.outgoing[s].add(t)
27 | self.incoming[t].add(s)
28 |
29 | self.outgoing_disabled = set()
30 | self.incoming_disabled = set()
31 |
32 | if not topology:
33 | topology = self.topology_sort()
34 | else:
35 | topology = list(filter(lambda x: x in v, topology))
36 |
37 | self.topology_map = {vertex: index for index, vertex in enumerate(topology, start=1)}
38 |
39 | def __str__(self) -> str:
40 | """
41 | String builtin for the Graph class
42 | @return: A string representation of the given Graph instance
43 | """
44 | msg = "Vertices: " + ", ".join(sorted(i for i in self.v)) + "\n"
45 | msg += "Edges:\n" + "\n".join(" -> ".join(i for i in edge) for edge in self.e)
46 | return msg
47 |
48 | def roots(self) -> Collection[str]:
49 | """
50 | Get the roots of the the graph G.
51 | @return: A set of vertices (strings) in G that have no parents.
52 | """
53 | return set([x for x in self.v if len(self.parents(x)) == 0])
54 |
55 | def sinks(self) -> Collection[str]:
56 | """
57 | Get the sinks of the graph G.
58 | @return: A collection of string vertices in G that have no descendants.
59 | """
60 | return set([x for x in self.v if len(self.children(x)) == 0])
61 |
62 | def parents(self, v: Vertex) -> Collection[Vertex]:
63 | """
64 | Get the parents of v, which may actually be currently controlled
65 | @param v: A variable in our graph
66 | @return: All parents reachable (which would be none if being controlled)
67 | """
68 | label = to_label(v)
69 | if label in self.incoming_disabled:
70 | return set()
71 |
72 | return {p for p in self.incoming[label] if p not in self.outgoing_disabled and p not in self.outgoing[label]}
73 |
74 | def children(self, v: Vertex) -> Collection[Vertex]:
75 | """
76 | Get the children of v, which may actually be currently controlled
77 | @param v: A variable in our graph
78 | @return: All children reachable (which would be none if being controlled)
79 | """
80 | label = to_label(v)
81 | if label in self.outgoing_disabled:
82 | return set()
83 |
84 | return {c for c in self.outgoing[label] if c not in self.incoming_disabled and c not in self.incoming[label]}
85 |
86 | def ancestors(self, v: Vertex) -> Collection[Vertex]:
87 | """
88 | Get the ancestors of v, accounting for disabled vertices
89 | @param v: The vertex to find all ancestors of
90 | @return: A set of reachable ancestors of v
91 | """
92 |
93 | ancestors = set()
94 | queue = []
95 | queue.extend(self.parents(v))
96 |
97 | while queue:
98 | current = queue.pop(0)
99 | ancestors.add(current)
100 | queue.extend(self.parents(current))
101 |
102 | return ancestors
103 |
104 | def descendants(self, v: Vertex) -> Collection[Vertex]:
105 | """
106 | Get the reach of v, accounting for disabled vertices
107 | @param v: The vertex to find all descendants of
108 | @return: A set of reachable descendants of v
109 | """
110 |
111 | children = set()
112 | queue = []
113 | queue.extend(list(self.children(v)))
114 |
115 | while queue:
116 | current = queue.pop(0)
117 | children.add(current)
118 | queue.extend(list(self.children(current)))
119 |
120 | return children
121 |
122 | def disable_outgoing(self, *disable: Vertex):
123 | """
124 | Disable the given vertices' outgoing edges
125 | @param disable: Any number of vertices to disable
126 | """
127 | for v in disable:
128 | self.outgoing_disabled.add(to_label(v))
129 |
130 | def disable_incoming(self, *disable: Vertex):
131 | """
132 | Disable the given vertices' incoming edges
133 | @param disable: Any number of vertices to disable
134 | """
135 | for v in disable:
136 | self.incoming_disabled.add(to_label(v))
137 |
138 | def reset_disabled(self):
139 | """
140 | Clear and reset all the disabled edges, restoring the graph
141 | """
142 | self.outgoing_disabled.clear()
143 | self.incoming_disabled.clear()
144 |
145 | def get_topology(self, v: Vertex) -> int:
146 | """
147 | Determine the "depth" a given Variable is at in a topological sort of the graph
148 | @param v: The variable to determine the depth of
149 | @return: Some non-negative integer representing the depth of this variable
150 | """
151 | return self.topology_map[to_label(v)]
152 |
153 | def copy(self):
154 | """
155 | Public copy method; copies v, e, and the disabled sets
156 | @return: A copied Graph
157 | """
158 | return self.__copy__()
159 |
160 | def __copy__(self):
161 | """
162 | Copy builtin allowing the Graph to be copied
163 | @return: A copied Graph
164 | """
165 | copied = Graph(self.v.copy(), set(self.e.copy()))
166 | copied.incoming_disabled = self.incoming_disabled.copy()
167 | copied.outgoing_disabled = self.outgoing_disabled.copy()
168 | return copied
169 |
170 | def __getitem__(self, v: set):
171 | """
172 | Compute a subset V of some Graph G.
173 | :param v: A set of variables in G.
174 | :return: A Graph representing the subgraph G[V].
175 | """
176 | return Graph({s for s in self.v if s in v}, {s for s in self.e if s[0] in v and s[1] in v})
177 |
178 | def descendant_first_sort(self, variables: Collection[Vertex]) -> Sequence[Vertex]:
179 | """
180 | A helper function to "sort" a list of Variables/Outcomes/Interventions such that no element has a
181 | "parent"/"ancestor" to its left
182 | @param variables: A list of any number of Variable/Outcome/Intervention instances
183 | @return: A sorted list, such that any instance has no ancestor earlier in the list
184 | """
185 | return sorted(variables, key=lambda v: self.get_topology(v))
186 |
187 | def topology_sort(self) -> Sequence[str]:
188 |
189 | topology = []
190 | v = self.v.copy()
191 | e = self.e.copy()
192 |
193 | while len(v) > 0:
194 |
195 | roots = set(filter(lambda t: not any((s, t) in e for s in v), v))
196 | assert len(roots) > 0
197 |
198 | topology.extend(sorted(list(roots)))
199 | v -= roots
200 | e -= set(filter(lambda edge: edge[0] in roots, e))
201 |
202 | return topology
203 |
204 | def without_incoming_edges(self, x: Collection[Vertex]):
205 |
206 | v = self.v.copy()
207 | e = {(s, t) for (s, t) in self.e if t not in x}
208 |
209 | return Graph(v, e)
210 |
211 | def without_outgoing_edges(self, x: Collection[Vertex]):
212 |
213 | v = self.v.copy()
214 | e = {(s, t) for (s, t) in self.e if s not in x}
215 |
216 | return Graph(v, e)
217 |
218 |
219 | def to_label(item: VClass) -> str:
220 | """
221 | Convert a variable to its string name, if not already provided as such
222 | @param item: The item to convert, either a string (done) or some Variable
223 | @return: A string name of the given item, if not already provided as a string
224 | """
225 | return item.strip("'") if isinstance(item, str) else item.name.strip("'")
226 |
--------------------------------------------------------------------------------
/do/identification/LatentGraph.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 | from typing import List, Iterable, Set, Tuple
3 |
4 | from ..core.Graph import Graph
5 |
6 |
7 | class LatentGraph(Graph):
8 |
9 | def __init__(self, vertices: Set[str], edges: Set[Tuple[str, str]], e_bidirected: Set[Tuple[str, str]], fixed_topology: List[str] = None):
10 | super().__init__(vertices, edges, fixed_topology)
11 | self.e_bidirected = e_bidirected.copy()
12 | self.V = vertices
13 | self.C = self.make_components()
14 |
15 | # Allows passing a topology down to a subgraph
16 | if fixed_topology:
17 |
18 | # filter any vertices from the given topology that don't exist as vertices in the graph
19 | filtered_topology = [x for x in fixed_topology if x in vertices]
20 |
21 | # ensure topology fully represents the graph
22 | assert all(x in vertices for x in filtered_topology), "vertex in the given topology is not in the graph!"
23 | assert all(x in filtered_topology for x in vertices), "vertex in the graph is not in the given topology!"
24 |
25 | self.v_Pi = filtered_topology
26 |
27 | # Otherwise, generate it
28 | else:
29 | self.v_Pi = self.__kahns()
30 |
31 | def __str__(self):
32 | return f"Graph: V = {', '.join(self.v)}, E = {', '.join(list(map(str, self.e)))}, E (Bidirected) = {', '.join(list(map(str, self.e_bidirected)))}"
33 |
34 | def __getitem__(self, v: Set[str]):
35 | e = {(s, t) for (s, t) in self.e if s in v and t in v}
36 | e_bidirected = {(s, t) for (s, t) in self.e_bidirected if s in v and t in v}
37 | return LatentGraph(self.v & v, e, e_bidirected, self.v_Pi)
38 |
39 | def __eq__(self, other):
40 | if not isinstance(other, LatentGraph):
41 | return False
42 |
43 | return self.v == other.v and self.e == other.e and \
44 | all([(e[0], e[1]) in other.e_bidirected or (e[1], e[0]) in other.e_bidirected for e in self.e_bidirected]) and \
45 | all([(e[0], e[1]) in self.e_bidirected or (e[1], e[0]) in self.e_bidirected for e in other.e_bidirected])
46 |
47 | def biadjacent(self, v: str):
48 | return {e[0] if e[0] != v else e[1] for e in self.e_bidirected if v in e}
49 |
50 | def ancestors(self, y: Set[str]):
51 | ans = y.copy()
52 | for v in y:
53 | for p in self.parents(v):
54 | ans |= self.ancestors({p})
55 | return ans
56 |
57 | # puts nodes in topological ordering
58 | def __kahns(self):
59 |
60 | edges = self.e.copy()
61 | vertices = self.v.copy()
62 | v_Pi = []
63 |
64 | s = vertices - ({e[0] for e in edges} | {e[1] for e in edges})
65 | s |= set([e[0] for e in edges if e[0] not in {g[1] for g in edges}])
66 | s = list(s)
67 |
68 | while s:
69 | n = s.pop()
70 | v_Pi.append(n)
71 |
72 | ms = {e[1] for e in edges if e[0] == n}
73 | for m in ms:
74 | edges.remove((n, m))
75 | if {e for e in edges if e[1] == m} == set():
76 | s.append(m)
77 |
78 | return v_Pi
79 |
80 | def make_components(self):
81 |
82 | ans = []
83 | all_v = self.v.copy()
84 | visited = set()
85 |
86 | while all_v:
87 | start = all_v.pop()
88 | component = []
89 | q = [start]
90 |
91 | while q:
92 | v = q.pop(0)
93 | if v not in visited:
94 | visited.add(v)
95 | component.append(v)
96 | q.extend([vs for vs in self.biadjacent(v) if vs not in visited])
97 |
98 | if component:
99 | ans.append(set(component))
100 |
101 | return ans
102 |
103 | def without_incoming(self, x: Iterable[str]):
104 | # return Graph(self.Edges - {edge for edge in self.Edges if edge[1] in x and edge[2] == "->"}, self.V)
105 | return LatentGraph(self.v, self.e - {e for e in self.e if e[1] in x}, self.e_bidirected, self.v_Pi)
106 |
107 | def collider(self, v1, v2, v3):
108 | return v1 in self.V and v2 in self.V and v3 in self.V and v1 in self.parents(v2) and v3 in self.children(v2)
109 |
110 | def all_paths(self, x: Iterable[str], y: Iterable[str]):
111 |
112 | def path_list(s, t): # returns all paths from X to Y regardless of direction of link (no bd links)
113 |
114 | # generate a fake variable to represent unobservable variables
115 | UNOBSERVABLE = "U"
116 | while UNOBSERVABLE in self.V:
117 | UNOBSERVABLE += "U"
118 |
119 | from_s_s = [[s]]
120 | ans = []
121 | while from_s_s:
122 | from_s = from_s_s.pop(0)
123 |
124 | # Directed links
125 | for each in set(self.parents(from_s[-1])) | set(self.children(from_s[-1])):
126 | if each == t:
127 | path = from_s.copy()
128 | path.append(t)
129 | ans.append(path)
130 | elif each not in from_s:
131 | r = from_s.copy()
132 | r.append(each)
133 | from_s_s.append(r)
134 |
135 | # Bidirected links
136 | for each in self.biadjacent(from_s[-1]):
137 | if each == t:
138 | path = from_s.copy()
139 | path.append(UNOBSERVABLE)
140 | path.append(t)
141 | ans.append(path)
142 | elif each not in from_s:
143 | r = from_s.copy()
144 | r.append(UNOBSERVABLE)
145 | r.append(each)
146 | from_s_s.append(r)
147 | return ans
148 |
149 | return [path_list(q, w) for q, w in product(x, y)]
150 |
151 | def ci(self, x: Set[str], y: Set[str], z: Set[str]):
152 | paths = self.all_paths(x, y)
153 | for path_pair in paths:
154 | for path in path_pair:
155 | broke = False
156 | for idx, element in enumerate(path):
157 | if 0 < idx < len(path) - 1:
158 | if self.collider(path[idx - 1], element, path[idx + 1]):
159 | return False
160 | if element in z:
161 | broke = True
162 | break
163 | if not broke:
164 | return False
165 | return True
166 |
167 |
168 | def latent_transform(g: Graph, u: Set[str]):
169 |
170 | V = g.v.copy()
171 | E = set(g.e.copy())
172 | E_Bidirected = set()
173 |
174 | Un = u.copy()
175 |
176 | # Collapse unobservable variables, such as U1 -> U2 -> V ==> U1 -> V
177 | reduction = True
178 | while reduction:
179 | reduction = False
180 |
181 | remove = set()
182 | for un in Un:
183 |
184 | parents = [edge[0] for edge in E if edge[1] == un] # Edges : parent -> u
185 | children = [edge[1] for edge in E if edge[0] == un] # Edges : u -> child
186 |
187 | # All parents are unobservable, all children are observable, at least one parent
188 | if all(x in u for x in parents) and len(parents) > 0 and all(x not in u for x in children):
189 | reduction = True
190 |
191 | # Remove edges from parents to u
192 | for parent in parents:
193 | E.remove((parent, un))
194 |
195 | # Remove edges from u to children
196 | for child in children:
197 | E.remove((un, child))
198 |
199 | # Replace with edge from edge parent to each child
200 | for cr in product(parents, children):
201 | E.add((cr[0], cr[1]))
202 |
203 | # U can be removed entirely from graph
204 | remove.add(un)
205 |
206 | V -= remove
207 | Un -= remove
208 |
209 | # Convert all remaining unobservable to a list to iterate through
210 | Un = list(Un)
211 |
212 | # Replace each remaining unobservable with bi-directed arcs between its children
213 | while len(Un) > 0:
214 |
215 | # Take one "current" unobservable to remove, and remove it from the graph entirely
216 | cur = Un.pop()
217 | V.remove(cur)
218 |
219 | assert len([edge for edge in E if edge[1] == cur]) == 0, \
220 | "Unobservable still had parent left."
221 |
222 | # All outgoing edges of this unobservable
223 | child_edges = {edge for edge in E if edge[0] == cur}
224 | E -= child_edges
225 |
226 | # Replace all edges from this unobservable to its children with bidirected arcs
227 | child_edges = list(child_edges)
228 | for i in range(len(child_edges)):
229 | a, b = child_edges[i], child_edges[(i + 1) % len(child_edges)]
230 | E_Bidirected.add((a[1], b[1]))
231 |
232 | print(V, u)
233 | return LatentGraph(V, E, E_Bidirected, [x for x in g.topology_sort() if x in V - u])
234 |
--------------------------------------------------------------------------------
/do/deconfounding/Backdoor.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 | from typing import Collection, List, Optional
3 |
4 | from ..core.Graph import Graph
5 | from ..core.Types import Path, Vertex
6 | from ..core.Exceptions import IntersectingSets
7 |
8 | from ..core.helpers import disjoint, minimal_sets, power_set
9 |
10 |
11 | def backdoors(src: Collection[Vertex], dst: Collection[Vertex], graph: Graph, dcf: Optional[Collection[Vertex]] = None) -> Collection[Path]:
12 | """
13 | Get all possible backdoor paths between some source set of vertices in the internal graph to any vertices in
14 | some destination set of vertices. A given (possibly empty) set of deconfounding vertices may serve to block, or
15 | even open, some backdoor paths.
16 | @param src: The source set of (string) vertices to search for paths from
17 | @param dst: The destination set of (string) vertices to search from src towards.
18 | @param dcf: An optional set of (string) vertices that may serve as a sufficient deconfounding set to block or open
19 | backdoor paths.
20 | @return: A list of lists, where each sublist contains a backdoor path, the first and last element being a
21 | vertex from src and dst, respectively, with all vertices between representing the path. All elements are
22 | string vertices.
23 | """
24 |
25 | src_str = str_map(src)
26 | dst_str = str_map(dst)
27 | dcf_str = str_map(dcf) if dcf else set()
28 |
29 | if not disjoint(src_str, dst_str, dcf_str):
30 | raise IntersectingSets
31 |
32 | paths = []
33 |
34 | # Use the product of src, dst to try each possible pairing
35 | for s, t in product(src_str, dst_str):
36 | paths += _backdoor_paths_pair(s, t, graph, dcf_str)
37 |
38 | return paths
39 |
40 |
41 | def deconfound(src: Collection[Vertex], dst: Collection[Vertex], graph: Graph) -> Collection[Collection[Vertex]]:
42 |
43 | src_str = str_map(src)
44 | dst_str = str_map(dst)
45 |
46 | # Can't use anything in src, dst, or any descendant of any vertex in src as a deconfounding/blocking vertex
47 | disallowed_vertices = src_str | dst_str | set().union(*[graph.descendants(s) for s in src_str])
48 |
49 | valid_deconfounding_sets = list()
50 |
51 | # Candidates deconfounding sets remaining are the power set of all the possible remaining vertices
52 | for tentative_dcf in power_set(graph.v - disallowed_vertices):
53 |
54 | # Tentative, indicating that no specific cross product in this subset has yet yielded any backdoor paths
55 | any_backdoor_paths = False
56 |
57 | # Cross represents one (x in X, y in Y) tuple
58 | for s, t in product(src_str, dst_str):
59 |
60 | # Get any/all backdoor paths for this particular pair of vertices in src,dst with given potential
61 | # deconfounding set
62 | if len(_backdoor_paths_pair(s, t, graph, set(tentative_dcf))) > 0:
63 | any_backdoor_paths = True
64 | break
65 |
66 | # None found in any cross product -> Valid subset
67 | if not any_backdoor_paths:
68 | valid_deconfounding_sets.append(tentative_dcf)
69 |
70 | return list(minimal_sets(*valid_deconfounding_sets))
71 |
72 |
73 | def all_paths_cumulative(s: str, t: str, path: list, path_list: list, graph: Graph) -> Collection[Path]:
74 | """
75 | Return a list of lists of all paths from a source to a target, with conditional movement from child to parent,
76 | or parent to child.
77 | This is a modified version of the graph-traversal algorithm provided by Dr. Eric Neufeld.
78 | @param s: A source (string) vertex defined in the graph.
79 | @param t: A target (string) destination vertex defined in the graph.
80 | @param path: A list representing the current path at any given point in the traversal.
81 | @param path_list: A list which will contain lists of paths from s to t.
82 | @return: A list of lists of Variables, where each sublist denotes a path from s to t .
83 | """
84 | if s == t:
85 | return path_list + [path + [t]]
86 | if s not in path:
87 | for child in graph.children(s):
88 | path_list = all_paths_cumulative(child, t, path + [s], path_list, graph)
89 | return path_list
90 |
91 |
92 | def independent(src: Collection[Vertex], dst: Collection[Vertex], dcf: Optional[Collection[Vertex]], graph: Graph) -> bool:
93 | """
94 | Helper function that makes some do_calculus logic more readable; determine if two sets are independent, given
95 | some third set.
96 | @param src: A source set (of strings) X, to be independent from Y
97 | @param dst: A destination set (of strings) Y, to be independent from X
98 | @param dcf: A deconfounding set (of strings) Z, to block paths between X and Y
99 | @return: True if there are no backdoor paths and no straight-line paths, False otherwise
100 | """
101 |
102 | src_str = str_map(src)
103 | dst_str = str_map(dst)
104 | dcf_str = str_map(dcf) if dcf else set()
105 |
106 | # Not independent if there are any unblocked backdoor paths
107 | if len(backdoors(src_str, dst_str, graph, dcf_str)) > 0:
108 | return False
109 |
110 | # Ensure no straight-line variables from any X -> Y or Y -> X
111 | for s, t in product(src_str, dst_str):
112 | if len(all_paths_cumulative(s, t, [], [], graph)) != 0:
113 | return False # x -> y
114 | if len(all_paths_cumulative(t, s, [], [], graph)) != 0:
115 | return False # y -> x
116 |
117 | # No paths, must be independent
118 | return True
119 |
120 |
121 | def _backdoor_paths_pair(s: Collection[str], t: Collection[str], graph: Graph, dcf: Collection[str]) -> List[Path]:
122 | """
123 | Find all backdoor paths between any particular pair of vertices in the loaded graph
124 | @param s: A source (string) vertex in the graph
125 | @param t: A destination (string) vertex in the graph
126 | @param dcf: A set of (string) variables, by which movement through any variable is controlled. This can serve
127 | as a sufficient "blocking" set, or may open additional backdoor paths
128 | @return Return a list of lists, where each sublist is a path of string vertices connecting s and t.
129 | Endpoints s and t are the first and last elements of any sublist.
130 | """
131 |
132 | def get_backdoor_paths(cur: str, path: list, path_list: list, previous="up") -> list:
133 | """
134 | Return a list of lists of all paths from a source to a target, with conditional movement of either
135 | child to parent or parent to child. This may include an edge case that is not a backdoor path, which
136 | is filtered in the parent function, otherwise all paths will be backdoor paths.
137 | This is a heavily modified version of the graph-traversal algorithm provided by Dr. Eric Neufeld.
138 | @param cur: The current (string) vertex we are at in a traversal.
139 | @param path: The current path from s, our source.
140 | @param path_list: A list of lists, each sublist being a path discovered so far.
141 | @param previous: Whether moving from the previous variable to current we moved "up" (child to parent) or
142 | "down" (from parent to child); this movement restriction is involved in backdoor path detection
143 | @return: A list of lists, where each sublist is a path from s to t.
144 | """
145 |
146 | # Reached target
147 | if cur == t:
148 | return path_list + [path + [t]]
149 |
150 | # No infinite loops
151 | if cur not in path:
152 |
153 | if previous == "down":
154 |
155 | # We can ascend on a controlled collider, OR an ancestor of a controlled collider
156 | if cur in dcf or any(map(lambda v: v in dcf, graph.descendants(cur))):
157 | for parent in graph.parents(cur):
158 | path_list = get_backdoor_paths(parent, path + [cur], path_list, "up")
159 |
160 | # We can *continue* to descend on a non-controlled variable
161 | if cur not in dcf:
162 | for child in graph.children(cur):
163 | path_list = get_backdoor_paths(child, path + [cur], path_list, "down")
164 |
165 | if previous == "up" and cur not in dcf:
166 |
167 | # We can ascend on a non-controlled variable
168 | for parent in graph.parents(cur):
169 | path_list = get_backdoor_paths(parent, path + [cur], path_list, "up")
170 |
171 | # We can descend on a non-controlled reverse-collider
172 | for child in graph.children(cur):
173 | path_list = get_backdoor_paths(child, path + [cur], path_list, "down")
174 |
175 | return path_list
176 |
177 | # Get all possible backdoor paths
178 | backdoor_paths = get_backdoor_paths(s, [], [])
179 |
180 | # Filter out the paths that don't "enter" x; see the definition of a backdoor path
181 | return list(filter(lambda l: len(l) > 2 and l[0] in graph.children(l[1]) and l[1] != t, backdoor_paths))
182 |
183 |
184 | def str_map(to_filter: Collection[Vertex]):
185 | return set(map(lambda v: v if isinstance(v, str) else v.name, to_filter))
186 |
--------------------------------------------------------------------------------
/tests/identification/test_LatentGraph.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 |
3 | from do.identification.Identification import Identification, simplify_expression
4 | from do.identification.LatentGraph import LatentGraph
5 | from do.identification.PExpression import PExpression, TemplateExpression
6 |
7 |
8 | def parse_graph_string(graph_string: str) -> LatentGraph:
9 |
10 | from re import split
11 |
12 | arrows = ["<->", "<-", "->"]
13 |
14 | splits = graph_string.strip().split(".")[:-1]
15 | print(splits)
16 |
17 | e = set()
18 | v = set()
19 | e_b = set()
20 |
21 | for item in splits:
22 |
23 | # This would be an item like "X." or "X,Y". in which the vertices exist, but don't have any edges.
24 | if not any(arrow in item for arrow in arrows):
25 | v.update(item.split(","))
26 |
27 | parse = split(f'({"|".join(arrows)})', item)
28 |
29 | for i in range(1, len(parse), 2):
30 |
31 | # Left and Right are comma-separated lists of values, arrow being the "->" -style arrow joining them.
32 | left, arrow, right = parse[i-1].split(","), parse[i], parse[i+1].split(",")
33 |
34 | # Add all vertices into V
35 | v.update(left)
36 | v.update(right)
37 |
38 | for s, t in product(left, right):
39 |
40 | if arrow == "<-":
41 | e.add((t, s))
42 |
43 | elif arrow == "->":
44 | e.add((s, t))
45 |
46 | elif arrow == "<->":
47 | e_b.add((s, t))
48 |
49 | else:
50 | print("Invalid Arrow Type:", arrow)
51 |
52 | return LatentGraph(v, e, e_b)
53 |
54 |
55 | #########################################
56 | # graph 1
57 | #########################################
58 |
59 | g_1 = LatentGraph({'C', 'S', 'M'}, {('M', 'S'), ('S', 'C'), ('M', 'C')}, set())
60 | g_1_string = "M->S.S,M->C."
61 |
62 | g1_q1 = ({'C'}, {'S'})
63 | g1_q2 = ({'C'}, {'M'})
64 | g1_q3 = ({'C'}, {'S', 'M'})
65 |
66 | g1_a1 = "C | S = "
67 | g1_a2 = "C | M = "
68 | g1_a3 = "C | S, M = [C|M,S]"
69 |
70 | g1_queries = [g1_q1, g1_q2, g1_q3]
71 | g1_answers = [g1_a1, g1_a2, g1_a3]
72 |
73 |
74 |
75 | #########################################
76 | # queries - graph 2
77 | #########################################
78 |
79 | g_2 = LatentGraph({'A', 'B', 'C', 'D'}, {('A', 'B'), ('A', 'C'), ('B', 'D'), ('C', 'D')}, set())
80 | g_2_string = "A->B,C.B,C->D."
81 |
82 | g2_q1 = ({'D'}, {'A'})
83 | g2_q2 = ({'D'}, {'B'})
84 | g2_q3 = ({'D'}, {'C'})
85 | g2_q4 = ({'D'}, {'B', 'C'})
86 |
87 | g2_a1 = "D | A = "
88 | g2_a2 = "D | B = "
89 | g2_a3 = "D | C = "
90 | g2_a4 = "D | C, B = [D|A,B,C]"
91 |
92 | g2_queries = [g2_q1, g2_q2, g2_q3, g2_q4]
93 | g2_answers = [g2_a1, g2_a2, g2_a3, g2_a4]
94 |
95 | #########################################
96 | # queries - graph 3
97 | #########################################
98 |
99 | g_3 = LatentGraph({'B', 'C', 'D'}, {('B', 'D'), ('C', 'D')}, {('B', 'C')})
100 | g_3_string = "B<->C.B,C->D."
101 |
102 | g3_q1 = ({'D'}, {'B'})
103 | g3_q2 = ({'D'}, {'C'})
104 | g3_q3 = ({'D'}, {'B', 'C'})
105 |
106 | g3_a1 = "D | B = "
107 | g3_a2 = "D | C = "
108 | g3_a3 = "D | C, B = [D|B,C]"
109 |
110 | g3_queries = [g3_q1, g3_q2, g3_q3]
111 | g3_answers = [g3_a1, g3_a2, g3_a3]
112 |
113 | #########################################
114 | # queries - graph 4
115 | #########################################
116 |
117 | g_4 = LatentGraph({'S', 'T', 'C'}, {('S', 'T'), ('T', 'C')}, {('S', 'C')})
118 | g_4_string = "S->T->C.S<->C."
119 |
120 | g4_q1 = ({'C'}, {'S'})
121 |
122 | g4_a1 = "C | S = "
123 |
124 | g4_queries = [g4_q1]
125 | g4_answers = [g4_a1]
126 |
127 | #########################################
128 | # queries - graph 5
129 | #########################################
130 |
131 | g_5 = LatentGraph({"X", "Y", "Z1", "Z2", "Z3"},
132 | {("Z1", "Z2"), ("X", "Z2"), ("Z2", "Z3"), ("X", "Y"), ("Z2", "Y"), ("Z3", "Y")}, {("X", "Z1"), ("Z1", "Z3")}
133 | )
134 | g_5_string = "X<->Z1<->Z3.X,Z1->Z2->Z3.X,Z2,Z3->Y."
135 |
136 | g5_q1 = ({"Y"}, {"X"}) # paper: Sum_{X} [P(Z3 | Z2, Z1, X), P(Z1 | X), P(X)]
137 | g5_q2 = ({"Y"}, {"X", "Z1", "Z2", "Z3"})
138 |
139 | g5_a1 = "Y | X = "
140 | g5_a2 = "Y | Z2, Z1, X, Z3 = [Y|X,Z2,Z3]"
141 |
142 | g5_queries = [g5_q1, g5_q2]
143 | g5_answers = [g5_a1, g5_a2]
144 |
145 | #########################################
146 | # queries - graph 6
147 | #########################################
148 |
149 | g_6 = LatentGraph({"X", "Y1", "Y2", "W1", "W2"},
150 | {("W1", "X"), ("X", "Y1"), ("W2", "Y2")}, {("W1", "W2"), ("W1", "Y1"), ("W2", "X")}
151 | )
152 | g_6_string = "Y1<->W1<->W2<->X.W1->X->Y1.W2->Y2."
153 |
154 | g6_q1 = ({"Y1", "Y2"}, {"X"})
155 | g6_a1 = "Y2, Y1 | X = [W2] "
156 |
157 | g6_queries = [g6_q1]
158 | g6_answers = [g6_a1]
159 |
160 | #########################################
161 | # queries - graph 7
162 | #########################################
163 |
164 | g_7 = LatentGraph({"Z1", "Z2", "W", "X", "Y"},
165 | {("Z2", "X"), ("X", "W"), ("W", "Y"), ("Z1", "Y")}, {("Z1", "Z2"), ("Z2", "W"), ("Z1", "W")}
166 | )
167 | g_7_string = "Z1<->Z2<->W<->Z1.Z2->X->W->Y<-Z1."
168 |
169 | g7_q1 = ({"Y"}, {"X"})
170 | g7_a1 = "Y | X = "
171 |
172 | g7_queries = [g7_q1]
173 | g7_answers = [g7_a1]
174 |
175 | #########################################
176 | # queries - graph 8
177 | #########################################
178 |
179 | g_8 = LatentGraph({'S1', 'T1', 'C1', 'S2', 'T2', 'C2', 'C'}, {('S1', 'T1'), ('T1', 'C1'), ('S2', 'T2'), ('T2', 'C2'), ('C2', 'C'), ('C1', 'C')}, {('S1', 'C1'), ('S2', 'C2')})
180 |
181 | g_8_string = "S1->T1->C1<->S1.S2->T2->C2<->S2.C2->C<-C1."
182 |
183 | g8_q1 = ({'C1'}, {'S1'})
184 | g8_q2 = ({'C2'}, {'S2'})
185 | g8_q3 = ({'C1', 'C2'}, {'S1', 'S2'})
186 | g8_q4 = ({'C'}, {'S1', 'S2'})
187 |
188 | g8_a1 = ">"
189 | g8_a2 = ">"
190 | g8_a3 = ">"
191 | g8_a4 = ">"
192 |
193 | g8_queries = [g8_q1, g8_q2, g8_q3, g8_q4]
194 | g8_answers = [g8_a1, g8_a2, g8_a3, g8_a4]
195 |
196 | #########################################
197 |
198 | all_tests = [
199 | {
200 | "queries": g1_queries,
201 | "answers": g1_answers,
202 | "g": g_1,
203 | "as_string": g_1_string,
204 | }, {
205 | "queries": g2_queries,
206 | "answers": g2_answers,
207 | "g": g_2,
208 | "as_string": g_2_string,
209 | }, {
210 | "queries": g3_queries,
211 | "answers": g3_answers,
212 | "g": g_3,
213 | "as_string": g_3_string,
214 | }, {
215 | "queries": g4_queries,
216 | "answers": g4_answers,
217 | "g": g_4,
218 | "as_string": g_4_string,
219 | }, {
220 | "queries": g5_queries,
221 | "answers": g5_answers,
222 | "g": g_5,
223 | "as_string": g_5_string,
224 | }, {
225 | "queries": g6_queries,
226 | "answers": g6_answers,
227 | "g": g_6,
228 | "as_string": g_6_string,
229 | }, {
230 | "queries": g7_queries,
231 | "answers": g7_answers,
232 | "g": g_7,
233 | "as_string": g_7_string,
234 | }, {
235 | "queries": g8_queries,
236 | "answers": g8_answers,
237 | "g": g_8,
238 | "as_string": g_8_string,
239 | }
240 | ]
241 |
242 |
243 | def test_GraphParse1():
244 | assert g_1 == parse_graph_string(g_1_string)
245 |
246 | def test_GraphParse2():
247 | assert g_2 == parse_graph_string(g_2_string)
248 |
249 | def test_GraphParse3():
250 | assert g_3 == parse_graph_string(g_3_string)
251 |
252 | def test_GraphParse4():
253 | assert g_4 == parse_graph_string(g_4_string)
254 |
255 | def test_GraphParse5():
256 | assert g_5 == parse_graph_string(g_5_string)
257 |
258 | def test_GraphParse6():
259 | assert g_6 == parse_graph_string(g_6_string)
260 |
261 | def test_GraphParse7():
262 | assert g_7 == parse_graph_string(g_7_string)
263 |
264 | def test_GraphParse8():
265 | assert g_8 == parse_graph_string(g_8_string)
266 |
267 |
268 | def tests():
269 |
270 | for index, problem_set in enumerate(all_tests, start=1):
271 |
272 | print("*" * 20, f"Beginning Graph {index}", "*" * 20)
273 |
274 | g = problem_set["g"]
275 | p = PExpression([], [TemplateExpression(x, list(g.parents(x))) for x in g.V])
276 |
277 | # Verify Graph-Parsing
278 | g_str = problem_set["as_string"]
279 |
280 | print(f"Graph String: {index}")
281 | print(g_str)
282 | parsed = parse_graph_string(g_str)
283 |
284 | print("Original:", g)
285 | print(" Parsed:", parsed)
286 | assert g == parsed
287 |
288 | # Verify ID
289 | for i, (query, answer) in enumerate(zip(problem_set["queries"], problem_set["answers"]), start=1):
290 |
291 | y, x = query
292 |
293 | query_str = f"{', '.join(y)} | {', '.join(x)}"
294 | print(f"Beginning problem ({i}): {query_str}")
295 | result = Identification(y, x, p, g, True)
296 | simplify = simplify_expression(result, g)
297 |
298 | print("*********** Proof")
299 | print(result.proof())
300 |
301 | print("*********** Proof (Simplified)")
302 | print(simplify.proof())
303 |
--------------------------------------------------------------------------------
/do/core/Inference.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 | from loguru import logger
3 | from typing import Collection
4 |
5 | from .Exceptions import ExogenousNonRoot, ProbabilityIndeterminableException
6 | from .Expression import Expression
7 | from .Model import Model
8 | from .Variables import Outcome, Intervention
9 |
10 | from .helpers import within_precision
11 |
12 |
13 | def inference(expression: Expression, model: Model):
14 |
15 | def _compute(head: Collection[Outcome], body: Collection[Intervention], depth=0) -> float:
16 | """
17 | Compute the probability of some head given some body
18 | @param head: A list of some number of Outcome objects
19 | @param body: A list of some number of Outcome objects
20 | @param depth: Used for horizontal offsets in outputting info
21 | @return: A probability between [0.0, 1.0]
22 | @raise ProbabilityIndeterminableException if the result cannot be computed for any reason
23 | """
24 |
25 | ###############################################
26 | # Begin with bookkeeping / error-checking #
27 | ###############################################
28 |
29 | current_expression = Expression(head, body)
30 | logger.info(f"query: {current_expression}")
31 |
32 | # If the calculation for this contains two separate outcomes for a variable (Y = y | Y = ~y), 0
33 | if contradictory_outcome_set(head + body):
34 | logger.error("two separate outcomes for one variable, P = 0.0")
35 | return 0.0
36 |
37 | ###############################################
38 | # Reverse product rule #
39 | # P(y, x | ~z) = P(y | x, ~z) * P(x | ~z) #
40 | ###############################################
41 |
42 | if len(head) > 1:
43 | logger.info(f"applying reverse product rule to {current_expression}")
44 |
45 | result_1 = _compute(head[:-1], [head[-1]] + body, depth+1)
46 | result_2 = _compute([head[-1]], body, depth+1)
47 | result = result_1 * result_2
48 |
49 | logger.success(f"{current_expression} = {result}")
50 | return result
51 |
52 | ###############################################
53 | # Attempt direct lookup #
54 | ###############################################
55 |
56 | if set(model.variable(head[0].name).parents) == set(v.name for v in body):
57 | logger.info(f"querying table for: {current_expression}")
58 | table = model.table(head[0].name) # Get table
59 | probability = table.probability_lookup(head[0], body) # Get specific row
60 | logger.success(f"{current_expression} = {probability}")
61 |
62 | return probability
63 | else:
64 | logger.info("no direct table found")
65 |
66 | ##################################################################
67 | # Easy identity rule; P(X | X) = 1, so if LHS ⊆ RHS, P = 1.0 #
68 | ##################################################################
69 |
70 | if set(head).issubset(set(body)):
71 | logger.success(f"identity rule: X|X = 1.0, therefore {current_expression} = 1.0")
72 | return 1.0
73 |
74 | #################################################
75 | # Bayes' Rule #
76 | # Detect children of the LHS in the RHS #
77 | # p(a|Cd) = p(d|aC) * p(a|C) / p(d|C) #
78 | #################################################
79 |
80 | reachable_from_head = set().union(*[model.graph().descendants(outcome) for outcome in head])
81 | descendants_in_rhs = set([var.name for var in body]) & reachable_from_head
82 |
83 | if descendants_in_rhs:
84 | logger.info(f"Children of the LHS in the RHS: {','.join(descendants_in_rhs)}")
85 | logger.info("Applying Bayes' rule.")
86 |
87 | # Not elegant, but simply take one of the children from the body out and recurse
88 | child = list(descendants_in_rhs)[0]
89 | child = list(filter(lambda x: x.name == child, body))
90 | new_body = list(set(body) - set(child))
91 |
92 | logger.info(f"{Expression(child, head + new_body)} * {Expression(head, new_body)} / {Expression(child, new_body)}")
93 |
94 | result_1 = _compute(child, head + new_body, depth+1)
95 | result_2 = _compute(head, new_body, depth+1)
96 | result_3 = _compute(child, new_body, depth+1)
97 | if result_3 == 0: # Avoid dividing by 0! coverage: skip
98 | logger.success(f"{Expression([child], new_body)} = 0, therefore the result is 0.")
99 | return 0
100 |
101 | # flip flop flippy flop
102 | result = result_1 * result_2 / result_3
103 | logger.success(f"{current_expression} = {result}")
104 | return result
105 |
106 | #######################################################################################################
107 | # Jeffrey's Rule / Distributive Rule #
108 | # P(y | x) = P(y | z, x) * P(z | x) + P(y | ~z, x) * P(~z | x) === sigma_Z P(y | z, x) * P(z | x) #
109 | #######################################################################################################
110 |
111 | missing_parents = set()
112 | for outcome in head:
113 | missing_parents.update(set(model.variable(outcome.name).parents) - set([parent.name for parent in head + body]))
114 |
115 | if missing_parents:
116 | logger.info("Attempting application of Jeffrey's Rule")
117 |
118 | for missing_parent in missing_parents:
119 |
120 | # Add one parent back in and recurse
121 | parent_outcomes = model.variable(missing_parent).outcomes
122 |
123 | # Consider the missing parent and sum every probability involving it
124 | total = 0.0
125 | for parent_outcome in parent_outcomes:
126 |
127 | as_outcome = Outcome(missing_parent, parent_outcome)
128 |
129 | logger.info(f"{Expression(head, [as_outcome] + body)}, * {Expression([as_outcome], body)}")
130 |
131 | result_1 = _compute(head, [as_outcome] + body, depth+1)
132 | result_2 = _compute([as_outcome], body, depth+1)
133 | outcome_result = result_1 * result_2
134 |
135 | total += outcome_result
136 |
137 | logger.success(f"{current_expression} = {total}")
138 | return total
139 |
140 | ###############################################
141 | # Single element on LHS #
142 | # Drop non-parents #
143 | ###############################################
144 |
145 | if len(head) == 1 and not missing_parents and not descendants_in_rhs:
146 |
147 | head_variable = head[0].name
148 | can_drop = [v for v in body if v.name not in model.variable(head_variable).parents]
149 |
150 | if can_drop:
151 | logger.info(f"can drop: {[str(item) for item in can_drop]}")
152 | result = _compute(head, list(set(body) - set(can_drop)), depth+1)
153 | logger.success(f"{current_expression} = {result}")
154 | return result
155 |
156 | ###############################################
157 | # Cannot compute #
158 | ###############################################
159 |
160 | raise ProbabilityIndeterminableException
161 |
162 | head = set(expression.head())
163 | body = set(expression.body())
164 |
165 | for out in head | body:
166 | assert out.name in model.graph().v, f"Error: Unknown variable {out}"
167 | assert out.outcome in model.variable(out.name).outcomes, f"Error: Unknown outcome {out.outcome} for {out.name}"
168 | assert not isinstance(out, Intervention), \
169 | f"Error: basic inference engine does not handle Interventions ({out.name} is an Intervention)"
170 |
171 | return _compute(list(head), list(body))
172 |
173 |
174 | def contradictory_outcome_set(outcomes: Collection[Outcome]) -> bool:
175 | """
176 | Check whether a list of outcomes contain any contradictory values, such as Y = y and Y = ~y
177 | @param outcomes: A list of Outcome objects
178 | @return: True if there is a contradiction/implausibility, False otherwise
179 | """
180 | for x, y in product(outcomes, outcomes):
181 | if x.name == y.name and x.outcome != y.outcome:
182 | return True
183 | return False
184 |
185 |
186 | def validate(model: Model) -> bool:
187 | """
188 | Ensures a model is 'valid' and 'consistent'.
189 | 1. Ensures the is a DAG (contains no cycles)
190 | 2. Ensures all variables denoted as exogenous are roots.
191 | 3. Ensures all distributions are consistent (the sum of probability of each outcome is 1.0)
192 |
193 | Returns True on success (indicating a valid model), or raises an appropriate Exception indicating a failure.
194 | """
195 | # no cycles
196 | ...
197 |
198 | # exogenous variables are all roots
199 | exogenous = model._g.v - set(model._v.keys())
200 | roots = model._g.roots()
201 | for variable in exogenous:
202 | if variable not in roots:
203 | raise ExogenousNonRoot(variable)
204 |
205 | # consistent distributions
206 | for name, variable in model._v.items():
207 | t = 0
208 | for outcome in variable.outcomes:
209 | t += inference(Expression(Outcome(name, outcome)), model)
210 |
211 | assert within_precision(t, 1)
212 |
213 | # all checks passed -> valid model
214 | return True
215 |
--------------------------------------------------------------------------------
/do/identification/Identification.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Set, Tuple, Union
2 |
3 | from .Exceptions import Fail as FAIL
4 | from .LatentGraph import LatentGraph as Graph
5 | from .PExpression import PExpression, TemplateExpression
6 |
7 |
8 | def Identification(y: Set[str], x: Set[str], p: PExpression, g: Graph, prove: bool = True):
9 | """
10 | The Identification algorithm presented in Shpitser & Pearl, 2007.
11 |
12 | Args:
13 | y (Set[str]): A set of (outcome) variable names, corresponding to vertices present in graph G.
14 | x (Set[str]): A set of (treatment) variable names, corresponding to vertices present in graph G.
15 | p (PExpression): A custom data structure representing a distribution as a summation of variables
16 | (which can be empty) and collection of 'tables' (TemplateExpressions) represented as a variable
17 | name "given" some set of prior variables.
18 | g (Graph): A LatentGraph which has undergone augmentation to remove any exogenous variables, replacing
19 | them with bidirected arcs connecting their children.
20 | prove (bool, optional): Controls whether or not an additional process of proof generation should be
21 | undertaken when identifying the resulting expression. Defaults to True.
22 |
23 | Returns:
24 | PExpression: A resulting PExpression containing any number of nested PExpressions or (terminal)
25 | TemplateExpressions. This is not particularly useful on its own, but instead, can be evaluated
26 | through the main API.
27 | """
28 |
29 | def _identification(_y: Set[str], _x: Set[str], _p: PExpression, _g: Graph, _prove: bool = True, i=0, passdown_proof: Optional[List[Tuple[int, List[str]]]] = None) -> PExpression:
30 |
31 | def s(a_set):
32 | if len(a_set) == 0:
33 | return "Ø"
34 | return "{" + ', '.join(a_set) + "}"
35 |
36 | # The continuation of a proof that is ongoing if this is a recursive ID call, or a 'fresh' new proof sequence otherwise
37 | proof_chain = passdown_proof if passdown_proof else []
38 |
39 | # noinspection PyPep8Naming
40 | def An(vertices):
41 | return _g.ancestors(vertices)
42 |
43 | if _prove:
44 | proof_chain.append((i, [f"ID Begin: Y = {s(_y)}, X = {s(_x)}"]))
45 |
46 | # 1
47 | if _x == set():
48 | if _prove:
49 | proof_chain.append((i, [
50 | "1: if X == Ø, return Σ_{V \\ Y} P(V)",
51 | f" --> Σ_{s(_g.V - _y)} P({s(_g.V)})",
52 | "",
53 | f"[***** Standard Probability Rules *****]"
54 | ]))
55 |
56 | return p_operator(_g.V - _y, _p, proof_chain)
57 |
58 | # 2
59 | if _g.V != An(_y):
60 | w = _g.V - An(_y)
61 | if _prove:
62 | proof_chain.append((i, [
63 | "2: if V != An(Y)",
64 | f"--> {s(_g.V)} != {s(An(_y))}",
65 | " return ID(y, x ∩ An(y), P(An(Y)), An(Y)_G)",
66 | f" --> ID({s(_y)}, {s(_x)} ∩ {s(An(_y))}, P({s(An(_y))}), An({s(An(_y))})_G)",
67 | "",
68 | f" [***** Do-Calculus: Rule 3 *****]",
69 | " let W = V \\ An(Y)_G",
70 | f" W = {s(_g.V)} \\ {s(An(_y))}",
71 | f" W = {s(w)}",
72 | f" G \\ W = An(Y)_G",
73 | f" {s(_g.V)} \\ {s(w)} = {s(An(_y))}",
74 | " P_{x,z} (y | w) = P_{x} (y | w) if (Y ⊥⊥ Z | X, W) _G_X,Z(W)",
75 | f" let y = y ({s(_y)}), x = x ∩ An(Y) ({s(_x & An(_y))}), z = w ({s(w)})" ", w = Ø",
76 | " P_{" f"{s((_x & An(_y)) | w)}" "} " f"({s(_y)}) = P_{s(_x & An(_y))} ({s(_y)}) if ({s(_y)} ⊥⊥ {s(w)} | {s(_x & An(_y))}) _G_{s(_x)}",
77 | ]))
78 |
79 | return _identification(_y, _x & An(_y), p_operator(_g.V - _g[An(_y)].V, _p), _g[An(_y)], _prove, i+1, proof_chain)
80 |
81 |
82 | # 3
83 | w = (_g.V - _x) - _g.without_incoming(_x).ancestors(_y)
84 |
85 | if _prove:
86 | proof_chain.append((i, [
87 | "let W = (V \\ X) \\ An(Y)_G_X",
88 | f"--> W = ({s(_g.V)} \\ {s(_x)}) \\ An({s(_y)})_G_{s(_x)}",
89 | f"--> W = {s(_g.V - _x)} \\ {s(_g.without_incoming(_x).ancestors(_y))}",
90 | f"--> W = {s(w)}"
91 | ]))
92 |
93 | if w != set():
94 | if _prove:
95 | proof_chain.append((i, [
96 | "3: W != Ø",
97 | " return ID(y, x ∪ w, P, G)",
98 | f" --> ID({s(_y)}, {s(_x)} ∪ {s(w)}, P, G)",
99 | "",
100 | " [***** Do-Calculus: Rule 3 *****]",
101 | " P_{x, z} (y | w) = P_{x} if (Y ⊥⊥ Z | X, W)_G_X_Z(W)",
102 | " let y = y, x = x, z = w, w = Ø",
103 | " P_{x} (y | w) = P_{x,z} (y | w) if (Y ⊥⊥ Z | X, W) _G_X,Z(W)",
104 | f" P_{s(_x)} ({s(_y)}) = P_" "{" f"{s(_x)[1:-1]}, {s(w)[1:-1]}" "}" f" ({s(_y)}) if ({s(_y)} ⊥⊥ {s(w)} | {s(_x)})_G_{s(_x)}"
105 | ]))
106 |
107 | return _identification(_y, _x | w, _p, _g, _prove, i+1, proof_chain)
108 |
109 | C_V_minus_X = _g[_g.V - _x].C
110 |
111 | # Line 4
112 | if len(C_V_minus_X) > 1:
113 | if _prove:
114 | proof_chain.append((i, [
115 | "4: C(G \\ X) = {S_1, ..., S_k}",
116 | f"--> C(G \\ X) = C({s(_g.V)} \\ {s(_x)}) = {', '.join(list(map(s, C_V_minus_X)))}",
117 | " return Σ_{V \\ y ∪ x} Π_i ID(Si, v \\ Si, P, G)",
118 | " --> Σ_{" f"{s(_g.V)} \\ {s(_y)} ∪ {s(_x)}" "} Π [",
119 | *[f" --> ID({s(Si)}, {s(_g.V - Si)}, P, G)" for Si in C_V_minus_X],
120 | " ]",
121 | "",
122 | " [***** Proof *****]",
123 | " P_{x} (y) = Σ_{v \\ (y ∪ x)} Π_i P_{v \\ S_i} (S_i)",
124 | " 1. [***** Do-Calculus: Rule 3 *****]",
125 | " Π_i P_{v \\ S_i} (S_i) = Π_i P_{A_i} (S_i), where A_i = An(S_i)_G \\ S_i",
126 | " Π [",
127 | *[f" P_{s(_g.V - si)} ({s(si)[1:-1]})" for si in C_V_minus_X],
128 | " ] = Π [",
129 | *[f" P_{s(_g.ancestors(si)-si)} ({s(si)[1:-1]})" for si in C_V_minus_X],
130 | " ]",
131 |
132 | " 2. [***** Chain Rule *****]",
133 | " Π_i P_{A_i} (S_i) = Π_i Π_{V_j ∈ S_i} P_{A_i} (V_j | V_π^(j-1) \\ A_i)",
134 |
135 | " Π [",
136 | *[f" P_{s(_g.ancestors(si)-si)} ({s(si)[1:-1]})" for si in C_V_minus_X],
137 | " ] = Π [",
138 | *[" ".join([" Π ["] + [
139 | f"P_{s(_g.ancestors(si)-si)} ({vj} | {s(set(_g.v_Pi[:_g.v_Pi.index(vj)]) - _g.ancestors({vj}))})" for vj in si
140 | ] + ["]"]) for si in C_V_minus_X],
141 | " ]",
142 |
143 | " 3. [***** Rule 2 or Rule 3 *****]",
144 | " Π_i Π_{V_j ∈ S_i} P_{A_i} (V_j | V_π^(j-1) \\ A_i) = Π_i Π_{V_j ∈ S_i} P(V_j | V_π^(j-1))",
145 | " a. if A ∈ A_i ∩ V_π^(j-1), A can be removed as an intervention by Rule 2",
146 | " All backdoor paths from A_i to V_j with a node not in V_π^(j-1) are d-separated.",
147 | " Paths must also be bidirected arcs only.",
148 | " let x = x, y = y, z = {A}, w = Ø",
149 | " P_{x,z} (y | w) = P_{x} (y | z, w) if (Y ⊥⊥ Z | X, W)_X_Z_",
150 | " b. if A ∈ A_i \\ V_π^(j-1), A can be removed as an intervention by Rule 3",
151 | " let x = x, y = V_j, z = {A}, w = Ø",
152 | " P_{x,z} (y | w) = P_{x} (y | w) if (Y ⊥⊥ Z | X, W)_G_X_Z(W)",
153 | " (V_j ⊥⊥ A | V_π^(j-1)) G_{A_i}",
154 |
155 | " Π [",
156 | *[" ".join([" Π ["] + [
157 | f"P_{s(_g.ancestors(si)-si)} ({vj} | {s(set(_g.v_Pi[:_g.v_Pi.index(vj)]) - _g.ancestors({vj}))})" for vj in si
158 | ] + ["]"]) for si in C_V_minus_X],
159 | " ] = Π [",
160 | *[" ".join([" Π ["] + [
161 | f"P ({vj} | {s(set(_g.v_Pi[:_g.v_Pi.index(vj)]))})" for vj in si
162 | ] + ["]"]) for si in C_V_minus_X],
163 | " ]",
164 |
165 | " 4. [***** Grouping *****]",
166 | " Π_i Π_{V_j ∈ S_i} P(V_j | V_π^(j-1)) = Π_i P(V_i | V_π^(i-1))",
167 |
168 | " Π [",
169 | *[" ".join([" Π ["] + [
170 | f"P ({vj} | {s(set(_g.v_Pi[:_g.v_Pi.index(vj)]))})" for vj in si
171 | ] + ["]"]) for si in C_V_minus_X],
172 | " ] = Π [",
173 |
174 | " ]",
175 |
176 | " 5. [***** Chain Rule *****]",
177 | " Π_i P(V_i | V_π^(i-1)) = P(v)"
178 | ]))
179 |
180 | return PExpression(_g.V - (_y | _x), [_identification(s_i, _g.V - s_i, _p, _g, _prove, i+1) for s_i in C_V_minus_X], proof_chain)
181 |
182 | else:
183 |
184 | # At this point we have a single component
185 | S = C_V_minus_X[0]
186 |
187 | if _prove:
188 | proof_chain.append((i, [
189 | "if C(G \\ X) = {S}",
190 | f"--> C({s(_g.V)} \\ {s(_x)}) = {s(S)}"
191 | ]))
192 |
193 | # Line 5
194 | if set(S) == _g.V:
195 | if _prove:
196 | proof_chain.append((i, [
197 | "5: if C(G) = {G}: FAIL(G, S)",
198 | f"--> G, S form hedges F, F' for Px(Y) -> {_g}, {S} for P_{_x}({_y})"
199 | ]))
200 |
201 | raise FAIL(_g, S, proof_chain)
202 |
203 | # Line 6 - a single c-component
204 | if S in _g.C:
205 |
206 | dists = []
207 | dist_str = []
208 | for vi in S:
209 | given = _g.v_Pi[:_g.v_Pi.index(vi)]
210 | if _prove:
211 | dist_str.append(f"P({vi})" if len(given) == 0 else f"P({vi} | {', '.join(given)})")
212 | dists.append(TemplateExpression(vi, given))
213 |
214 | if _prove:
215 | proof_chain.append((i, [
216 | f"6: S ∈ C(G)",
217 | f"--> {s(S)} ∈ {', '.join(list(map(s, _g.C)))}",
218 | " return Σ_{S-Y} π_{Vi ∈ S} P(Vi | V_π^(i-1))",
219 | f" --> Σ_{s(S - _y)} π [{', '.join(dist_str)}]",
220 | "",
221 | " [***** Proof *****]",
222 | f" G has been partitioned into S = {s(S)} and X = {s(_x)} in G = {s(_g.V)}.",
223 | " There are no bidirected arcs between S and X."
224 | ]))
225 |
226 | return PExpression(S - _y, dists, proof_chain)
227 |
228 | # 7
229 | else:
230 | s_prime = next(s for s in _g.C if set(s) > set(S))
231 | p = []
232 |
233 | msg = " --> P = "
234 |
235 | for v in s_prime:
236 | rhs0 = _g.v_Pi[:_g.v_Pi.index(v)]
237 | rhs1 = rhs0.copy()
238 |
239 | rhs0 = list(set(rhs0) & s_prime)
240 | rhs1 = list(set(rhs1) - s_prime)
241 | rhs = rhs0 + rhs1
242 | p.append(TemplateExpression(v, rhs))
243 | if _prove:
244 | msg += f"[{v}{(f' | ' + ', '.join(rhs)) if len(rhs) > 0 else ''}]"
245 |
246 | g_s_prime = _g[s_prime]
247 |
248 | if _prove:
249 | proof_chain.append((i, [
250 | f"7: if ∃(S') S ⊂ S' ∈ C(G)",
251 | f"--> let S = {s(S)}, S' = {s(s_prime)}",
252 | f"--> {s(S)} ⊂ {s(s_prime)} ∈ {', '.join(list(map(s, _g.C)))}",
253 | " return ID(y, x ∩ S', π_{V_i ∈ S'} P(V_i | V_π^(i-1) ∩ S', V_π^(i-1) \\ S'), S')",
254 | msg,
255 | f" --> ID({s(_y)}, {s(_x)} ∩ {s(s_prime)}, P, G = ({g_s_prime.V}, {g_s_prime.e}, {g_s_prime.e_bidirected}))",
256 | "",
257 | " [***** Proof *****]",
258 | f" G is partitioned into X = {s(_x)} and S = {s(S)}, where X ⊂ An(S).",
259 | " M_{X \\ S'} induces G \\ (X \\ S') = S'.",
260 | " P_{x} = P_{x ∩ S', X \\ S'} = P_{x ∩ S'}.",
261 | ]))
262 |
263 | return _identification(_y, _x & s_prime, PExpression([], p), g_s_prime, _prove, i+1, proof_chain)
264 |
265 | return _identification(y, x, p, g, prove)
266 |
267 | def simplify_expression(original: PExpression, g: Graph, debug=False) -> PExpression:
268 |
269 | def _simplify(current,i = 0):
270 |
271 | cpt_list_copy = list(filter(lambda i: isinstance(i, TemplateExpression), current.terms))
272 | for s in current.terms:
273 |
274 | if isinstance(s, TemplateExpression):
275 | continue
276 |
277 | c = _simplify(s, i + 1)
278 |
279 | if s.internal_proof:
280 | offset = original.internal_proof[-1][0] + 2
281 | else:
282 | offset = 1
283 |
284 | s.internal_proof.append((offset, c))
285 |
286 | steps = []
287 |
288 | # """
289 | # Remove unnecessary variables from body
290 | for expression in cpt_list_copy:
291 |
292 | while True:
293 | removed_one = False
294 | x = {expression.head}
295 | for variable in expression.given:
296 | y = {variable}
297 | z = set(expression.given) - y
298 | if g.ci(x, y, z):
299 | msg1 = f"{', '.join(x)} is independent of {', '.join(y)} given {', '.join(z)}, and can be removed."
300 | msg2 = f"p operator removed {variable} from body of {expression}"
301 | if debug:
302 | print(msg1)
303 | print(msg2)
304 | steps.append(msg1)
305 | expression.given.remove(variable)
306 | removed_one = True
307 |
308 | if not removed_one:
309 | break
310 | # """
311 |
312 | # Remove unnecessary expressions
313 | # """
314 | while True:
315 | bodies = set().union(*[el.given for el in current.terms if isinstance(el, TemplateExpression)])
316 | search = filter(lambda el: isinstance(el, TemplateExpression) and el.head in current.sigma, current.terms)
317 | remove = list(filter(lambda el: el.head not in bodies, search))
318 |
319 | if len(remove) == 0:
320 | break
321 |
322 | for query in remove:
323 | current.sigma.remove(query.head)
324 | current.terms.remove(query)
325 | msg = f"{query.head} can be removed."
326 | if debug:
327 | print(msg)
328 | steps.append(msg)
329 | # """
330 |
331 | while True:
332 | sumout = [cpt for cpt in current.terms if isinstance(cpt, TemplateExpression) and cpt.head in current.sigma and not any([cpt.head in el.given for el in current.terms if isinstance(el, TemplateExpression)])]
333 | if not sumout:
334 | break
335 | for cpt in sumout:
336 | current.terms.remove(cpt)
337 | current.sigma.remove(cpt.head)
338 |
339 | if len(steps) > 0:
340 | tables = ", ".join(f"P({table.head} | {', '.join(table.given)})" if len(table.given) > 0 else f"P({table.head})" for table in cpt_list_copy)
341 | steps.append(f"After simplification: {tables}")
342 |
343 | def distribution_position(item: Union[PExpression, TemplateExpression]):
344 | if isinstance(item, PExpression):
345 | if len(item.sigma) == 0:
346 | return len(g.v_Pi)
347 | return len(g.v_Pi) + min(0, *list(map(lambda v: g.v_Pi.index(v), item.sigma)))
348 | else:
349 | return g.v_Pi.index(item.head)
350 |
351 | # Sort remaining expressions by the topological ordering
352 | current.terms.sort(key=distribution_position)
353 |
354 | if len(steps) > 0:
355 | steps.insert(0, "[***** Simplification *****]")
356 |
357 | return steps
358 |
359 | if original.internal_proof:
360 | depth = original.internal_proof[-1][0] + 1
361 | else:
362 | depth = 1
363 |
364 | p = original.copy()
365 | changes = _simplify(p)
366 | p.internal_proof.append((depth, changes))
367 | return p
368 |
369 |
370 | def p_operator(v: Set[str], p: PExpression, proof: List[Tuple[int, List[str]]] = None):
371 | return PExpression(list(v.copy() | set(p.sigma)), p.terms.copy(), proof)
372 |
--------------------------------------------------------------------------------