diff --git a/src/languedoc/train.py b/src/languedoc/train.py --- a/src/languedoc/train.py +++ b/src/languedoc/train.py @@ -3,6 +3,7 @@ import random import itertools import json import gzip +from typing import Iterable from languedoc.predict import preprocess, identify, extract_ngram_counts, rank_ngram_counts, Sample @@ -12,7 +13,8 @@ CROSSVALIDATION_SOURCE_COUNT = 5 TEST_LENS = [8, 16, 32, 64] -def merge_ngram_freqs(counts): +def merge_ngram_freqs(counts: list[dict[str, int]]) -> dict[str, float]: + """Merge together ngram frequencies from multiple source texts.""" n = len(counts) res = dict() @@ -31,16 +33,25 @@ class SampleSet: self.texts = [] self.counts = [] - def add(self, text): + def add(self, text: str): + """Add another source text and its ngram counts.""" self.texts.append(text) self.counts.append(extract_ngram_counts(text)) - def create_model(self): + def create_model(self) -> Sample: + """Create a language model based on SampleSet data.""" merged_frequencies = merge_ngram_freqs(self.counts) res = Sample(self.language, rank_ngram_counts(merged_frequencies)) return res - def generate_tests(self, n): + def generate_tests(self, n: int) -> Iterable[tuple[str, Sample]]: + """Generate tests for crossvalidation. + + Yield source texts and the corresponding models built from the other texts, cycling as necessary. + Therefore, one can test the models with the texts. + + :param n: how many tests to generate + :return: pairs of texts and models""" for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.counts))): if i >= n: break @@ -49,7 +60,14 @@ class SampleSet: yield (text, Sample(self.language, ranked_ngrams)) -def cross_validate(sample_sets): +def cross_validate(sample_sets: list[SampleSet]) -> tuple[float, int, int]: + """Run 10-fold crossvalidation on the samples. + + Iterate through the languages, for each generate `CROSSVALIDATION_SOURCE_COUNT` tests + with one source text left out, then identify ten random excerpts for each length from `TEST_LENS`. + + :param sample_sets: sample sets of all target languages + :return: ratio of correctly predicted samples, its absolute number and the theoretical maximum""" models = [s.create_model() for s in sample_sets] score = 0 max_score = 0 @@ -70,17 +88,19 @@ def cross_validate(sample_sets): print(real_lang, predicted_lang, t) max_score += 1 - return score / max_score, (score, max_score) + return score/max_score, score, max_score -DATA_DIR = os.path.join(os.path.dirname(__file__), "../../data") -LANG_DIRS = sorted([x.path for x in os.scandir(DATA_DIR)]) -MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz") +def train(data_dir: str, model_path: str): + """Run the training and create a prediction model. + files + :param data_dir: path to the data directory, with one subdirectory for each language + containing several text files as separate sources. + :param model_path: where to save the result language model as a .json.gz""" + samples = [] + lang_dirs = sorted([x.path for x in os.scandir(data_dir)]) -if __name__ == "__main__": - samples = [] - - for d in LANG_DIRS: + for d in lang_dirs: lang = os.path.basename(d) lang_samples = SampleSet(lang) samples.append(lang_samples) @@ -93,7 +113,14 @@ if __name__ == "__main__": lang_samples.add(text) - with gzip.open(MODEL_PATH, mode="wt", encoding="utf-8") as f: + with gzip.open(model_path, mode="wt", encoding="utf-8") as f: json.dump([sample_set.create_model().export() for sample_set in samples], f, ensure_ascii=False) print(cross_validate(samples)) + + +DATA_DIR = os.path.join(os.path.dirname(__file__), "../../data") +MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz") + +if __name__ == "__main__": + train(DATA_DIR, MODEL_PATH)