Files
@ 9c518a47ef7f
Branch filter:
Location: Languedoc/src/languedoc/train.py
9c518a47ef7f
3.9 KiB
text/x-python
1.0
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | import os
import random
import itertools
import json
import gzip
from typing import Iterable
from languedoc.predict import preprocess, identify, extract_ngram_counts, rank_ngram_counts, Sample
random.seed(19181028)
CROSSVALIDATION_SOURCE_COUNT = 5
TEST_LENS = [8, 16, 32, 64]
def merge_ngram_freqs(counts: list[dict[str, int]]) -> dict[str, float]:
"""Merge together ngram frequencies from multiple source texts."""
n = len(counts)
res = dict()
for d in counts:
k = sum(d.values())
for (key, val) in d.items():
res.setdefault(key, 0)
res[key] += val/k/n
return res
class SampleSet:
def __init__(self, language):
self.language = language
self.texts = []
self.counts = []
def add(self, text: str):
"""Add another source text and its ngram counts."""
self.texts.append(text)
self.counts.append(extract_ngram_counts(text))
def create_model(self) -> Sample:
"""Create a language model based on SampleSet data."""
merged_frequencies = merge_ngram_freqs(self.counts)
res = Sample(self.language, rank_ngram_counts(merged_frequencies))
return res
def generate_tests(self, n: int) -> Iterable[tuple[str, Sample]]:
"""Generate tests for crossvalidation.
Yield source texts and the corresponding models built from the other texts, cycling as necessary.
Therefore, one can test the models with the texts.
:param n: how many tests to generate
:return: pairs of texts and models"""
for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.counts))):
if i >= n:
break
ranked_ngrams = rank_ngram_counts(merge_ngram_freqs([f for f in self.counts if f is not freqs]))
yield (text, Sample(self.language, ranked_ngrams))
def cross_validate(sample_sets: list[SampleSet]) -> tuple[float, int, int]:
"""Run 10-fold crossvalidation on the samples.
Iterate through the languages, for each generate `CROSSVALIDATION_SOURCE_COUNT` tests
with one source text left out, then identify ten random excerpts for each length from `TEST_LENS`.
:param sample_sets: sample sets of all target languages
:return: ratio of correctly predicted samples, its absolute number and the theoretical maximum"""
models = [s.create_model() for s in sample_sets]
score = 0
max_score = 0
for s in sample_sets:
for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
real_lang = partial_model.language
test_models = [partial_model] + [m for m in models if m.language != real_lang]
for k in TEST_LENS:
for i in range(10):
j = random.randrange(0, len(test_text)-k)
t = test_text[j:j+k]
predicted_lang = identify(t, test_models)
if predicted_lang == real_lang:
score += 1
else:
print(real_lang, predicted_lang, t)
max_score += 1
return score/max_score, score, max_score
def train(data_dir: str, model_path: str):
"""Run the training and create a prediction model.
files
:param data_dir: path to the data directory, with one subdirectory for each language
containing several text files as separate sources.
:param model_path: where to save the result language model as a .json.gz"""
samples = []
lang_dirs = sorted([x.path for x in os.scandir(data_dir)])
for d in lang_dirs:
lang = os.path.basename(d)
lang_samples = SampleSet(lang)
samples.append(lang_samples)
for file in sorted(os.scandir(d), key=lambda f: f.name):
with open(file) as f:
text = f.read()
text = preprocess(text)
print(f"{lang}: {file.name} ({len(text)})")
lang_samples.add(text)
with gzip.GzipFile(model_path, mode="wb", mtime=0) as f:
s = json.dumps([sample_set.create_model().export() for sample_set in samples], ensure_ascii=False, sort_keys=True)
f.write(s.encode("utf-8"))
print(cross_validate(samples))
DATA_DIR = os.path.join(os.path.dirname(__file__), "../../data")
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")
if __name__ == "__main__":
train(DATA_DIR, MODEL_PATH)
|