diff --git a/regexp.py b/regexp.py --- a/regexp.py +++ b/regexp.py @@ -206,6 +206,14 @@ def parse(pattern, offset=0): return Chain(res) +def print_dfa(dfa, label=""): + n = len(dfa.alphabet_index) + print(label) + for i in range(0, len(dfa.rules), n): + print(i//n, dfa.rules[i:i+n]) + print(dfa.end_states) + + class Regexp: def __init__(self, pattern): r = parse(pattern) @@ -340,7 +348,9 @@ class RegexpDFA: if self.rules == r.rules and self.end_states == r.end_states: return None - product = self._build_product_automaton(r) + r1 = self._expand_alphabet(r.alphabet_index) + r2 = r._expand_alphabet(self.alphabet_index) + product = r1._build_product_automaton(r2) n = len(product.alphabet_index) reverse_alphabet_index = {v: k for (k, v) in product.alphabet_index.items()} @@ -401,6 +411,28 @@ class RegexpDFA: return (rules, end_states) + def _expand_alphabet(self, alphabet_index): + if alphabet_index == self.alphabet_index: + return self + + n1 = len(self.alphabet_index) + m = len(self.rules) // n1 + + combined_alphabet = set(self.alphabet_index.keys()) | set(alphabet_index.keys()) + combined_index = {c: i for (i, c) in enumerate(sorted(combined_alphabet))} + conversion_index = {v: combined_index[k] for (k, v) in self.alphabet_index.items()} + n2 = len(combined_alphabet) + + rules = [] + for i in range(0, len(self.rules), n1): + row = ([m]*n2) + for (j, st) in enumerate(self.rules[i:i+n1]): + row[conversion_index[j]] = st + rules.extend(row) + rules.extend([m]*n2) + + return RegexpDFA(rules, self.end_states, combined_index).reduce().normalize() + def _build_product_automaton(self, r): n = len(self.alphabet_index) m = len(r.rules) // n