Files @ c532c271f407
Branch filter:

Location: Regular-Expresso/regexp.py

Laman
refactoring: rules as a vector instead of a hashmap
import math
import itertools
from abc import abstractmethod
from collections import deque


class ParsingError(Exception):
	pass


class Token:
	is_skippable = False

	@abstractmethod
	def list_first(self):
		pass

	@abstractmethod
	def list_last(self):
		pass

	@abstractmethod
	def list_neighbours(self):
		pass


class Lambda(Token):
	is_skippable = True

	def list_first(self):
		yield from []

	def list_last(self):
		yield from []

	def list_neighbours(self):
		yield from []


class Symbol(Token):
	def __init__(self, position, value):
		self.position = position
		self.value = value

	def list_first(self):
		yield self.position

	def list_last(self):
		yield self.position

	def list_neighbours(self):
		yield from []

	def __str__(self):
		return self.value


class Asterisk(Token):
	is_skippable = True

	def __init__(self, content: Token):
		self.content = content

	def list_first(self):
		yield from self.content.list_first()

	def list_last(self):
		yield from self.content.list_last()

	def list_neighbours(self):
		yield from self.content.list_neighbours()
		for x in self.list_last():
			for y in self.list_first():
				yield (x, y)

	def __str__(self):
		return str(self.content) + "*"


class Alternative(Token):
	def __init__(self, content: list):
		self.variants = []
		subsequence = []

		for token in content:
			if isinstance(token, AlternativeSeparator):
				if not subsequence:
					raise ParsingError("Found an empty Alternative variant.")
				self.variants.append(Chain(subsequence))
				subsequence = []
			else:
				subsequence.append(token)
		
		if not subsequence:
				raise ParsingError("Found an empty Alternative variant.")
		self.variants.append(Chain(subsequence))
		

	def list_first(self):
		for x in self.variants:
			yield from x.list_first()

	def list_last(self):
		for x in self.variants:
			yield from x.list_last()
	
	def list_neighbours(self):
		for x in self.variants:
			yield from x.list_neighbours()

	@property
	def is_skippable(self):
		return any(x.is_skippable for x in self.variants)

class AlternativeSeparator:
	pass

class Chain(Token):
	def __init__(self, content: list):
		self.content = content

	def list_first(self):
		for token in self.content:
			yield from token.list_first()
			if not token.is_skippable:
				break

	def list_last(self):
		for token in reversed(self.content):
			yield from token.list_last()
			if not token.is_skippable:
				break

	def list_neighbours(self):
		previous = []
		for token in self.content:
			for t in previous:
				for x in t.list_last():
					for y in token.list_first():
						yield (x, y)
			yield from token.list_neighbours()

			if token.is_skippable:
				previous.append(token)
			else:
				previous = [token]

	@property
	def is_skippable(self):
		return all(x.is_skippable for x in self.content)

	def __str__(self):
		return "(" + "".join(str(x) for x in self.content) + ")"


def find_closing_parenthesis(pattern, k):
	counter = 0

	for (i, c) in enumerate(pattern[k:]):
		if c == "(":
			counter += 1
		elif c == ")":
			counter -= 1
		if counter == 0:
			return k+i

	raise ParsingError(f'A closing parenthesis not found. Pattern: "{pattern}", position: {k}')


def parse(pattern, offset=0):
	res = []
	is_alternative = False

	i = 0
	while i < len(pattern):
		c = pattern[i]
		if c == "(":
			j = find_closing_parenthesis(pattern, i)
			inner_content = parse(pattern[i+1:j], offset+i+1)
			res.append(inner_content)
			i = j+1
		elif c == "*":
			try:
				token = res.pop()
			except IndexError as e:
				raise ParsingError(f'The asterisk operator is missing an argument. Pattern: "{pattern}", position {i}')
			res.append(Asterisk(token))
			i += 1
		elif c == ")":
			raise ParsingError(f'An opening parenthesis not found. Pattern: "{pattern}", position: {i}')
		elif c == "|" or c == "+":
			is_alternative = True
			res.append(AlternativeSeparator())
			i += 1
		elif c == "_":
			res.append(Lambda())
			i += 1
		else:
			res.append(Symbol(i+offset, c))
			i += 1

	if is_alternative:
		return Alternative(res)
	else:
		return Chain(res)


