diff --git a/languedoc.py b/languedoc.py --- a/languedoc.py +++ b/languedoc.py @@ -1,6 +1,8 @@ import os import random import itertools +import json +import gzip from shared import preprocess, identify, extract_ngram_freqs, rank_ngram_freqs, Sample @@ -72,6 +74,7 @@ def cross_validate(sample_sets): DATA_DIR = os.path.join(os.path.dirname(__file__), "data") LANG_DIRS = sorted([x.path for x in os.scandir(DATA_DIR)]) +MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz") if __name__ == "__main__": samples = [] @@ -89,4 +92,7 @@ if __name__ == "__main__": lang_samples.add(text) + with gzip.open(MODEL_PATH, mode="wt", encoding="utf-8") as f: + json.dump([sample_set.create_model().export() for sample_set in samples], f, ensure_ascii=False) + print(cross_validate(samples)) diff --git a/shared.py b/shared.py --- a/shared.py +++ b/shared.py @@ -1,7 +1,11 @@ +import os import re import itertools +import json +import gzip TOP_NGRAM_COUNT = 3000 +MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz") def preprocess(text): @@ -65,10 +69,6 @@ class Sample: } def compare(self, other): - """take k most common - use frequencies x order - use letter, digrams, trigrams - use absolute x square""" m = len(other.ranked_ngrams) res = sum( @@ -79,7 +79,15 @@ class Sample: return res -def identify(text, models): +def load_models(model_path): + with gzip.open(model_path, mode="rt", encoding="utf-8") as f: + return [Sample.load(obj) for obj in json.load(f)] + + +def identify(text, models=[]): + if not models: + models = load_models(MODEL_PATH) + sample = Sample.extract(text) return sorted(models, key=lambda m: sample.compare(m))[0].language