Files @ 781dc476bf41
Branch filter:

Location: Languedoc/src/languedoc/predict.py - annotation

Laman
some tests and some documentation
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
252d3b1bca60
d443541818b2
d443541818b2
781dc476bf41
781dc476bf41
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
781dc476bf41
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
d443541818b2
d443541818b2
d443541818b2
d443541818b2
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
d443541818b2
d443541818b2
d443541818b2
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
d443541818b2
d443541818b2
d443541818b2
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
781dc476bf41
781dc476bf41
d443541818b2
d443541818b2
d443541818b2
d443541818b2
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
781dc476bf41
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
d443541818b2
import os
import re
import itertools
import json
import gzip

TOP_NGRAM_COUNT = 3000
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")


def preprocess(text: str) -> str:
	"""Preprocess text by stripping non-letter characters, collapsing whitespace and converting to lowercase."""
	text = re.sub(r"[\W\d_]+", " ", " "+text+" ")
	return text.lower()


def extract_kgram_freqs(text, k):
	n = len(text)
	d = dict()

	for i in range(0, n-k+1):
		key = text[i:i+k]
		if key.isspace():
			continue

		d[key] = d.get(key, 0) + 1

	count = sum(d.values())

	return {key: val/count for (key, val) in d.items()}


def extract_ngram_freqs(text):
	frequencies = {}

	for k in range(1, 4):
		frequencies.update(extract_kgram_freqs(text, k))

	return frequencies


def rank_ngram_freqs(frequencies):
	ordered_ngrams = sorted(frequencies.items(), key=lambda kv: (-kv[1], len(kv[0]), kv[0]))[:TOP_NGRAM_COUNT]
	return dict(zip([key for (key, freq) in ordered_ngrams], itertools.count(0)))


def extract_ranked_ngrams(text):
	frequencies = extract_ngram_freqs(text)
	return rank_ngram_freqs(frequencies)


class Sample:
	def __init__(self, language: str, ranked_ngrams: dict[str, float]):
		"""Create a new Sample from language and ngrams.

		This is usually impractical and Sample.extract or Sample.load are preferred."""
		self.language = language
		self.ranked_ngrams = ranked_ngrams

	@classmethod
	def extract(cls, text: str, language="??") -> "Sample":
		"""Create a new Sample by extracting it from text.

		:param text: a string, from which to extract the ngrams into a Sample
		:param language: a two letter language code if it is known (cs|de|en|...)"""
		return cls(language, extract_ranked_ngrams(preprocess(text)))

	@classmethod
	def load(cls, exported: dict) -> "Sample":
		"""Load a previously exported dict and create a new Sample.

		:param exported: {"language": str, "ngrams": [str, ...]}"""
		ranked_ngrams = {key: order for (order, key) in enumerate(exported["ngrams"])}
		return cls(exported["language"], ranked_ngrams)

	def export(self) -> dict:
		"""Export to a dict. Complement to Sample.load()

		:return: {"language": str, "ngrams": [str, ...]}"""
		return {
			"language": self.language,
			"ngrams": [key for (key, order) in sorted(self.ranked_ngrams.items(), key=lambda key_order: key_order[1])]
		}

	def compare(self, other: "Sample") -> int:
		"""Compute a similarity score between self and other.

		The method is asymmetric. You are supposed to use sample.compare(model), not model.compare(sample).

		:param other: a reference model in known language"""
		m = len(other.ranked_ngrams)

		res = sum(
			(abs(v - other.ranked_ngrams[k]) if k in other.ranked_ngrams else m)
			for (k, v) in self.ranked_ngrams.items()
		)

		return res


def load_models(model_path: str) -> list[Sample]:
	"""Load language models from path and return as a list."""
	with gzip.open(model_path, mode="rt", encoding="utf-8") as f:
		return [Sample.load(obj) for obj in json.load(f)]


def identify(text: str, models=[]) -> str:
	"""Return the language closest to text among the models.

	:param text: the text to identify
	:param models: list of models to choose from. The default is loaded from MODEL_PATH
	:return: best matching language (cs, de, en, ...)"""
	if not models:
		models = load_models(MODEL_PATH)

	sample = Sample.extract(text)

	return sorted(models, key=lambda m: sample.compare(m))[0].language