diff --git a/languedoc.py b/languedoc.py --- a/languedoc.py +++ b/languedoc.py @@ -5,6 +5,7 @@ import itertools random.seed(19181028) +CROSSVALIDATION_SOURCE_COUNT = 5 TEST_LENS = [8, 16, 32, 64] TOP_TRIGRAM_COUNT = 6000 @@ -107,8 +108,11 @@ class SampleSet: def create_model(self): return Sample.merge(self.samples) - def generate_tests(self): - for (text, sample) in zip(self.texts, self.samples): + def generate_tests(self, n): + for (i, (text, sample)) in enumerate(itertools.cycle(zip(self.texts, self.samples))): + if i >= n: + break + yield (text, Sample.merge([x for x in self.samples if x is not sample])) @@ -118,7 +122,7 @@ def cross_validate(sample_sets): max_score = 0 for s in sample_sets: - for (test_text, partial_model) in s.generate_tests(): + for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT): real_lang = partial_model.language test_models = [partial_model] + [m for m in models if m.language != real_lang] @@ -157,9 +161,7 @@ if __name__ == "__main__": with open(file) as f: text = f.read() text = preprocess(text) - print(f"{file.name} ({len(text)})") - print(text[:256]) - print() + print(f"{lang}: {file.name} ({len(text)})") lang_samples.add(text)