Files
@ f8fe5a65e7fc
Branch filter:
Location: Languedoc/languedoc.py - annotation
f8fe5a65e7fc
2.5 KiB
text/x-python
model saving and loading
1c7a7c3926e6 1c7a7c3926e6 5ab4acb6f293 f8fe5a65e7fc f8fe5a65e7fc 1c7a7c3926e6 f896b3675ee7 d2fa9460c0fb 1c7a7c3926e6 1c7a7c3926e6 2de09682747e 6fce04d6aa8d 6fce04d6aa8d 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 1c7a7c3926e6 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d f896b3675ee7 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d f896b3675ee7 6fce04d6aa8d 6fce04d6aa8d f896b3675ee7 f896b3675ee7 f896b3675ee7 6fce04d6aa8d 2de09682747e f896b3675ee7 2de09682747e 2de09682747e 2de09682747e f896b3675ee7 f896b3675ee7 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 2de09682747e 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 4efb46769b28 4efb46769b28 4efb46769b28 4efb46769b28 4efb46769b28 4efb46769b28 4efb46769b28 4efb46769b28 4efb46769b28 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 1c7a7c3926e6 3980aeb455b0 f8fe5a65e7fc 1c7a7c3926e6 6fce04d6aa8d 1cae4ecc8978 1c7a7c3926e6 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 1c7a7c3926e6 3980aeb455b0 6fce04d6aa8d 6fce04d6aa8d 6fce04d6aa8d 2de09682747e 1c7a7c3926e6 6fce04d6aa8d 6fce04d6aa8d f8fe5a65e7fc f8fe5a65e7fc f8fe5a65e7fc 6fce04d6aa8d | import os
import random
import itertools
import json
import gzip
from shared import preprocess, identify, extract_ngram_freqs, rank_ngram_freqs, Sample
random.seed(19181028)
CROSSVALIDATION_SOURCE_COUNT = 5
TEST_LENS = [8, 16, 32, 64]
def merge_ngram_freqs(freqs):
n = len(freqs)
res = dict()
for d in freqs:
for (key, val) in d.items():
res.setdefault(key, 0)
res[key] += val/n
return res
class SampleSet:
def __init__(self, language):
self.language = language
self.texts = []
self.frequencies = []
def add(self, text):
self.texts.append(text)
self.frequencies.append(extract_ngram_freqs(text))
def create_model(self):
merged_frequencies = merge_ngram_freqs(self.frequencies)
res = Sample(self.language, rank_ngram_freqs(merged_frequencies))
return res
def generate_tests(self, n):
for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.frequencies))):
if i >= n:
break
ranked_ngrams = rank_ngram_freqs(merge_ngram_freqs([f for f in self.frequencies if f is not freqs]))
yield (text, Sample(self.language, ranked_ngrams))
def cross_validate(sample_sets):
models = [s.create_model() for s in sample_sets]
score = 0
max_score = 0
for s in sample_sets:
for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
real_lang = partial_model.language
test_models = [partial_model] + [m for m in models if m.language != real_lang]
for k in TEST_LENS:
for i in range(10):
j = random.randrange(0, len(test_text)-k)
t = test_text[j:j+k]
predicted_lang = identify(t, test_models)
if predicted_lang == real_lang:
score += 1
else:
print(real_lang, predicted_lang, t)
max_score += 1
return score / max_score, (score, max_score)
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
LANG_DIRS = sorted([x.path for x in os.scandir(DATA_DIR)])
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")
if __name__ == "__main__":
samples = []
for d in LANG_DIRS:
lang = os.path.basename(d)
lang_samples = SampleSet(lang)
samples.append(lang_samples)
for file in sorted(os.scandir(d), key=lambda f: f.name):
with open(file) as f:
text = f.read()
text = preprocess(text)
print(f"{lang}: {file.name} ({len(text)})")
lang_samples.add(text)
with gzip.open(MODEL_PATH, mode="wt", encoding="utf-8") as f:
json.dump([sample_set.create_model().export() for sample_set in samples], f, ensure_ascii=False)
print(cross_validate(samples))
|