Files @ 2de09682747e
Branch filter:

Location: Languedoc/languedoc.py - annotation

Laman
crossvalidation handling a variable number of input texts
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
5ab4acb6f293
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
2de09682747e
6fce04d6aa8d
167aab0c3103
6fce04d6aa8d
1c7a7c3926e6
1c7a7c3926e6
1cae4ecc8978
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
5ab4acb6f293
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1c7a7c3926e6
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
5ab4acb6f293
5ab4acb6f293
5ab4acb6f293
5ab4acb6f293
167aab0c3103
167aab0c3103
5ab4acb6f293
5ab4acb6f293
5ab4acb6f293
167aab0c3103
167aab0c3103
167aab0c3103
5ab4acb6f293
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
1cae4ecc8978
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
2de09682747e
2de09682747e
2de09682747e
2de09682747e
2de09682747e
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
2de09682747e
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
4efb46769b28
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
1c7a7c3926e6
3980aeb455b0
1c7a7c3926e6
6fce04d6aa8d
1cae4ecc8978
1c7a7c3926e6
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
1c7a7c3926e6
3980aeb455b0
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
2de09682747e
1c7a7c3926e6
6fce04d6aa8d
6fce04d6aa8d
6fce04d6aa8d
import os
import re
import random
import itertools

random.seed(19181028)

CROSSVALIDATION_SOURCE_COUNT = 5
TEST_LENS = [8, 16, 32, 64]
TOP_TRIGRAM_COUNT = 6000


def preprocess(text):
	text = re.sub(r"[\W\d_]+", " ", " "+text+" ")
	return text.lower()


def extract_ngram_freqs(text, k):
	n = len(text)
	d = dict()

	for i in range(0, n-k+1):
		key = text[i:i+k]
		if key.isspace():
			continue

		d[key] = d.get(key, 0) + 1

	count = sum(d.values())

	return {key: val/count for (key, val) in d.items()}


def merge_ngram_freqs(freqs):
	n = len(freqs)
	res = dict()

	for d in freqs:
		for (key, val) in d.items():
			res.setdefault(key, 0)
			res[key] += val/n

	return res


class Sample:
	def __init__(self, language="??", text=""):
		self.language = language
		self.frequencies = [dict(), dict(), dict()]

		if text:
			self._extract(text)

	def _extract(self, text):
		for k in range(1, 4):
			self.frequencies[k-1] = extract_ngram_freqs(text, k)

	@staticmethod
	def merge(samples):
		assert len({x.language for x in samples}) == 1

		res = Sample(samples[0].language)
		res.frequencies = []

		for freqs in zip(*(x.frequencies for x in samples)):
			res.frequencies.append(merge_ngram_freqs(freqs))

		return res

	def compare(self, other):
		"""take k most common
		use frequencies x order
		use letter, digrams, trigrams
		use absolute x square"""
		ordered_own_trigrams = sorted(self.frequencies[2].items(), key=lambda kv: -kv[1])[:TOP_TRIGRAM_COUNT]
		ordered_other_trigrams = sorted(other.frequencies[2].items(), key=lambda kv: -kv[1])[:TOP_TRIGRAM_COUNT]
		ranked_own_trigrams = dict(zip([key for (key, freq) in ordered_own_trigrams], itertools.count(0)))
		ranked_other_trigrams = dict(zip([key for (key, freq) in ordered_other_trigrams], itertools.count(0)))

		res = sum(abs(v-ranked_other_trigrams.get(k, TOP_TRIGRAM_COUNT)) for (k, v) in ranked_own_trigrams.items()) + \
					sum(abs(v-ranked_own_trigrams.get(k, TOP_TRIGRAM_COUNT)) for (k, v) in ranked_other_trigrams.items())

		return res

	def print_overview(self):
		print(f"Sample({self.language}):")

		for freqs in self.frequencies:
			x = [
				(k, round(v, 3))
				for (k, v) in sorted(freqs.items(), key=lambda kv: -kv[1])
			]
			print("  ", x[:8], "...", x[-8:])

		print()


class SampleSet:
	def __init__(self, language):
		self.language = language
		self.texts = []
		self.samples = []

	def add(self, text):
		self.texts.append(text)
		self.samples.append(Sample(self.language, text))

	def create_model(self):
		return Sample.merge(self.samples)

	def generate_tests(self, n):
		for (i, (text, sample)) in enumerate(itertools.cycle(zip(self.texts, self.samples))):
			if i >= n:
				break

			yield (text, Sample.merge([x for x in self.samples if x is not sample]))


def cross_validate(sample_sets):
	models = [s.create_model() for s in sample_sets]
	score = 0
	max_score = 0

	for s in sample_sets:
		for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
			real_lang = partial_model.language
			test_models = [partial_model] + [m for m in models if m.language != real_lang]

			for k in TEST_LENS:
				for i in range(10):
					j = random.randrange(0, len(test_text)-k)
					t = test_text[j:j+k]
					predicted_lang = identify(t, test_models)
					if predicted_lang == real_lang:
						score += 1
					else:
						print(real_lang, predicted_lang, t)
					max_score += 1

	return score / max_score, (score, max_score)


def identify(text, models):
	sample = Sample(text=text)

	return min(map(lambda m: (m.compare(sample), m.language), models))[1]


DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
LANG_DIRS = sorted([x.path for x in os.scandir(DATA_DIR)])

if __name__ == "__main__":
	samples = []

	for d in LANG_DIRS:
		lang = os.path.basename(d)
		lang_samples = SampleSet(lang)
		samples.append(lang_samples)

		for file in sorted(os.scandir(d), key=lambda f: f.name):
			with open(file) as f:
				text = f.read()
				text = preprocess(text)
				print(f"{lang}: {file.name} ({len(text)})")

				lang_samples.add(text)

	print(cross_validate(samples))