# HG changeset patch # User Laman # Date 2022-10-18 22:58:35 # Node ID 781dc476bf41be33b2c5885baa3abc417e23b2ef # Parent 252d3b1bca6051198452fee201d4c88202712b5f some tests and some documentation diff --git a/src/languedoc/predict.py b/src/languedoc/predict.py --- a/src/languedoc/predict.py +++ b/src/languedoc/predict.py @@ -8,7 +8,8 @@ TOP_NGRAM_COUNT = 3000 MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz") -def preprocess(text): +def preprocess(text: str) -> str: + """Preprocess text by stripping non-letter characters, collapsing whitespace and converting to lowercase.""" text = re.sub(r"[\W\d_]+", " ", " "+text+" ") return text.lower() @@ -39,7 +40,7 @@ def extract_ngram_freqs(text): def rank_ngram_freqs(frequencies): - ordered_ngrams = sorted(frequencies.items(), key=lambda kv: -kv[1])[:TOP_NGRAM_COUNT] + ordered_ngrams = sorted(frequencies.items(), key=lambda kv: (-kv[1], len(kv[0]), kv[0]))[:TOP_NGRAM_COUNT] return dict(zip([key for (key, freq) in ordered_ngrams], itertools.count(0))) @@ -49,26 +50,44 @@ def extract_ranked_ngrams(text): class Sample: - def __init__(self, language, ranked_ngrams): + def __init__(self, language: str, ranked_ngrams: dict[str, float]): + """Create a new Sample from language and ngrams. + + This is usually impractical and Sample.extract or Sample.load are preferred.""" self.language = language self.ranked_ngrams = ranked_ngrams @classmethod - def extract(cls, text, language="??"): + def extract(cls, text: str, language="??") -> "Sample": + """Create a new Sample by extracting it from text. + + :param text: a string, from which to extract the ngrams into a Sample + :param language: a two letter language code if it is known (cs|de|en|...)""" return cls(language, extract_ranked_ngrams(preprocess(text))) @classmethod - def load(cls, exported): + def load(cls, exported: dict) -> "Sample": + """Load a previously exported dict and create a new Sample. + + :param exported: {"language": str, "ngrams": [str, ...]}""" ranked_ngrams = {key: order for (order, key) in enumerate(exported["ngrams"])} return cls(exported["language"], ranked_ngrams) - def export(self): + def export(self) -> dict: + """Export to a dict. Complement to Sample.load() + + :return: {"language": str, "ngrams": [str, ...]}""" return { "language": self.language, "ngrams": [key for (key, order) in sorted(self.ranked_ngrams.items(), key=lambda key_order: key_order[1])] } - def compare(self, other): + def compare(self, other: "Sample") -> int: + """Compute a similarity score between self and other. + + The method is asymmetric. You are supposed to use sample.compare(model), not model.compare(sample). + + :param other: a reference model in known language""" m = len(other.ranked_ngrams) res = sum( @@ -79,12 +98,18 @@ class Sample: return res -def load_models(model_path): +def load_models(model_path: str) -> list[Sample]: + """Load language models from path and return as a list.""" with gzip.open(model_path, mode="rt", encoding="utf-8") as f: return [Sample.load(obj) for obj in json.load(f)] -def identify(text, models=[]): +def identify(text: str, models=[]) -> str: + """Return the language closest to text among the models. + + :param text: the text to identify + :param models: list of models to choose from. The default is loaded from MODEL_PATH + :return: best matching language (cs, de, en, ...)""" if not models: models = load_models(MODEL_PATH) diff --git a/src/languedoc/train.py b/src/languedoc/train.py --- a/src/languedoc/train.py +++ b/src/languedoc/train.py @@ -4,7 +4,7 @@ import itertools import json import gzip -from predict import preprocess, identify, extract_ngram_freqs, rank_ngram_freqs, Sample +from .predict import preprocess, identify, extract_ngram_freqs, rank_ngram_freqs, Sample random.seed(19181028) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 diff --git a/tests/test_predict.py b/tests/test_predict.py new file mode 100644 --- /dev/null +++ b/tests/test_predict.py @@ -0,0 +1,46 @@ +from unittest import TestCase + +from languedoc.predict import preprocess, rank_ngram_freqs, Sample, identify + + +class TestPredict(TestCase): + def test_preprocess(self): + self.assertEqual(preprocess("abc"), " abc ") + self.assertEqual(preprocess("A b.c"), " a b c ") + self.assertEqual(preprocess("1% "), " ") + self.assertEqual(preprocess("Глава ĚŠČŘŽ"), " глава ěščřž ") + + def test_rank_ngram_freqs(self): + freqs = {"a": 3, "aa": 1, "b": 4, "bb": 1, "c": 1} + expected = {"b": 0, "a": 1, "c": 2, "aa": 3, "bb": 4} + self.assertEqual(rank_ngram_freqs(freqs), expected) + + +class TestSample(TestCase): + def test_extract(self): + a = Sample.extract("aaaaaa", "a") + self.assertEqual(a.language, "a") + self.assertEqual(a.ranked_ngrams, {'a': 0, 'aa': 1, 'aaa': 2, ' aa': 3, 'aa ': 4, ' a': 5, 'a ': 6}) + + b = Sample.extract("aa aa aa", "b") + self.assertEqual(b.ranked_ngrams, {'a': 0, ' aa': 1, 'aa ': 2, ' a': 3, 'a ': 4, 'aa': 5, 'a a': 6}) + + c = Sample.extract("aa") + self.assertEqual(c.language, "??") + self.assertEqual(c.ranked_ngrams, {'a': 0, ' aa': 1, 'aa ': 2, ' a': 3, 'a ': 4, 'aa': 5}) + + +class TestIdentify(TestCase): + def test_identify(self): + samples = [ + ("cs", "Severní ledový oceán je nejmenší světový oceán."), + ("de", "Der Arktische Ozean ist der kleinste Ozean der Erde."), + ("en", "The Arctic Ocean is the smallest of the world's oceans."), + ("es", "Océano Ártico más pequeña y más septentrional del planeta"), + ("fr", "L'océan Arctique ce qui en fait le plus petit des océans."), + ("it", "Il Mar Glaciale Artico è una massa d'acqua..."), + ("ru", "Се́верный Ледови́тый океа́н — наименьший по площади океан Земли") + ] + + for (lang, sample) in samples: + self.assertEqual(lang, identify(sample))