Files @ f6c340624f3f
Branch filter:

Location: Languedoc/languedoc.py

Laman
optimized the prediction code
import os
import re
import random
import itertools

random.seed(19181028)

CROSSVALIDATION_SOURCE_COUNT = 5
TEST_LENS = [8, 16, 32, 64]
TOP_NGRAM_COUNT = 6000


def preprocess(text):
	text = re.sub(r"[\W\d_]+", " ", " "+text+" ")
	return text.lower()


def extract_ngram_freqs(text, k):
	n = len(text)
	d = dict()

	for i in range(0, n-k+1):
		key = text[i:i+k]
		if key.isspace():
			continue

		d[key] = d.get(key, 0) + 1

	count = sum(d.values())

	return {key: val/count for (key, val) in d.items()}


def merge_ngram_freqs(freqs):
	n = len(freqs)
	res = dict()

	for d in freqs:
		for (key, val) in d.items():
			res.setdefault(key, 0)
			res[key] += val/n

	return res


class Sample:
	def __init__(self, language="??", text=""):
		self.language = language
		self.frequencies = dict()
		self._ranked_ngrams = dict()

		if text:
			self._extract(text)

	def _extract(self, text):
		for k in range(1, 4):
			self.frequencies.update(extract_ngram_freqs(text, k))

	@staticmethod
	def merge(samples):
		assert len({x.language for x in samples}) == 1

		res = Sample(samples[0].language)
		res.frequencies = merge_ngram_freqs([x.frequencies for x in samples])

		return res

	@property
	def ranked_ngrams(self):
		if not self._ranked_ngrams:
			ordered_ngrams = sorted(self.frequencies.items(), key=lambda kv: -kv[1])[:TOP_NGRAM_COUNT]
			self._ranked_ngrams = dict(zip([key for (key, freq) in ordered_ngrams], itertools.count(0)))

		return self._ranked_ngrams

	def compare(self, other):
		"""take k most common
		use frequencies x order
		use letter, digrams, trigrams
		use absolute x square"""
		res = sum(abs(v-other.ranked_ngrams.get(k, len(other.ranked_ngrams))) for (k, v) in self.ranked_ngrams.items()) + \
					sum(abs(v-self.ranked_ngrams.get(k, len(self.ranked_ngrams))) for (k, v) in other.ranked_ngrams.items())

		return res

	def print_overview(self):
		print(f"Sample({self.language}):")

		for freqs in self.frequencies:
			x = [
				(k, round(v, 3))
				for (k, v) in sorted(freqs.items(), key=lambda kv: -kv[1])
			]
			print("  ", x[:8], "...", x[-8:])

		print()


class SampleSet:
	def __init__(self, language):
		self.language = language
		self.texts = []
		self.samples = []

	def add(self, text):
		self.texts.append(text)
		self.samples.append(Sample(self.language, text))

	def create_model(self):
		return Sample.merge(self.samples)

	def generate_tests(self, n):
		for (i, (text, sample)) in enumerate(itertools.cycle(zip(self.texts, self.samples))):
			if i >= n:
				break

			yield (text, Sample.merge([x for x in self.samples if x is not sample]))


def cross_validate(sample_sets):
	models = [s.create_model() for s in sample_sets]
	score = 0
	max_score = 0

	for s in sample_sets:
		for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
			real_lang = partial_model.language
			test_models = [partial_model] + [m for m in models if m.language != real_lang]

			for k in TEST_LENS:
				for i in range(10):
					j = random.randrange(0, len(test_text)-k)
					t = test_text[j:j+k]
					predicted_lang = identify(t, test_models)
					if predicted_lang == real_lang:
						score += 1
					else:
						print(real_lang, predicted_lang, t)
					max_score += 1

	return score / max_score, (score, max_score)


def identify(text, models):
	sample = Sample(text=text)

	return min(map(lambda m: (m.compare(sample), m.language), models))[1]


DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
LANG_DIRS = sorted([x.path for x in os.scandir(DATA_DIR)])

if __name__ == "__main__":
	samples = []

	for d in LANG_DIRS:
		lang = os.path.basename(d)
		lang_samples = SampleSet(lang)
		samples.append(lang_samples)

		for file in sorted(os.scandir(d), key=lambda f: f.name):
			with open(file) as f:
				text = f.read()
				text = preprocess(text)
				print(f"{lang}: {file.name} ({len(text)})")

				lang_samples.add(text)

	print(cross_validate(samples))