Skip to content

Commit 936e9a9

Browse files
committed
Added support for custom CIs
1 parent 081853e commit 936e9a9

File tree

2 files changed

+278
-1
lines changed

2 files changed

+278
-1
lines changed

causallearn/utils/cit.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,47 @@
2222
gsq = "gsq"
2323
d_separation = "d_separation"
2424

25+
# Registry for custom CI tests
26+
_custom_ci_tests = {}
27+
28+
def register_ci_test(name, test_class):
29+
"""
30+
Register a custom CI test implementation.
31+
32+
Parameters
33+
----------
34+
name: str
35+
Name of the CI test, used to identify the test in the CIT function
36+
test_class: class
37+
The CI test class. Must inherit from CIT_Base and implement __call__ method
38+
39+
Returns
40+
-------
41+
test_class: The registered class (for decorator use)
42+
"""
43+
if not issubclass(test_class, CIT_Base):
44+
raise TypeError(f"CI test class must inherit from CIT_Base: {test_class.__name__}")
45+
46+
_custom_ci_tests[name] = test_class
47+
return test_class
48+
2549

2650
def CIT(data, method='fisherz', **kwargs):
2751
'''
2852
Parameters
2953
----------
3054
data: numpy.ndarray of shape (n_samples, n_features)
3155
method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "rcit", "fastkci", "chisq", "gsq"]
56+
or a custom method registered via register_ci_test
3257
kwargs: placeholder for future arguments, or for KCI, FastKCI or RCIT specific arguments now
3358
TODO: utimately kwargs should be replaced by explicit named parameters.
3459
check https://github.com/cmu-phil/causal-learn/pull/62#discussion_r927239028
3560
'''
61+
# First check if method is a registered custom CI test
62+
if method in _custom_ci_tests:
63+
return _custom_ci_tests[method](data, **kwargs)
64+
65+
# Otherwise use built-in methods
3666
if method == fisherz:
3767
return FisherZ(data, **kwargs)
3868
elif method == kci:
@@ -50,7 +80,7 @@ def CIT(data, method='fisherz', **kwargs):
5080
elif method == d_separation:
5181
return D_Separation(data, **kwargs)
5282
else:
53-
raise ValueError("Unknown method: {}".format(method))
83+
raise ValueError(f"Unknown method: {method}. If using a custom CI test, make sure it's registered with register_ci_test()")
5484

5585

5686
class CIT_Base(object):
@@ -145,6 +175,20 @@ def _stringize(ulist1, ulist2, clist):
145175
len(set(Ys).intersection(condition_set)) == 0, "X, Y cannot be in condition_set."
146176
return Xs, Ys, condition_set, _stringize(Xs, Ys, condition_set)
147177

178+
def __call__(self, X, Y, condition_set=None):
179+
"""
180+
Perform an independence test.
181+
182+
Parameters
183+
----------
184+
X, Y: column indices of data
185+
condition_set: conditioning variables, default None
186+
187+
Returns
188+
-------
189+
p: the p-value of the test
190+
"""
191+
raise NotImplementedError("Subclasses must implement __call__ method")
148192
class FisherZ(CIT_Base):
149193
def __init__(self, data, **kwargs):
150194
super().__init__(data, **kwargs)
@@ -544,3 +588,4 @@ def __call__(self, X, Y, condition_set=None):
544588
# 2. GeneralGraph class will be hugely refactored in the near future.
545589
self.pvalue_cache[cache_key] = p
546590
return p
591+

