Files @ d443541818b2
Branch filter:

Location: Languedoc/src/languedoc/train.py - annotation

Laman
changed the project layout
import os
import random
import itertools
import json
import gzip

from predict import preprocess, identify, extract_ngram_freqs, rank_ngram_freqs, Sample

random.seed(19181028)

CROSSVALIDATION_SOURCE_COUNT = 5
TEST_LENS = [8, 16, 32, 64]


def merge_ngram_freqs(freqs):
	n = len(freqs)
	res = dict()

	for d in freqs:
		for (key, val) in d.items():
			res.setdefault(key, 0)
			res[key] += val/n

	return res


class SampleSet:
	def __init__(self, language):
		self.language = language
		self.texts = []
		self.frequencies = []

	def add(self, text):
		self.texts.append(text)
		self.frequencies.append(extract_ngram_freqs(text))

	def create_model(self):
		merged_frequencies = merge_ngram_freqs(self.frequencies)
		res = Sample(self.language, rank_ngram_freqs(merged_frequencies))
		return res

	def generate_tests(self, n):
		for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.frequencies))):
			if i >= n:
				break

			ranked_ngrams = rank_ngram_freqs(merge_ngram_freqs([f for f in self.frequencies if f is not freqs]))
			yield (text, Sample(self.language, ranked_ngrams))


def cross_validate(sample_sets):
	models = [s.create_model() for s in sample_sets]
	score = 0
	max_score = 0

	for s in sample_sets:
		for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
			real_lang = partial_model.language
			test_models = [partial_model] + [m for m in models if m.language != real_lang]

			for k in TEST_LENS:
				for i in range(10):
					j = random.randrange(0, len(test_text)-k)
					t = test_text[j:j+k]
					predicted_lang = identify(t, test_models)
					if predicted_lang == real_lang:
						score += 1
					else:
						print(real_lang, predicted_lang, t)
					max_score += 1

	return score / max_score, (score, max_score)


DATA_DIR = os.path.join(os.path.dirname(__file__), "../../data")
LANG_DIRS = sorted([x.path for x in os.scandir(DATA_DIR)])
MODEL_PATH = os.path.join(os.path.dirname(__file__), "../../models.json.gz")

if __name__ == "__main__":
	samples = []

	for d in LANG_DIRS:
		lang = os.path.basename(d)
		lang_samples = SampleSet(lang)
		samples.append(lang_samples)

		for file in sorted(os.scandir(d), key=lambda f: f.name):
			with open(file) as f:
				text = f.read()
				text = preprocess(text)
				print(f"{lang}: {file.name} ({len(text)})")

				lang_samples.add(text)

	with gzip.open(MODEL_PATH, mode="wt", encoding="utf-8") as f:
		json.dump([sample_set.create_model().export() for sample_set in samples], f, ensure_ascii=False)

	print(cross_validate(samples))