From 3064fbdd357f0c7f22d6eb5d8c18d53ccdd4198d Mon Sep 17 00:00:00 2001 From: ZhengLi Date: Fri, 5 Sep 2025 15:58:09 +0800 Subject: [PATCH] Fix bug: rule9-uncovered p.d. paths Signed-off-by: Zheng Li --- causallearn/search/ConstraintBased/FCI.py | 51 ++++++++++++++++++++++- tests/TestDAG2PAG.py | 37 ++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/causallearn/search/ConstraintBased/FCI.py b/causallearn/search/ConstraintBased/FCI.py index 43d1e29..869b8de 100644 --- a/causallearn/search/ConstraintBased/FCI.py +++ b/causallearn/search/ConstraintBased/FCI.py @@ -82,6 +82,54 @@ def existsSemiDirectedPath(node_from: Node, node_to: Node, G: Graph) -> bool: ## return False + + +def traversePotentiallyDirected(node: Node, edge: Edge) -> Node | None: + if node == edge.get_node1(): + if (edge.get_endpoint1() == Endpoint.TAIL or edge.get_endpoint1() == Endpoint.CIRCLE) and \ + (edge.get_endpoint2() == Endpoint.ARROW or edge.get_endpoint2() == Endpoint.CIRCLE): + return edge.get_node2() + elif node == edge.get_node2(): + if (edge.get_endpoint2() == Endpoint.TAIL or edge.get_endpoint2() == Endpoint.CIRCLE) and \ + (edge.get_endpoint1() == Endpoint.ARROW or edge.get_endpoint1() == Endpoint.CIRCLE): + return edge.get_node1() + return None + + +def existsUncoveredPdPath(node_from: Node, node_next: Node, node_to: Node, G: Graph) -> bool: + Q = Queue() + V = set([node_from, node_next]) + + for node_u in G.get_adjacent_nodes(node_next): + edge = G.get_edge(node_next, node_u) + node_c = traversePotentiallyDirected(node_next, edge) + + if node_c is None: + continue + + if not V.__contains__(node_c): + V.add(node_c) + Q.put((node_c, [node_from, node_next, node_c])) + + while not Q.empty(): + node_t, path = Q.get_nowait() + if node_t == node_to and is_uncovered_path(path, G): + # print(f"Found uncovered pd path: {[node.get_name() for node in path]}") + return True + + for node_u in G.get_adjacent_nodes(node_t): + edge = G.get_edge(node_t, node_u) + node_c = traversePotentiallyDirected(node_t, edge) + + if node_c is None: + continue + + if not V.__contains__(node_c): + V.add(node_c) + Q.put((node_c, path + [node_c])) + + return False + def GetUncoveredCirclePath(node_from: Node, node_to: Node, G: Graph, exclude_node: List[Node]) -> Generator[Node] | None: Q = Queue() V = set() @@ -802,7 +850,8 @@ def rule9(graph: Graph, nodes: List[Node], changeFlag): for node_B in possible_children: if graph.is_adjacent_to(node_B, node_C): continue - if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph): + + if existsUncoveredPdPath(node_from=node_A, node_next=node_B, node_to=node_C, G=graph): edge1 = graph.get_edge(node_A, node_C) graph.remove_edge(edge1) graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW)) diff --git a/tests/TestDAG2PAG.py b/tests/TestDAG2PAG.py index 004ea89..495231e 100644 --- a/tests/TestDAG2PAG.py +++ b/tests/TestDAG2PAG.py @@ -86,3 +86,40 @@ def test_case_selection(self): dag.add_directed_edge(nodes[0], nodes[4]) pag = dag2pag(dag, islatent=[], isselection=[nodes[4]]) print(pag) + + def test_case_orient_rules(self): + nodes = [] + X = {} + L = {} + for i in range(7): + node_name = f"X{i + 1}" + if i + 1 == 2: + node_name = f"L{i + 1}" + node = GraphNode(node_name) + nodes.append(node) + if i + 1 == 2: + L[2] = node + else: + X[i + 1] = node + dag = Dag(nodes) + dag.add_directed_edge(L[2], X[4]) + dag.add_directed_edge(L[2], X[5]) + dag.add_directed_edge(L[2], X[6]) + + dag.add_directed_edge(X[5], X[7]) + dag.add_directed_edge(X[1], X[4]) + dag.add_directed_edge(X[1], X[7]) + dag.add_directed_edge(X[3], X[7]) + pag = dag2pag(dag, [L[2]]) + print(pag) + graphviz_pag = GraphUtils.to_pgv(pag) + graphviz_pag.draw("pag.png", prog='dot', format='png') + + +if __name__ == "__main__": + test_model = TestDAG2PAG() + test_model.test_case1() + test_model.test_case2() + test_model.test_case3() + test_model.test_case_selection() + test_model.test_case_orient_rules() \ No newline at end of file