Skip to content

Commit 3064fbd

Browse files
committed
Fix bug: rule9-uncovered p.d. paths
Signed-off-by: Zheng Li <[email protected]>
1 parent 474437c commit 3064fbd

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

causallearn/search/ConstraintBased/FCI.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,54 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ##
8282

8383
return False
8484

85+
86+
87+
def traversePotentiallyDirected(node: Node, edge: Edge) -> Node | None:
88+
if node == edge.get_node1():
89+
if (edge.get_endpoint1() == Endpoint.TAIL or edge.get_endpoint1() == Endpoint.CIRCLE) and \
90+
(edge.get_endpoint2() == Endpoint.ARROW or edge.get_endpoint2() == Endpoint.CIRCLE):
91+
return edge.get_node2()
92+
elif node == edge.get_node2():
93+
if (edge.get_endpoint2() == Endpoint.TAIL or edge.get_endpoint2() == Endpoint.CIRCLE) and \
94+
(edge.get_endpoint1() == Endpoint.ARROW or edge.get_endpoint1() == Endpoint.CIRCLE):
95+
return edge.get_node1()
96+
return None
97+
98+
99+
def existsUncoveredPdPath(node_from: Node, node_next: Node, node_to: Node, G: Graph) -> bool:
100+
Q = Queue()
101+
V = set([node_from, node_next])
102+
103+
for node_u in G.get_adjacent_nodes(node_next):
104+
edge = G.get_edge(node_next, node_u)
105+
node_c = traversePotentiallyDirected(node_next, edge)
106+
107+
if node_c is None:
108+
continue
109+
110+
if not V.__contains__(node_c):
111+
V.add(node_c)
112+
Q.put((node_c, [node_from, node_next, node_c]))
113+
114+
while not Q.empty():
115+
node_t, path = Q.get_nowait()
116+
if node_t == node_to and is_uncovered_path(path, G):
117+
# print(f"Found uncovered pd path: {[node.get_name() for node in path]}")
118+
return True
119+
120+
for node_u in G.get_adjacent_nodes(node_t):
121+
edge = G.get_edge(node_t, node_u)
122+
node_c = traversePotentiallyDirected(node_t, edge)
123+
124+
if node_c is None:
125+
continue
126+
127+
if not V.__contains__(node_c):
128+
V.add(node_c)
129+
Q.put((node_c, path + [node_c]))
130+
131+
return False
132+
85133
def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph, exclude_node: List[Node]) -> Generator[Node] | None:
86134
Q = Queue()
87135
V = set()
@@ -802,7 +850,8 @@ def rule9(graph: Graph, nodes: List[Node], changeFlag):
802850
for node_B in possible_children:
803851
if graph.is_adjacent_to(node_B, node_C):
804852
continue
805-
if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph):
853+
854+
if existsUncoveredPdPath(node_from=node_A, node_next=node_B, node_to=node_C, G=graph):
806855
edge1 = graph.get_edge(node_A, node_C)
807856
graph.remove_edge(edge1)
808857
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))

tests/TestDAG2PAG.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,40 @@ def test_case_selection(self):
8686
dag.add_directed_edge(nodes[0], nodes[4])
8787
pag = dag2pag(dag, islatent=[], isselection=[nodes[4]])
8888
print(pag)
89+
90+
def test_case_orient_rules(self):
91+
nodes = []
92+
X = {}
93+
L = {}
94+
for i in range(7):
95+
node_name = f"X{i + 1}"
96+
if i + 1 == 2:
97+
node_name = f"L{i + 1}"
98+
node = GraphNode(node_name)
99+
nodes.append(node)
100+
if i + 1 == 2:
101+
L[2] = node
102+
else:
103+
X[i + 1] = node
104+
dag = Dag(nodes)
105+
dag.add_directed_edge(L[2], X[4])
106+
dag.add_directed_edge(L[2], X[5])
107+
dag.add_directed_edge(L[2], X[6])
108+
109+
dag.add_directed_edge(X[5], X[7])
110+
dag.add_directed_edge(X[1], X[4])
111+
dag.add_directed_edge(X[1], X[7])
112+
dag.add_directed_edge(X[3], X[7])
113+
pag = dag2pag(dag, [L[2]])
114+
print(pag)
115+
graphviz_pag = GraphUtils.to_pgv(pag)
116+
graphviz_pag.draw("pag.png", prog='dot', format='png')
117+
118+
119+
if __name__ == "__main__":
120+
test_model = TestDAG2PAG()
121+
test_model.test_case1()
122+
test_model.test_case2()
123+
test_model.test_case3()
124+
test_model.test_case_selection()
125+
test_model.test_case_orient_rules()

0 commit comments

Comments
 (0)