Changeset - 3998161856de
[Not reviewed]
default
0 1 0
Laman - 21 months ago 2023-07-10 18:47:29

optimized the model loading
1 file changed with 5 insertions and 1 deletions:
0 comments (0 inline, 0 general)
src/languedoc/predict.py
Show inline comments
 
import os
 
import re
 
import itertools
 
import json
 
import gzip
 
from typing import Union
 

	
 
TOP_NGRAM_COUNT = 3000
 
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")
 
MODEL = []
 

	
 

	
 
def preprocess(text: str) -> str:
 
	"""Preprocess text by stripping non-letter characters, collapsing whitespace and converting to lowercase."""
 
	text = re.sub(r"[\W\d_]+", " ", " "+text+" ")
 
	return text.lower()
 

	
 

	
 
def extract_kgram_counts(text: str, k: int) -> dict[str, int]:
 
	"""Extract k-gram counts from the text for a provided k.
 

	
 
	:param text: the source text
 
	:param k: length of the kgrams to extract. 1 for letters, 2 for bigrams, ...
 
	:return: a dict mapping kgrams to their counts in the text"""
 
	n = len(text)
 
	counts = dict()
 

	
 
	for i in range(0, n-k+1):
 
		key = text[i:i+k]
 
		if key.isspace():
 
			continue
 

	
 
		counts[key] = counts.get(key, 0) + 1
 

	
 
	return counts
 

	
 

	
 
def extract_ngram_counts(text: str) -> dict[str, int]:
 
	"""Extract counts of 1- to 3-grams from the text.
 

	
 
	:param text: the source text
 
	:return: a dict mapping ngrams to their counts in the text"""
 
	counts = dict()
 

	
 
	for k in range(1, 4):
 
		counts.update(extract_kgram_counts(text, k))
 

	
 
	return counts
 

	
 

	
 
def rank_ngram_counts(counts: dict[str, Union[int, float]]) -> dict[str, int]:
 
	"""Order supplied ngrams by their counts (then length, then alphabetically) and return their ranking.
 

	
 
	:param counts: a dict mapping ngrams to their counts
 
	:return: a dict mapping ngrams to their rank (the most frequent: 0, the second: 1, ...)"""
 
	ordered_ngrams = sorted(counts.items(), key=lambda kv: (-kv[1], len(kv[0]), kv[0]))[:TOP_NGRAM_COUNT]
 
	return dict(zip([key for (key, count) in ordered_ngrams], itertools.count(0)))
 

	
 
@@ -81,54 +82,57 @@ class Sample:
 
		:param language: a two letter language code if it is known (cs|de|en|...)"""
 
		return cls(language, extract_ranked_ngrams(preprocess(text)))
 

	
 
	@classmethod
 
	def load(cls, exported: dict) -> "Sample":
 
		"""Load a previously exported dict and create a new Sample.
 

	
 
		:param exported: {"language": str, "ngrams": [str, ...]}"""
 
		ranked_ngrams = {key: order for (order, key) in enumerate(exported["ngrams"])}
 
		return cls(exported["language"], ranked_ngrams)
 

	
 
	def export(self) -> dict:
 
		"""Export to a dict. Complement to Sample.load()
 

	
 
		:return: {"language": str, "ngrams": [str, ...]}"""
 
		return {
 
			"language": self.language,
 
			"ngrams": [key for (key, order) in sorted(self.ranked_ngrams.items(), key=lambda key_order: key_order[1])]
 
		}
 

	
 
	def compare(self, other: "Sample") -> int:
 
		"""Compute a similarity score between self and other.
 

	
 
		The method is asymmetric. You are supposed to use sample.compare(model), not model.compare(sample).
 

	
 
		:param other: a reference model in known language"""
 
		m = len(other.ranked_ngrams)
 

	
 
		res = sum(
 
			(abs(v - other.ranked_ngrams[k]) if k in other.ranked_ngrams else m)
 
			for (k, v) in self.ranked_ngrams.items()
 
		)
 

	
 
		return res
 

	
 

	
 
def load_models(model_path: str) -> list[Sample]:
 
	"""Load language models from path and return as a list."""
 
	with gzip.open(model_path, mode="rt", encoding="utf-8") as f:
 
		return [Sample.load(obj) for obj in json.load(f)]
 

	
 

	
 
def identify(text: str, models=[]) -> str:
 
	"""Return the language closest to text among the models.
 

	
 
	:param text: the text to identify
 
	:param models: list of models to choose from. The default is loaded from MODEL_PATH
 
	:return: best matching language (cs, de, en, ...)"""
 
	global MODEL
 
	if not MODEL and not models:
 
		MODEL = load_models(MODEL_PATH)
 
	if not models:
 
		models = load_models(MODEL_PATH)
 
		models = MODEL
 

	
 
	sample = Sample.extract(text)
 

	
 
	return sorted(models, key=lambda m: sample.compare(m))[0].language
0 comments (0 inline, 0 general)