Files @ f82e9a5b1c2c
Branch filter:

Location: Languedoc/tests/test_train.py

Laman
added more tests
import unittest
from unittest import TestCase

from languedoc.train import merge_ngram_freqs, cross_validate, SampleSet


class TestTrain(TestCase):
	def test_merge_ngram_freqs(self):
		a = {"a": 3, "b": 1, "c": 4}
		b = {"b": 1, "c": 5, "d": 9}
		c = merge_ngram_freqs([a, b])
		self.assertEqual(c, {"a": 3/8/2, "b": (1/8+1/15)/2, "c": (4/8+5/15)/2, "d": 9/15/2})
		self.assertEqual(sum(c.values()), 1)

	@unittest.skip
	def test_crossvalidate(self):
		pass


class TestSampleSet(TestCase):
	def test_add(self):
		sample_set = SampleSet("xy")
		self.assertEqual(sample_set.texts, [])
		self.assertEqual(sample_set.counts, [])

		sample_set.add("aaab")
		self.assertEqual(sample_set.texts, ["aaab"])
		self.assertEqual(sample_set.counts, [{"a": 3, "b": 1, "aa": 2, "ab": 1, "aaa": 1, "aab": 1}])

	def test_create_model(self):
		sample_set = SampleSet("xy")
		sample_set.add("aaab")
		sample_set.add("aab")

		res = sample_set.create_model()

		self.assertEqual(res.language, "xy")
		self.assertEqual(res.ranked_ngrams, {"a": 0, "aa": 1, "b": 2, "ab": 3, "aab": 4, "aaa": 5})

	@unittest.skip
	def test_generate_tests(self):
		pass