import os import random import itertools import json import gzip from .predict import preprocess, identify, extract_ngram_freqs, rank_ngram_freqs, Sample random.seed(19181028) CROSSVALIDATION_SOURCE_COUNT = 5 TEST_LENS = [8, 16, 32, 64] def merge_ngram_freqs(freqs): n = len(freqs) res = dict() for d in freqs: for (key, val) in d.items(): res.setdefault(key, 0) res[key] += val/n return res class SampleSet: def __init__(self, language): self.language = language self.texts = [] self.frequencies = [] def add(self, text): self.texts.append(text) self.frequencies.append(extract_ngram_freqs(text)) def create_model(self): merged_frequencies = merge_ngram_freqs(self.frequencies) res = Sample(self.language, rank_ngram_freqs(merged_frequencies)) return res def generate_tests(self, n): for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.frequencies))): if i >= n: break ranked_ngrams = rank_ngram_freqs(merge_ngram_freqs([f for f in self.frequencies if f is not freqs])) yield (text, Sample(self.language, ranked_ngrams)) def cross_validate(sample_sets): models = [s.create_model() for s in sample_sets] score = 0 max_score = 0 for s in sample_sets: for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT): real_lang = partial_model.language test_models = [partial_model] + [m for m in models if m.language != real_lang] for k in TEST_LENS: for i in range(10): j = random.randrange(0, len(test_text)-k) t = test_text[j:j+k] predicted_lang = identify(t, test_models) if predicted_lang == real_lang: score += 1 else: print(real_lang, predicted_lang, t) max_score += 1 return score / max_score, (score, max_score) DATA_DIR = os.path.join(os.path.dirname(__file__), "../../data") LANG_DIRS = sorted([x.path for x in os.scandir(DATA_DIR)]) MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz") if __name__ == "__main__": samples = [] for d in LANG_DIRS: lang = os.path.basename(d) lang_samples = SampleSet(lang) samples.append(lang_samples) for file in sorted(os.scandir(d), key=lambda f: f.name): with open(file) as f: text = f.read() text = preprocess(text) print(f"{lang}: {file.name} ({len(text)})") lang_samples.add(text) with gzip.open(MODEL_PATH, mode="wt", encoding="utf-8") as f: json.dump([sample_set.create_model().export() for sample_set in samples], f, ensure_ascii=False) print(cross_validate(samples))