Files
@ 4ea2a5eb6cf4
Branch filter:
Location: Languedoc/src/languedoc/train.py
4ea2a5eb6cf4
4.1 KiB
text/x-python
merge default
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import os
import random
import itertools
import json
import gzip
from typing import Iterable
from languedoc.predict import preprocess, identify, extract_ngram_counts, rank_ngram_counts, Sample
random.seed(19181028)
CROSSVALIDATION_SOURCE_COUNT = 5
TEST_LENS = [8, 16, 32, 64]
def merge_ngram_freqs(counts: list[dict[str, int]]) -> dict[str, float]:
"""Merge together ngram frequencies from multiple source texts."""
n = len(counts)
res = dict()
for d in counts:
k = sum(d.values())
for (key, val) in d.items():
res.setdefault(key, 0)
res[key] += val/k/n
return res
class SampleSet:
def __init__(self, language):
self.language = language
self.texts = []
self.counts = []
def add(self, text: str):
"""Add another source text and its ngram counts."""
self.texts.append(text)
self.counts.append(extract_ngram_counts(text))
def create_model(self) -> Sample:
"""Create a language model based on SampleSet data."""
merged_frequencies = merge_ngram_freqs(self.counts)
res = Sample(self.language, rank_ngram_counts(merged_frequencies))
return res
def generate_tests(self, n: int) -> Iterable[tuple[str, Sample]]:
"""Generate tests for crossvalidation.
Yield source texts and the corresponding models built from the other texts, cycling as necessary.
Therefore, one can test the models with the texts.
:param n: how many tests to generate
:return: pairs of texts and models"""
for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.counts))):
if i >= n:
break
ranked_ngrams = rank_ngram_counts(merge_ngram_freqs([f for f in self.counts if f is not freqs]))
yield (text, Sample(self.language, ranked_ngrams))
def cross_validate(sample_sets: list[SampleSet]) -> tuple[float, int, int]:
"""Run 10-fold crossvalidation on the samples.
Iterate through the languages, for each generate `CROSSVALIDATION_SOURCE_COUNT` tests
with one source text left out, then identify ten random excerpts for each length from `TEST_LENS`.
:param sample_sets: sample sets of all target languages
:return: ratio of correctly predicted samples, its absolute number and the theoretical maximum"""
models = [s.create_model() for s in sample_sets]
score = 0
max_score = 0
print("# Crossvalidation:")
for s in sample_sets:
for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
real_lang = partial_model.language
test_models = [partial_model] + [m for m in models if m.language != real_lang]
for k in TEST_LENS:
for i in range(10):
j = random.randrange(0, len(test_text)-k)
t = test_text[j:j+k]
predicted_lang = identify(t, test_models)
if predicted_lang == real_lang:
score += 1
else:
print(f"{real_lang} misidentified as {predicted_lang}: {t}")
max_score += 1
return score/max_score, score, max_score
def train(data_dir: str, model_path: str):
"""Run the training and create a prediction model.
files
:param data_dir: path to the data directory, with one subdirectory for each language
containing several text files as separate sources.
:param model_path: where to save the result language model as a .json.gz"""
samples = []
lang_dirs = sorted([x.path for x in os.scandir(data_dir)])
print("# Source texts:")
for d in lang_dirs:
lang = os.path.basename(d)
lang_samples = SampleSet(lang)
samples.append(lang_samples)
for file in sorted(os.scandir(d), key=lambda f: f.name):
with open(file) as f:
text = f.read()
text = preprocess(text)
print(f"{lang}: {file.name} ({len(text)} chars)")
lang_samples.add(text)
with gzip.GzipFile(model_path, mode="wb", mtime=0) as f:
s = json.dumps([sample_set.create_model().export() for sample_set in samples], ensure_ascii=False, sort_keys=True)
f.write(s.encode("utf-8"))
print()
(acc, success, count) = cross_validate(samples)
print(f"Accuracy: {acc*100:.4f}%, ({success}/{count} tests during crossvalidation)")
DATA_DIR = os.path.join(os.path.dirname(__file__), "../../data")
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")
if __name__ == "__main__":
train(DATA_DIR, MODEL_PATH)
|