class Regexp:
	def __init__(self, pattern):
		r = parse(pattern)
		rules = dict()
		alphabet = set()

		for i in r.list_first():
			c = pattern[i]
			alphabet.add(c)
			key = (-1, c)
			if key not in rules:
				rules[key] = set()
			rules[key].add(i)

		for (i, j) in r.list_neighbours():
			c = pattern[j]
			alphabet.add(c)
			key = (i, c)
			if key not in rules:
				rules[key] = set()
			rules[key].add(j)

		end_states = set(r.list_last())
		if r.is_skippable:
			end_states.add(-1)

		self.rules = rules
		self.end_states = end_states
		self.alphabet = alphabet

	def match(self, s):
		current = {-1}

		for c in s:
			new_state = set()
			for st in current:
				key = (st, c)
				if key in self.rules:
					new_state.update(self.rules[key])
			current = new_state

		return any(st in self.end_states for st in current)

	def determinize(self):
		alphabet_index = {c: i for (i, c) in enumerate(sorted(self.alphabet))}
		n = len(alphabet_index)
		compact_rules = [-1] * n
		end_states = {0} if -1 in self.end_states else set()

		index = {(-1,): 0}
		stack = [(-1,)]
		while stack:
			multistate = stack.pop()
			new_rules = dict()
			
			for ((st, c), target) in filter(lambda item: item[0][0] in multistate, self.rules.items()):
				if c not in new_rules:
					new_rules[c] = set()
				new_rules[c].update(target)
			
			for (c, target_set) in new_rules.items():
				target_tup = tuple(sorted(target_set))
				if target_tup not in index:
					new_target = len(index)
					index[target_tup] = new_target
					compact_rules.extend([-1] * n)
					stack.append(target_tup)
				compact_rules[index[multistate]*n + alphabet_index[c]] = index[target_tup]
				if any(st in self.end_states for st in target_set):
					end_states.add(index[target_tup])
		
		return (compact_rules, end_states, alphabet_index)


class RegexpDFA:
	def __init__(self, rules, end_states, alphabet_index):
		self.rules = rules
		self.end_states = end_states
		self.alphabet_index = alphabet_index

	@classmethod
	def create(cls, pattern):
		r = Regexp(pattern)
		(rules, end_states, alphabet_index) = r.determinize()

		return cls(rules, end_states, alphabet_index)

	def match(self, s):
		st = 0
		n = len(self.alphabet_index)

		for c in s:
			if c not in self.alphabet_index or st < 0:
				return False
			key = (st*n + self.alphabet_index[c])
			st = self.rules[key]

		return st in self.end_states

	def reduce(self):
		equivalents = self._find_equivalent_states()
		(rules, end_states) = self._collapse_states(equivalents)

		return RegexpDFA(rules, end_states, self.alphabet_index)

	def normalize(self):
		n = len(self.alphabet_index)
		index = {-1: -1, 0: 0}
		queue = deque([0])

		rules = []

		while queue:
			si = queue.popleft()
			row = self.rules[si*n:(si+1)*n]
			for sj in row:
				if sj not in index:
					index[sj] = len(index)-1
					queue.append(sj)
			rules.extend(index[sj] for sj in row)
		
		end_states = {index[si] for si in self.end_states}

		return RegexpDFA(rules, end_states, self.alphabet_index)

	def _find_equivalent_states(self):
		n = len(self.alphabet_index)
		state_list = list(range(len(self.rules) // n))
		equivalents = {(s1, s2) for (i, s1) in enumerate(state_list) for s2 in state_list[i+1:] if (s1 in self.end_states) == (s2 in self.end_states)}
		
		ctrl = True
		while ctrl:
			ctrl = False
			for (s1, s2) in equivalents.copy():
				for ci in range(n):
					t1 = self.rules[s1*n + ci]
					t2 = self.rules[s2*n + ci]
					key = (min(t1, t2), max(t1, t2))
					if t1 != t2 and key not in equivalents:
						equivalents.remove((s1, s2))
						ctrl = True
						break
		
		return equivalents
	
	def _collapse_states(self, equivalents):
		n = len(self.alphabet_index)
		rules = []

		eq_mapping = dict()
		for (s1, s2) in equivalents:
			eq_mapping[s2] = min(s1, eq_mapping.get(s2, math.inf))

		discard_mapping = {-1: -1}
		discard_count = 0

		for i in range(0, len(self.rules), n):
			si = i//n
			if si in eq_mapping:
				discard_count += 1
				continue
			discard_mapping[si] = si - discard_count
			rules.extend(map(lambda st: eq_mapping.get(st, st), self.rules[i:i+n]))
		
		rules = [discard_mapping[st] for st in rules]
		end_states = {discard_mapping[eq_mapping.get(st, st)] for st in self.end_states}
		
		return (rules, end_states)


if __name__ == "__main__":
	tests = ["", "a", "ab", "aabb", "abab", "abcd", "abcbcdbcd"]
	for pattern in ["a(b|c)", "a*b*", "(ab)*", "a((bc)*d)*", "(a|b)*a(a|b)(a|b)(a|b)", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz"]:
		print("#", pattern)
		try:
			r = RegexpDFA.create(pattern).reduce().normalize()
		except ParsingError as e:
			print("Failed to parse the regexp:")
			print(e)
			continue
		for t in tests:
			print(t, r.match(t))
		print()