tests/test_custom_ci.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import numpy as np
2+
from math import log, sqrt
3+
from scipy.stats import norm
4+
from causallearn.utils.cit import CIT, CIT_Base, register_ci_test, NO_SPECIFIED_PARAMETERS_MSG
5+
from causallearn.search.ConstraintBased.PC import pc
6+
import time
7+
8+
# Import our modified modules
9+
# Assuming the modified code is saved in a file called modified_cit.py
10+
#from modified_cit import CIT, CIT_Base, register_ci_test, NO_SPECIFIED_PARAMETERS_MSG
11+
12+
from causallearn.utils.cit import CIT, CIT_Base, register_ci_test, NO_SPECIFIED_PARAMETERS_MSG
13+
14+
# Define a custom implementation of Fisher Z test
15+
16+
class CustomFisherZ(CIT_Base):
17+
def __init__(self, data, **kwargs):
18+
super().__init__(data, **kwargs)
19+
self.check_cache_method_consistent('custom_fisherz', NO_SPECIFIED_PARAMETERS_MSG)
20+
self.assert_input_data_is_valid()
21+
# Calculate the correlation matrix just like the original FisherZ
22+
self.correlation_matrix = np.corrcoef(data.T)
23+
print("Initialized CustomFisherZ test")
24+
25+
def __call__(self, X, Y, condition_set=None):
26+
print("Using the CI test")
27+
'''
28+
Custom implementation of Fisher-Z's test that mirrors the original.
29+
30+
Parameters
31+
----------
32+
X, Y and condition_set : column indices of data
33+
34+
Returns
35+
-------
36+
p : the p-value of the test
37+
'''
38+
Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set)
39+
if cache_key in self.pvalue_cache:
40+
# print(f"Using cached result for {cache_key}")
41+
return self.pvalue_cache[cache_key]
42+
43+
# print(f"Computing new result for {cache_key}")
44+
var = Xs + Ys + condition_set
45+
sub_corr_matrix = self.correlation_matrix[np.ix_(var, var)]
46+
47+
try:
48+
inv = np.linalg.inv(sub_corr_matrix)
49+
except np.linalg.LinAlgError:
50+
raise ValueError('Data correlation matrix is singular. Cannot run fisherz test. Please check your data.')
51+
52+
r = -inv[0, 1] / sqrt(abs(inv[0, 0] * inv[1, 1]))
53+
if abs(r) >= 1:
54+
r = (1. - np.finfo(float).eps) * np.sign(r)
55+
56+
Z = 0.5 * log((1 + r) / (1 - r))
57+
X = sqrt(self.sample_size - len(condition_set) - 3) * abs(Z)
58+
p = 2 * (1 - norm.cdf(abs(X)))
59+
60+
self.pvalue_cache[cache_key] = p
61+
return p
62+
63+
register_ci_test("custom_fisherz", CustomFisherZ)
64+
def run_test():
65+
# Generate some random data
66+
np.random.seed(42) # For reproducibility
67+
n_samples = 100
68+
n_features = 5
69+
data = np.random.randn(n_samples, n_features)
70+
71+
print("=== Testing with original FisherZ ===")
72+
# Create a CI test with the original method
73+
original_ci_test = CIT(data, method="fisherz")
74+
75+
# Run some tests
76+
original_p1 = original_ci_test(0, 1)
77+
original_p2 = original_ci_test(0, 1, [2])
78+
original_p3 = original_ci_test(0, 1, [2, 3])
79+
80+
print(f"Original FisherZ p-values:")
81+
print(f" X=0, Y=1: {original_p1}")
82+
print(f" X=0, Y=1, Z=[2]: {original_p2}")
83+
print(f" X=0, Y=1, Z=[2, 3]: {original_p3}")
84+
85+
print("\n=== Testing with custom FisherZ ===")
86+
# Create a CI test with our custom method
87+
custom_ci_test = CIT(data, method="custom_fisherz")
88+
89+
# Run the same tests
90+
custom_p1 = custom_ci_test(0, 1)
91+
custom_p2 = custom_ci_test(0, 1, [2])
92+
custom_p3 = custom_ci_test(0, 1, [2, 3])
93+
94+
print(f"Custom FisherZ p-values:")
95+
print(f" X=0, Y=1: {custom_p1}")
96+
print(f" X=0, Y=1, Z=[2]: {custom_p2}")
97+
print(f" X=0, Y=1, Z=[2, 3]: {custom_p3}")
98+
99+
# Compare results
100+
print("\n=== Comparing results ===")
101+
print(f"P-value match for X=0, Y=1: {original_p1 == custom_p1}")
102+
print(f"P-value match for X=0, Y=1, Z=[2]: {original_p2 == custom_p2}")
103+
print(f"P-value match for X=0, Y=1, Z=[2, 3]: {original_p3 == custom_p3}")
104+
105+
# Test caching mechanism by running the same test again
106+
print("\n=== Testing caching mechanism ===")
107+
custom_p1_cached = custom_ci_test(0, 1) # Should use cached result
108+
print(f"Cached result matches: {custom_p1 == custom_p1_cached}")
109+
# Register the custom class
110+
111+
112+
def generate_synthetic_data(n_samples=500):
113+
"""
114+
Generate synthetic data with a known causal structure:
115+
X1 -> X3 <- X2
116+
X4 -> X5 -> X6
117+
"""
118+
np.random.seed(42)
119+
120+
# X1, X2, X4 are exogenous
121+
X1 = np.random.normal(0, 1, n_samples)
122+
X2 = np.random.normal(0, 1, n_samples)
123+
X4 = np.random.normal(0, 1, n_samples)
124+
125+
# X3 depends on X1 and X2
126+
X3 = 0.7 * X1 + 0.8 * X2 + np.random.normal(0, 1, n_samples)
127+
128+
# X5 depends on X4
129+
X5 = 0.9 * X4 + np.random.normal(0, 0.5, n_samples)
130+
131+
# X6 depends on X5
132+
X6 = 0.8 * X5 + np.random.normal(0, 0.5, n_samples)
133+
134+
# Combine all variables
135+
data = np.column_stack([X1, X2, X3, X4, X5, X6])
136+
137+
# Ground truth DAG adjacency matrix (1 if i->j)
138+
true_dag = np.zeros((6, 6))
139+
true_dag[0, 2] = 1 # X1 -> X3
140+
true_dag[1, 2] = 1 # X2 -> X3
141+
true_dag[3, 4] = 1 # X4 -> X5
142+
true_dag[4, 5] = 1 # X5 -> X6
143+
144+
return data, true_dag
145+
146+
def print_graph_edges(adj_matrix, title):
147+
"""Print edges from an adjacency matrix"""
148+
print(f"\n{title}:")
149+
edge_count = 0
150+
for i in range(adj_matrix.shape[0]):
151+
for j in range(adj_matrix.shape[1]):
152+
if adj_matrix[i, j] != 0:
153+
edge_count += 1
154+
print(f" X{i+1} -> X{j+1}")
155+
if edge_count == 0:
156+
print(" No edges found")
157+
else:
158+
print(f" Total: {edge_count} edges")
159+
160+
def compare_results(adj1, adj2):
161+
"""Compare two adjacency matrices and return metrics"""
162+
# Check if the matrices have the same shape
163+
if adj1.shape != adj2.shape:
164+
raise ValueError("Adjacency matrices must have the same shape")
165+
166+
# Convert to binary (just in case)
167+
adj1_bin = (adj1 != 0).astype(int)
168+
adj2_bin = (adj2 != 0).astype(int)
169+
170+
# Count matches and mismatches
171+
matches = np.sum(adj1_bin == adj2_bin)
172+
total = adj1.shape[0] * adj1.shape[1]
173+
174+
# Calculate edge presence match
175+
match_percentage = (matches / total) * 100
176+
177+
# Count missing and extra edges
178+
in_1_not_2 = np.sum((adj1_bin == 1) & (adj2_bin == 0))
179+
in_2_not_1 = np.sum((adj1_bin == 0) & (adj2_bin == 1))
180+
181+
return {
182+
'match_percentage': match_percentage,
183+
'edges_in_first_not_second': in_1_not_2,
184+
'edges_in_second_not_first': in_2_not_1
185+
}
186+
187+
def run_pc_algorithm_test():
188+
"""Run the PC algorithm with both built-in and custom Fisher-Z tests"""
189+
print("\n=== Testing PC Algorithm with Custom CI Test ===")
190+
191+
# Generate synthetic data
192+
data, true_dag = generate_synthetic_data(n_samples=500)
193+
194+
print("Data shape:", data.shape)
195+
196+
# Print true DAG edges
197+
print_graph_edges(true_dag, "True DAG Edges")
198+
199+
# Run PC with built-in Fisher-Z
200+
print("\nRunning PC with built-in Fisher-Z...")
201+
start_time = time.time()
202+
pc_result_built_in = pc(data, 0.05, indep_test="fisherz")
203+
built_in_time = time.time() - start_time
204+
print(f"Built-in Fisher-Z test took {built_in_time:.4f} seconds")
205+
206+
# Run PC with custom Fisher-Z
207+
print("\nRunning PC with custom Fisher-Z...")
208+
start_time = time.time()
209+
pc_result_custom = pc(data, 0.05, indep_test="custom_fisherz")
210+
custom_time = time.time() - start_time
211+
print(f"Custom Fisher-Z test took {custom_time:.4f} seconds")
212+
213+
# Get the adjacency matrices
214+
adj_built_in = pc_result_built_in.G.graph
215+
adj_custom = pc_result_custom.G.graph
216+
217+
# Print edges from both results
218+
print_graph_edges(adj_built_in, "PC with Built-in Fisher-Z")
219+
print_graph_edges(adj_custom, "PC with Custom Fisher-Z")
220+
221+
# Compare built-in and custom results
222+
comparison = compare_results(adj_built_in, adj_custom)
223+
print(f"\nComparison of built-in vs custom CI test:")
224+
print(f"Match percentage: {comparison['match_percentage']:.2f}%")
225+
print(f"Edges in built-in result not in custom: {comparison['edges_in_first_not_second']}")
226+
print(f"Edges in custom result not in built-in: {comparison['edges_in_second_not_first']}")
227+
228+
229+
if __name__ == "__main__":
230+
run_test()
231+
run_pc_algorithm_test()
232+

0 commit comments

Comments
 (0)