diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,42 @@ +import unittest +from unittest import TestCase + +from languedoc.train import merge_ngram_freqs, cross_validate, SampleSet + + +class TestTrain(TestCase): + def test_merge_ngram_freqs(self): + a = {"a": 3, "b": 1, "c": 4} + b = {"b": 1, "c": 5, "d": 9} + c = merge_ngram_freqs([a, b]) + self.assertEqual(c, {"a": 3/8/2, "b": (1/8+1/15)/2, "c": (4/8+5/15)/2, "d": 9/15/2}) + self.assertEqual(sum(c.values()), 1) + + @unittest.skip + def test_crossvalidate(self): + pass + + +class TestSampleSet(TestCase): + def test_add(self): + sample_set = SampleSet("xy") + self.assertEqual(sample_set.texts, []) + self.assertEqual(sample_set.counts, []) + + sample_set.add("aaab") + self.assertEqual(sample_set.texts, ["aaab"]) + self.assertEqual(sample_set.counts, [{"a": 3, "b": 1, "aa": 2, "ab": 1, "aaa": 1, "aab": 1}]) + + def test_create_model(self): + sample_set = SampleSet("xy") + sample_set.add("aaab") + sample_set.add("aab") + + res = sample_set.create_model() + + self.assertEqual(res.language, "xy") + self.assertEqual(res.ranked_ngrams, {"a": 0, "aa": 1, "b": 2, "ab": 3, "aab": 4, "aaa": 5}) + + @unittest.skip + def test_generate_tests(self): + pass