diff --git a/regexp.py b/regexp.py --- a/regexp.py +++ b/regexp.py @@ -318,8 +318,8 @@ class RegexpDFA: return st in self.end_states def reduce(self): - equivalents = self._find_equivalent_states() - (rules, end_states) = self._collapse_states(equivalents) + partition = self._find_equivalent_states() + (rules, end_states) = self._collapse_states(partition) return RegexpDFA(rules, end_states, self.alphabet_index) @@ -368,45 +368,70 @@ class RegexpDFA: def _find_equivalent_states(self): n = len(self.alphabet_index) - state_list = list(range(len(self.rules) // n)) - equivalents = {(s1, s2) for (i, s1) in enumerate(state_list) for s2 in state_list[i+1:] if (s1 in self.end_states) == (s2 in self.end_states)} + m = len(self.rules) // n + inverse_rules = [set() for i in range(m*n)] + + for i in range(m): + for j in range(n): + target = self.rules[i*n + j] + inverse_rules[target*n + j].add(i) + + set_bag = [self.end_states, set(range(m))-self.end_states] + res = {0, 1} + work = {0, 1} + + while work: + key = work.pop() + target_set = set_bag[key] + for j in range(n): + source_set = set(itertools.chain.from_iterable(inverse_rules[t*n + j] for t in target_set)) + for k in res.copy(): + part = set_bag[k] + intersection = part & source_set + diff = part - source_set + if not intersection or not diff: + continue + res.remove(k) + ki = len(set_bag) + set_bag.append(intersection) + res.add(ki) + kd = len(set_bag) + set_bag.append(diff) + res.add(kd) + if k in work: + work.remove(k) + work.add(ki) + work.add(kd) + elif len(intersection) < len(diff): + work.add(ki) + else: + work.add(kd) - ctrl = True - while ctrl: - ctrl = False - for (s1, s2) in equivalents.copy(): - for ci in range(n): - t1 = self.rules[s1*n + ci] - t2 = self.rules[s2*n + ci] - key = (min(t1, t2), max(t1, t2)) - if t1 != t2 and key not in equivalents: - equivalents.remove((s1, s2)) - ctrl = True - break - - return equivalents + return [set_bag[k] for k in res] - def _collapse_states(self, equivalents): + def _collapse_states(self, partition): n = len(self.alphabet_index) rules = [] eq_mapping = dict() - for (s1, s2) in equivalents: - eq_mapping[s2] = min(s1, eq_mapping.get(s2, math.inf)) + for eq_set in partition: + states = sorted(eq_set) + for st in states: + eq_mapping[st] = states[0] discard_mapping = dict() discard_count = 0 for i in range(0, len(self.rules), n): si = i//n - if si in eq_mapping: + if eq_mapping[si] != si: discard_count += 1 continue discard_mapping[si] = si - discard_count - rules.extend(map(lambda st: eq_mapping.get(st, st), self.rules[i:i+n])) + rules.extend(map(lambda st: eq_mapping[st], self.rules[i:i+n])) rules = [discard_mapping[st] for st in rules] - end_states = {discard_mapping[eq_mapping.get(st, st)] for st in self.end_states} + end_states = {discard_mapping[eq_mapping[st]] for st in self.end_states} return (rules, end_states) @@ -452,7 +477,7 @@ class RegexpDFA: def test(): tests = ["", "a", "ab", "aabb", "abab", "abcd", "abcbcdbcd"] - for pattern in ["a(b|c)", "a*b*", "(ab)*", "a((bc)*d)*", "(a|b)*a(a|b)(a|b)(a|b)", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"]: + for pattern in ["a(b|c)", "a*b*", "(ab)*", "a((bc)*d)*", "(a|b)*a(a|b)(a|b)(a|b)", "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"]: print("#", pattern) try: r = RegexpDFA.create(pattern).reduce().normalize()