diff --git a/languedoc.py b/languedoc.py --- a/languedoc.py +++ b/languedoc.py @@ -5,6 +5,8 @@ import itertools random.seed(19181028) +TEST_LENS = [8, 16, 32, 64] + def preprocess(text): text = re.sub(r"[\W\d_]+", " ", " "+text+" ") @@ -91,25 +93,72 @@ class Sample: print() +class SampleSet: + def __init__(self, language): + self.language = language + self.texts = [] + self.samples = [] + + def add(self, text): + self.texts.append(text) + self.samples.append(Sample(self.language, text)) + + def create_model(self): + return Sample.merge(self.samples) + + def generate_tests(self): + for (text, sample) in zip(self.texts, self.samples): + yield (text, Sample.merge([x for x in self.samples if x is not sample])) + + +def cross_validate(sample_sets): + models = [s.create_model() for s in sample_sets] + score = 0 + max_score = 0 + + for s in sample_sets: + for (test_text, partial_model) in s.generate_tests(): + partial_model.print_overview() + real_lang = partial_model.language + test_models = [partial_model] + [m for m in models if m.language != real_lang] + + for k in TEST_LENS: + j = random.randrange(0, len(test_text)-k) + t = test_text[j:j+k] + predicted_lang = identify(t, test_models) + print(real_lang, predicted_lang, t) + if predicted_lang == real_lang: + score += 1 + max_score += 1 + + return score / max_score, (score, max_score) + + +def identify(text, models): + sample = Sample(text=text) + + return min(map(lambda m: (m.compare(sample), m.language), models))[1] + + DATA_DIR = os.path.join(os.path.dirname(__file__), "data") LANG_DIRS = [x.path for x in os.scandir(DATA_DIR)] -models = dict() - -for d in LANG_DIRS: - lang = os.path.basename(d) +if __name__ == "__main__": samples = [] - for file in os.scandir(d): - with open(file) as f: - text = f.read() - text = preprocess(text) - print(f"{file.name} ({len(text)})") - print(text[:256]) - print() + for d in LANG_DIRS: + lang = os.path.basename(d) + lang_samples = SampleSet(lang) + samples.append(lang_samples) - samples.append(Sample(lang, text)) - samples[-1].print_overview() + for file in os.scandir(d): + with open(file) as f: + text = f.read() + text = preprocess(text) + print(f"{file.name} ({len(text)})") + print(text[:256]) + print() - models[lang] = Sample.merge(samples) - models[lang].print_overview() + lang_samples.add(text) + + print(cross_validate(samples))