diff --git a/src/languedoc/predict.py b/src/languedoc/predict.py --- a/src/languedoc/predict.py +++ b/src/languedoc/predict.py @@ -14,39 +14,37 @@ def preprocess(text: str) -> str: return text.lower() -def extract_kgram_freqs(text, k): +def extract_kgram_counts(text, k): n = len(text) - d = dict() + counts = dict() for i in range(0, n-k+1): key = text[i:i+k] if key.isspace(): continue - d[key] = d.get(key, 0) + 1 + counts[key] = counts.get(key, 0) + 1 - count = sum(d.values()) - - return {key: val/count for (key, val) in d.items()} + return counts -def extract_ngram_freqs(text): - frequencies = {} +def extract_ngram_counts(text): + counts = dict() for k in range(1, 4): - frequencies.update(extract_kgram_freqs(text, k)) + counts.update(extract_kgram_counts(text, k)) - return frequencies + return counts -def rank_ngram_freqs(frequencies): - ordered_ngrams = sorted(frequencies.items(), key=lambda kv: (-kv[1], len(kv[0]), kv[0]))[:TOP_NGRAM_COUNT] +def rank_ngram_counts(counts): + ordered_ngrams = sorted(counts.items(), key=lambda kv: (-kv[1], len(kv[0]), kv[0]))[:TOP_NGRAM_COUNT] return dict(zip([key for (key, freq) in ordered_ngrams], itertools.count(0))) def extract_ranked_ngrams(text): - frequencies = extract_ngram_freqs(text) - return rank_ngram_freqs(frequencies) + counts = extract_ngram_counts(text) + return rank_ngram_counts(counts) class Sample: diff --git a/src/languedoc/train.py b/src/languedoc/train.py --- a/src/languedoc/train.py +++ b/src/languedoc/train.py @@ -4,7 +4,7 @@ import itertools import json import gzip -from .predict import preprocess, identify, extract_ngram_freqs, rank_ngram_freqs, Sample +from languedoc.predict import preprocess, identify, extract_ngram_counts, rank_ngram_counts, Sample random.seed(19181028) @@ -12,14 +12,15 @@ CROSSVALIDATION_SOURCE_COUNT = 5 TEST_LENS = [8, 16, 32, 64] -def merge_ngram_freqs(freqs): - n = len(freqs) +def merge_ngram_freqs(counts): + n = len(counts) res = dict() - for d in freqs: + for d in counts: + k = sum(d.values()) for (key, val) in d.items(): res.setdefault(key, 0) - res[key] += val/n + res[key] += val/k/n return res @@ -28,23 +29,23 @@ class SampleSet: def __init__(self, language): self.language = language self.texts = [] - self.frequencies = [] + self.counts = [] def add(self, text): self.texts.append(text) - self.frequencies.append(extract_ngram_freqs(text)) + self.counts.append(extract_ngram_counts(text)) def create_model(self): - merged_frequencies = merge_ngram_freqs(self.frequencies) - res = Sample(self.language, rank_ngram_freqs(merged_frequencies)) + merged_frequencies = merge_ngram_freqs(self.counts) + res = Sample(self.language, rank_ngram_counts(merged_frequencies)) return res def generate_tests(self, n): - for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.frequencies))): + for (i, (text, freqs)) in enumerate(itertools.cycle(zip(self.texts, self.counts))): if i >= n: break - ranked_ngrams = rank_ngram_freqs(merge_ngram_freqs([f for f in self.frequencies if f is not freqs])) + ranked_ngrams = rank_ngram_counts(merge_ngram_freqs([f for f in self.counts if f is not freqs])) yield (text, Sample(self.language, ranked_ngrams)) diff --git a/tests/test_predict.py b/tests/test_predict.py --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,6 +1,6 @@ from unittest import TestCase -from languedoc.predict import preprocess, rank_ngram_freqs, Sample, identify +from languedoc.predict import preprocess, rank_ngram_counts, Sample, identify class TestPredict(TestCase): @@ -10,24 +10,24 @@ class TestPredict(TestCase): self.assertEqual(preprocess("1% "), " ") self.assertEqual(preprocess("Глава ĚŠČŘŽ"), " глава ěščřž ") - def test_rank_ngram_freqs(self): + def test_rank_ngram_counts(self): freqs = {"a": 3, "aa": 1, "b": 4, "bb": 1, "c": 1} expected = {"b": 0, "a": 1, "c": 2, "aa": 3, "bb": 4} - self.assertEqual(rank_ngram_freqs(freqs), expected) + self.assertEqual(rank_ngram_counts(freqs), expected) class TestSample(TestCase): def test_extract(self): a = Sample.extract("aaaaaa", "a") self.assertEqual(a.language, "a") - self.assertEqual(a.ranked_ngrams, {'a': 0, 'aa': 1, 'aaa': 2, ' aa': 3, 'aa ': 4, ' a': 5, 'a ': 6}) + self.assertEqual(a.ranked_ngrams, {'a': 0, 'aa': 1, 'aaa': 2, ' a': 3, 'a ': 4, ' aa': 5, 'aa ': 6}) b = Sample.extract("aa aa aa", "b") - self.assertEqual(b.ranked_ngrams, {'a': 0, ' aa': 1, 'aa ': 2, ' a': 3, 'a ': 4, 'aa': 5, 'a a': 6}) + self.assertEqual(b.ranked_ngrams, {'a': 0, ' a': 1, 'a ': 2, 'aa': 3, ' aa': 4, 'aa ': 5, 'a a': 6}) c = Sample.extract("aa") self.assertEqual(c.language, "??") - self.assertEqual(c.ranked_ngrams, {'a': 0, ' aa': 1, 'aa ': 2, ' a': 3, 'a ': 4, 'aa': 5}) + self.assertEqual(c.ranked_ngrams, {'a': 0, ' a': 1, 'a ': 2, 'aa': 3, ' aa': 4, 'aa ': 5}) class TestIdentify(TestCase):