Files
@ 252d3b1bca60
Branch filter:
Location: Languedoc/src/languedoc/train.py - annotation
252d3b1bca60
2.5 KiB
text/x-python
models file included in the package
d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 252d3b1bca60 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 d443541818b2 | import os
import random
import itertools
import json
import gzip
from predict import preprocess, identify, extract_ngram_freqs, rank_ngram_freqs, Sample
random.seed(19181028)
CROSSVALIDATION_SOURCE_COUNT = 5
TEST_LENS = [8, 16, 32, 64]
def merge_ngram_freqs(freqs):
n = len(freqs)
res = dict()
for d in freqs:
for (key, val) in d.items():
res.setdefault(key, 0)
res[key] += val/n
return res
class SampleSet:
def __init__(self, language):
self.language = language
self.texts = []
self.frequencies = []
def add(self, text):
self.texts.append(text)
self.frequencies.append(extract_ngram_freqs(text))
def create_model(self):
merged_frequencies = merge_ngram_freqs(self.frequencies)
res = Sample(self.language, rank_ngram_freqs(merged_frequencies))
return res
def generate_tests(self, n):
for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.frequencies))):
if i >= n:
break
ranked_ngrams = rank_ngram_freqs(merge_ngram_freqs([f for f in self.frequencies if f is not freqs]))
yield (text, Sample(self.language, ranked_ngrams))
def cross_validate(sample_sets):
models = [s.create_model() for s in sample_sets]
score = 0
max_score = 0
for s in sample_sets:
for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
real_lang = partial_model.language
test_models = [partial_model] + [m for m in models if m.language != real_lang]
for k in TEST_LENS:
for i in range(10):
j = random.randrange(0, len(test_text)-k)
t = test_text[j:j+k]
predicted_lang = identify(t, test_models)
if predicted_lang == real_lang:
score += 1
else:
print(real_lang, predicted_lang, t)
max_score += 1
return score / max_score, (score, max_score)
DATA_DIR = os.path.join(os.path.dirname(__file__), "../../data")
LANG_DIRS = sorted([x.path for x in os.scandir(DATA_DIR)])
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")
if __name__ == "__main__":
samples = []
for d in LANG_DIRS:
lang = os.path.basename(d)
lang_samples = SampleSet(lang)
samples.append(lang_samples)
for file in sorted(os.scandir(d), key=lambda f: f.name):
with open(file) as f:
text = f.read()
text = preprocess(text)
print(f"{lang}: {file.name} ({len(text)})")
lang_samples.add(text)
with gzip.open(MODEL_PATH, mode="wt", encoding="utf-8") as f:
json.dump([sample_set.create_model().export() for sample_set in samples], f, ensure_ascii=False)
print(cross_validate(samples))
|