Skip to content

Commit 2f4c4e7

Browse files
committed
Use cls/self.__class__ in more places in DFA class
This will make subclassing DFA more flexible.
1 parent b1a6062 commit 2f4c4e7

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

automata/fa/dfa.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def minify(self, retain_names: bool = False) -> DFA:
338338
"""
339339

340340
# Compute reachable states and final states
341-
bfs_states = DFA._bfs_states(
341+
bfs_states = self.__class__._bfs_states(
342342
self.initial_state, lambda state: iter(self.transitions[state].items())
343343
)
344344
reachable_states = {*bfs_states}
@@ -456,7 +456,7 @@ def union_function(state_pair: Tuple[DFAStateT, DFAStateT]) -> bool:
456456
q_a, q_b = state_pair
457457
return q_a in self.final_states or q_b in other.final_states
458458

459-
initial_state, expand_state_fn = DFA._cross_product(self, other)
459+
initial_state, expand_state_fn = self.__class__._cross_product(self, other)
460460

461461
return self.__class__._expand_dfa(
462462
union_function,
@@ -480,7 +480,7 @@ def intersection_function(state_pair: Tuple[DFAStateT, DFAStateT]) -> bool:
480480
q_a, q_b = state_pair
481481
return q_a in self.final_states and q_b in other.final_states
482482

483-
initial_state, expand_state_fn = DFA._cross_product(self, other)
483+
initial_state, expand_state_fn = self.__class__._cross_product(self, other)
484484

485485
return self.__class__._expand_dfa(
486486
intersection_function,
@@ -504,7 +504,7 @@ def difference_function(state_pair: Tuple[DFAStateT, DFAStateT]) -> bool:
504504
q_a, q_b = state_pair
505505
return q_a in self.final_states and q_b not in other.final_states
506506

507-
initial_state, expand_state_fn = DFA._cross_product(self, other)
507+
initial_state, expand_state_fn = self.__class__._cross_product(self, other)
508508

509509
return self.__class__._expand_dfa(
510510
difference_function,
@@ -530,7 +530,7 @@ def symmetric_difference_function(
530530
q_a, q_b = state_pair
531531
return (q_a in self.final_states) ^ (q_b in other.final_states)
532532

533-
initial_state, expand_state_fn = DFA._cross_product(self, other)
533+
initial_state, expand_state_fn = self.__class__._cross_product(self, other)
534534

535535
return self.__class__._expand_dfa(
536536
symmetric_difference_function,
@@ -545,7 +545,7 @@ def complement(self, *, retain_names: bool = False, minify: bool = True) -> DFA:
545545
"""Return the complement of this DFA."""
546546

547547
if minify:
548-
bfs_states = DFA._bfs_states(
548+
bfs_states = self.__class__._bfs_states(
549549
self.initial_state, lambda state: iter(self.transitions[state].items())
550550
)
551551
reachable_states = {*bfs_states}
@@ -612,8 +612,9 @@ def _bfs_states(
612612
visited_set.add(tgt_state)
613613
queue.append(tgt_state)
614614

615-
@staticmethod
615+
@classmethod
616616
def _expand_dfa(
617+
cls,
617618
final_state_fn: IsFinalStateFn,
618619
initial_state: DFAStateT,
619620
expand_state_fn: ExpandStateFn,
@@ -645,7 +646,7 @@ def get_name_original(state):
645646
states = {initial_state_name}
646647
final_states = {initial_state_name} if final_state_fn(initial_state) else set()
647648

648-
for cur_state, chr, tgt_state in DFA._bfs_edges(initial_state, expand_state_fn):
649+
for cur_state, chr, tgt_state in cls._bfs_edges(initial_state, expand_state_fn):
649650
cur_state_name = get_name(cur_state)
650651
tgt_state_name = get_name(tgt_state)
651652

@@ -660,7 +661,7 @@ def get_name_original(state):
660661

661662
if minify:
662663
# From the construction, the states/final states are reachable
663-
return DFA._minify(
664+
return cls._minify(
664665
reachable_states=states,
665666
input_symbols=input_symbols,
666667
transitions=transitions,
@@ -676,8 +677,9 @@ def get_name_original(state):
676677
final_states=final_states,
677678
)
678679

679-
@staticmethod
680+
@classmethod
680681
def _find_state(
682+
cls,
681683
target_state_fn: TargetStateFn,
682684
initial_state: DFAStateT,
683685
expand_state_fn: ExpandStateFn,
@@ -688,7 +690,7 @@ def _find_state(
688690
searched for. The expand_state_fn should return an iterator with the
689691
successors of each state in the product.
690692
"""
691-
bfs_states = DFA._bfs_states(initial_state, expand_state_fn)
693+
bfs_states = cls._bfs_states(initial_state, expand_state_fn)
692694
return any(target_state_fn(state) for state in bfs_states)
693695

694696
@staticmethod
@@ -718,8 +720,10 @@ def subset_state_fn(state_pair: Tuple[DFAStateT, DFAStateT]) -> bool:
718720
q_a, q_b = state_pair
719721
return q_a in self.final_states and q_b not in other.final_states
720722

721-
initial_state, expand_state_fn = DFA._cross_product(self, other)
722-
return not DFA._find_state(subset_state_fn, initial_state, expand_state_fn)
723+
initial_state, expand_state_fn = self.__class__._cross_product(self, other)
724+
return not self.__class__._find_state(
725+
subset_state_fn, initial_state, expand_state_fn
726+
)
723727

724728
def issuperset(self, other: DFA) -> bool:
725729
"""Return True if this DFA is a superset of another DFA."""
@@ -733,12 +737,14 @@ def disjoint_state_fn(state_pair: Tuple[DFAStateT, DFAStateT]) -> bool:
733737
q_a, q_b = state_pair
734738
return q_a in self.final_states and q_b in other.final_states
735739

736-
initial_state, expand_state_fn = DFA._cross_product(self, other)
737-
return not DFA._find_state(disjoint_state_fn, initial_state, expand_state_fn)
740+
initial_state, expand_state_fn = self.__class__._cross_product(self, other)
741+
return not self.__class__._find_state(
742+
disjoint_state_fn, initial_state, expand_state_fn
743+
)
738744

739745
def isempty(self) -> bool:
740746
"""Return True if this DFA is completely empty."""
741-
return not DFA._find_state(
747+
return not self.__class__._find_state(
742748
lambda state: state in self.final_states,
743749
self.initial_state,
744750
lambda state: iter(self.transitions[state].items()),
@@ -1376,7 +1382,7 @@ def from_finite_language(
13761382
SignatureT = Tuple[bool, FrozenSet[Tuple[str, str]]]
13771383

13781384
if not language:
1379-
return DFA.empty_language(input_symbols)
1385+
return cls.empty_language(input_symbols)
13801386

13811387
transitions: Dict[DFAStateT, Dict[str, DFAStateT]] = {}
13821388
back_map: Dict[str, Set[str]] = {"": set()}

0 commit comments

Comments
 (0)