Files @ d10cf4335b76
Branch filter:

Location: Languedoc/src/languedoc/train.py - annotation

Laman
1.1
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
8b18810a3c7c
d443541818b2
dbaf68186bdf
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
8b18810a3c7c
8b18810a3c7c
dbaf68186bdf
d443541818b2
d443541818b2
dbaf68186bdf
dbaf68186bdf
d443541818b2
d443541818b2
dbaf68186bdf
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
dbaf68186bdf
d443541818b2
8b18810a3c7c
8b18810a3c7c
d443541818b2
dbaf68186bdf
d443541818b2
8b18810a3c7c
8b18810a3c7c
dbaf68186bdf
dbaf68186bdf
d443541818b2
d443541818b2
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
dbaf68186bdf
d443541818b2
d443541818b2
d443541818b2
dbaf68186bdf
d443541818b2
d443541818b2
d443541818b2
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
d443541818b2
d443541818b2
d443541818b2
d443541818b2
f1db051d658e
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
f1db051d658e
d443541818b2
d443541818b2
8b18810a3c7c
d443541818b2
d443541818b2
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
d443541818b2
f1db051d658e
8b18810a3c7c
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
f1db051d658e
d443541818b2
d443541818b2
d443541818b2
ba1303bfd58c
ba1303bfd58c
ba1303bfd58c
d443541818b2
f1db051d658e
f1db051d658e
f1db051d658e
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
8b18810a3c7c
import os
import random
import itertools
import json
import gzip
from typing import Iterable

from languedoc.predict import preprocess, identify, extract_ngram_counts, rank_ngram_counts, Sample

random.seed(19181028)

CROSSVALIDATION_SOURCE_COUNT = 5
TEST_LENS = [8, 16, 32, 64]


def merge_ngram_freqs(counts: list[dict[str, int]]) -> dict[str, float]:
	"""Merge together ngram frequencies from multiple source texts."""
	n = len(counts)
	res = dict()

	for d in counts:
		k = sum(d.values())
		for (key, val) in d.items():
			res.setdefault(key, 0)
			res[key] += val/k/n

	return res


class SampleSet:
	def __init__(self, language):
		self.language = language
		self.texts = []
		self.counts = []

	def add(self, text: str):
		"""Add another source text and its ngram counts."""
		self.texts.append(text)
		self.counts.append(extract_ngram_counts(text))

	def create_model(self) -> Sample:
		"""Create a language model based on SampleSet data."""
		merged_frequencies = merge_ngram_freqs(self.counts)
		res = Sample(self.language, rank_ngram_counts(merged_frequencies))
		return res

	def generate_tests(self, n: int) -> Iterable[tuple[str, Sample]]:
		"""Generate tests for crossvalidation.

		Yield source texts and the corresponding models built from the other texts, cycling as necessary.
		Therefore, one can test the models with the texts.

		:param n: how many tests to generate
		:return: pairs of texts and models"""
		for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.counts))):
			if i >= n:
				break

			ranked_ngrams = rank_ngram_counts(merge_ngram_freqs([f for f in self.counts if f is not freqs]))
			yield (text, Sample(self.language, ranked_ngrams))


def cross_validate(sample_sets: list[SampleSet]) -> tuple[float, int, int]:
	"""Run 10-fold crossvalidation on the samples.

	Iterate through the languages, for each generate `CROSSVALIDATION_SOURCE_COUNT` tests
	with one source text left out, then identify ten random excerpts for each length from `TEST_LENS`.

	:param sample_sets: sample sets of all target languages
	:return: ratio of correctly predicted samples, its absolute number and the theoretical maximum"""
	models = [s.create_model() for s in sample_sets]
	score = 0
	max_score = 0

	print("# Crossvalidation:")
	for s in sample_sets:
		for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
			real_lang = partial_model.language
			test_models = [partial_model] + [m for m in models if m.language != real_lang]

			for k in TEST_LENS:
				for i in range(10):
					j = random.randrange(0, len(test_text)-k)
					t = test_text[j:j+k]
					predicted_lang = identify(t, test_models)
					if predicted_lang == real_lang:
						score += 1
					else:
						print(f"{real_lang} misidentified as {predicted_lang}: {t}")
					max_score += 1

	return score/max_score, score, max_score


def train(data_dir: str, model_path: str):
	"""Run the training and create a prediction model.
	files
	:param data_dir: path to the data directory, with one subdirectory for each language
		containing several text files as separate sources.
	:param model_path: where to save the result language model as a .json.gz"""
	samples = []
	lang_dirs = sorted([x.path for x in os.scandir(data_dir)])

	print("# Source texts:")
	for d in lang_dirs:
		lang = os.path.basename(d)
		lang_samples = SampleSet(lang)
		samples.append(lang_samples)

		for file in sorted(os.scandir(d), key=lambda f: f.name):
			with open(file) as f:
				text = f.read()
				text = preprocess(text)
				print(f"{lang}: {file.name} ({len(text)} chars)")

				lang_samples.add(text)

	with gzip.GzipFile(model_path, mode="wb", mtime=0) as f:
		s = json.dumps([sample_set.create_model().export() for sample_set in samples], ensure_ascii=False, sort_keys=True)
		f.write(s.encode("utf-8"))

	print()
	(acc, success, count) = cross_validate(samples)
	print(f"Accuracy: {acc*100:.4f}%, ({success}/{count} tests during crossvalidation)")


DATA_DIR = os.path.join(os.path.dirname(__file__), "../../data")
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")

if __name__ == "__main__":
	train(DATA_DIR, MODEL_PATH)