Files @ f6c340624f3f
Branch filter:

Location: Languedoc/languedoc.py - annotation

Laman
optimized the prediction code
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
5ab4acb6f293
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
2de09682747e
6fce04d6aa8d
f6c340624f3f
6fce04d6aa8d
1c7a7c3926e6
1c7a7c3926e6
1cae4ecc8978
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
5ab4acb6f293
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
54f61e475ab8
f6c340624f3f
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
54f61e475ab8
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
54f61e475ab8
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
f6c340624f3f
f6c340624f3f
f6c340624f3f
f6c340624f3f
f6c340624f3f
f6c340624f3f
f6c340624f3f
f6c340624f3f
1cae4ecc8978
5ab4acb6f293
5ab4acb6f293
5ab4acb6f293
5ab4acb6f293
f6c340624f3f
f6c340624f3f
167aab0c3103
5ab4acb6f293
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
2de09682747e
2de09682747e
2de09682747e
2de09682747e
2de09682747e
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
2de09682747e
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
1c7a7c3926e6
3980aeb455b0
1c7a7c3926e6
6fce04d6aa8d
1cae4ecc8978
1c7a7c3926e6
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
1c7a7c3926e6
3980aeb455b0
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
2de09682747e
1c7a7c3926e6
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
import os
import re
import random
import itertools

random.seed(19181028)

CROSSVALIDATION_SOURCE_COUNT = 5
TEST_LENS = [8, 16, 32, 64]
TOP_NGRAM_COUNT = 6000


def preprocess(text):
	text = re.sub(r"[\W\d_]+", " ", " "+text+" ")
	return text.lower()


def extract_ngram_freqs(text, k):
	n = len(text)
	d = dict()

	for i in range(0, n-k+1):
		key = text[i:i+k]
		if key.isspace():
			continue

		d[key] = d.get(key, 0) + 1

	count = sum(d.values())

	return {key: val/count for (key, val) in d.items()}


def merge_ngram_freqs(freqs):
	n = len(freqs)
	res = dict()

	for d in freqs:
		for (key, val) in d.items():
			res.setdefault(key, 0)
			res[key] += val/n

	return res


class Sample:
	def __init__(self, language="??", text=""):
		self.language = language
		self.frequencies = dict()
		self._ranked_ngrams = dict()

		if text:
			self._extract(text)

	def _extract(self, text):
		for k in range(1, 4):
			self.frequencies.update(extract_ngram_freqs(text, k))

	@staticmethod
	def merge(samples):
		assert len({x.language for x in samples}) == 1

		res = Sample(samples[0].language)
		res.frequencies = merge_ngram_freqs([x.frequencies for x in samples])

		return res

	@property
	def ranked_ngrams(self):
		if not self._ranked_ngrams:
			ordered_ngrams = sorted(self.frequencies.items(), key=lambda kv: -kv[1])[:TOP_NGRAM_COUNT]
			self._ranked_ngrams = dict(zip([key for (key, freq) in ordered_ngrams], itertools.count(0)))

		return self._ranked_ngrams

	def compare(self, other):
		"""take k most common
		use frequencies x order
		use letter, digrams, trigrams
		use absolute x square"""
		res = sum(abs(v-other.ranked_ngrams.get(k, len(other.ranked_ngrams))) for (k, v) in self.ranked_ngrams.items()) + \
					sum(abs(v-self.ranked_ngrams.get(k, len(self.ranked_ngrams))) for (k, v) in other.ranked_ngrams.items())

		return res

	def print_overview(self):
		print(f"Sample({self.language}):")

		for freqs in self.frequencies:
			x = [
				(k, round(v, 3))
				for (k, v) in sorted(freqs.items(), key=lambda kv: -kv[1])
			]
			print("  ", x[:8], "...", x[-8:])

		print()


class SampleSet:
	def __init__(self, language):
		self.language = language
		self.texts = []
		self.samples = []

	def add(self, text):
		self.texts.append(text)
		self.samples.append(Sample(self.language, text))

	def create_model(self):
		return Sample.merge(self.samples)

	def generate_tests(self, n):
		for (i, (text, sample)) in enumerate(itertools.cycle(zip(self.texts, self.samples))):
			if i >= n:
				break

			yield (text, Sample.merge([x for x in self.samples if x is not sample]))


def cross_validate(sample_sets):
	models = [s.create_model() for s in sample_sets]
	score = 0
	max_score = 0

	for s in sample_sets:
		for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
			real_lang = partial_model.language
			test_models = [partial_model] + [m for m in models if m.language != real_lang]

			for k in TEST_LENS:
				for i in range(10):
					j = random.randrange(0, len(test_text)-k)
					t = test_text[j:j+k]
					predicted_lang = identify(t, test_models)
					if predicted_lang == real_lang:
						score += 1
					else:
						print(real_lang, predicted_lang, t)
					max_score += 1

	return score / max_score, (score, max_score)


def identify(text, models):
	sample = Sample(text=text)

	return min(map(lambda m: (m.compare(sample), m.language), models))[1]


DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
LANG_DIRS = sorted([x.path for x in os.scandir(DATA_DIR)])

if __name__ == "__main__":
	samples = []

	for d in LANG_DIRS:
		lang = os.path.basename(d)
		lang_samples = SampleSet(lang)
		samples.append(lang_samples)

		for file in sorted(os.scandir(d), key=lambda f: f.name):
			with open(file) as f:
				text = f.read()
				text = preprocess(text)
				print(f"{lang}: {file.name} ({len(text)})")

				lang_samples.add(text)

	print(cross_validate(samples))