diff --git a/src/languedoc/train.py b/src/languedoc/train.py --- a/src/languedoc/train.py +++ b/src/languedoc/train.py @@ -4,7 +4,7 @@ import itertools import json import gzip -from .predict import preprocess, identify, extract_ngram_freqs, rank_ngram_freqs, Sample +from languedoc.predict import preprocess, identify, extract_ngram_counts, rank_ngram_counts, Sample random.seed(19181028) @@ -12,14 +12,15 @@ CROSSVALIDATION_SOURCE_COUNT = 5 TEST_LENS = [8, 16, 32, 64] -def merge_ngram_freqs(freqs): - n = len(freqs) +def merge_ngram_freqs(counts): + n = len(counts) res = dict() - for d in freqs: + for d in counts: + k = sum(d.values()) for (key, val) in d.items(): res.setdefault(key, 0) - res[key] += val/n + res[key] += val/k/n return res @@ -28,23 +29,23 @@ class SampleSet: def __init__(self, language): self.language = language self.texts = [] - self.frequencies = [] + self.counts = [] def add(self, text): self.texts.append(text) - self.frequencies.append(extract_ngram_freqs(text)) + self.counts.append(extract_ngram_counts(text)) def create_model(self): - merged_frequencies = merge_ngram_freqs(self.frequencies) - res = Sample(self.language, rank_ngram_freqs(merged_frequencies)) + merged_frequencies = merge_ngram_freqs(self.counts) + res = Sample(self.language, rank_ngram_counts(merged_frequencies)) return res def generate_tests(self, n): - for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.frequencies))): + for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.counts))): if i >= n: break - ranked_ngrams = rank_ngram_freqs(merge_ngram_freqs([f for f in self.frequencies if f is not freqs])) + ranked_ngrams = rank_ngram_counts(merge_ngram_freqs([f for f in self.counts if f is not freqs])) yield (text, Sample(self.language, ranked_ngrams))