Files @ 8b18810a3c7c
Branch filter:

Location: Languedoc/src/languedoc/train.py

Laman
extended documentation
import os
import random
import itertools
import json
import gzip
from typing import Iterable

from languedoc.predict import preprocess, identify, extract_ngram_counts, rank_ngram_counts, Sample

random.seed(19181028)

CROSSVALIDATION_SOURCE_COUNT = 5
TEST_LENS = [8, 16, 32, 64]


def merge_ngram_freqs(counts: list[dict[str, int]]) -> dict[str, float]:
	"""Merge together ngram frequencies from multiple source texts."""
	n = len(counts)
	res = dict()

	for d in counts:
		k = sum(d.values())
		for (key, val) in d.items():
			res.setdefault(key, 0)
			res[key] += val/k/n

	return res


class SampleSet:
	def __init__(self, language):
		self.language = language
		self.texts = []
		self.counts = []

	def add(self, text: str):
		"""Add another source text and its ngram counts."""
		self.texts.append(text)
		self.counts.append(extract_ngram_counts(text))

	def create_model(self) -> Sample:
		"""Create a language model based on SampleSet data."""
		merged_frequencies = merge_ngram_freqs(self.counts)
		res = Sample(self.language, rank_ngram_counts(merged_frequencies))
		return res

	def generate_tests(self, n: int) -> Iterable[tuple[str, Sample]]:
		"""Generate tests for crossvalidation.

		Yield source texts and the corresponding models built from the other texts, cycling as necessary.
		Therefore, one can test the models with the texts.

		:param n: how many tests to generate
		:return: pairs of texts and models"""
		for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.counts))):
			if i >= n:
				break

			ranked_ngrams = rank_ngram_counts(merge_ngram_freqs([f for f in self.counts if f is not freqs]))
			yield (text, Sample(self.language, ranked_ngrams))


def cross_validate(sample_sets: list[SampleSet]) -> tuple[float, int, int]:
	"""Run 10-fold crossvalidation on the samples.

	Iterate through the languages, for each generate `CROSSVALIDATION_SOURCE_COUNT` tests
	with one source text left out, then identify ten random excerpts for each length from `TEST_LENS`.

	:param sample_sets: sample sets of all target languages
	:return: ratio of correctly predicted samples, its absolute number and the theoretical maximum"""
	models = [s.create_model() for s in sample_sets]
	score = 0
	max_score = 0

	for s in sample_sets:
		for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
			real_lang = partial_model.language
			test_models = [partial_model] + [m for m in models if m.language != real_lang]

			for k in TEST_LENS:
				for i in range(10):
					j = random.randrange(0, len(test_text)-k)
					t = test_text[j:j+k]
					predicted_lang = identify(t, test_models)
					if predicted_lang == real_lang:
						score += 1
					else:
						print(real_lang, predicted_lang, t)
					max_score += 1

	return score/max_score, score, max_score


def train(data_dir: str, model_path: str):
	"""Run the training and create a prediction model.
	files
	:param data_dir: path to the data directory, with one subdirectory for each language
		containing several text files as separate sources.
	:param model_path: where to save the result language model as a .json.gz"""
	samples = []
	lang_dirs = sorted([x.path for x in os.scandir(data_dir)])

	for d in lang_dirs:
		lang = os.path.basename(d)
		lang_samples = SampleSet(lang)
		samples.append(lang_samples)

		for file in sorted(os.scandir(d), key=lambda f: f.name):
			with open(file) as f:
				text = f.read()
				text = preprocess(text)
				print(f"{lang}: {file.name} ({len(text)})")

				lang_samples.add(text)

	with gzip.open(model_path, mode="wt", encoding="utf-8") as f:
		json.dump([sample_set.create_model().export() for sample_set in samples], f, ensure_ascii=False)

	print(cross_validate(samples))


DATA_DIR = os.path.join(os.path.dirname(__file__), "../../data")
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")

if __name__ == "__main__":
	train(DATA_DIR, MODEL_PATH)