Changeset - 167aab0c3103
[Not reviewed]
default
0 1 0
Laman - 3 years ago 2022-09-28 22:55:21

configurable top k trigrams
1 file changed with 6 insertions and 5 deletions:
0 comments (0 inline, 0 general)
languedoc.py
Show inline comments
 
import os
 
import re
 
import random
 
import itertools
 

	
 
random.seed(19181028)
 

	
 
TEST_LENS = [8, 16, 32, 64]
 
TOP_TRIGRAM_COUNT = 6000
 

	
 

	
 
def preprocess(text):
 
	text = re.sub(r"[\W\d_]+", " ", " "+text+" ")
 
	return text.lower()
 

	
 

	
 
def extract_ngram_freqs(text, k):
 
	n = len(text)
 
	d = dict()
 

	
 
	for i in range(0, n-k+1):
 
		key = text[i:i+k]
 
		if key.isspace():
 
			continue
 

	
 
		d[key] = d.get(key, 0) + 1
 

	
 
	count = sum(d.values())
 

	
 
	return {key: val/count for (key, val) in d.items()}
 

	
 

	
 
def merge_ngram_freqs(freqs):
 
	n = len(freqs)
 
	res = dict()
 

	
 
	for d in freqs:
 
		for (key, val) in d.items():
 
			res.setdefault(key, 0)
 
			res[key] += val/n
 

	
 
	return res
 

	
 

	
 
class Sample:
 
	def __init__(self, language="??", text=""):
 
		self.language = language
 
		self.frequencies = [dict(), dict(), dict()]
 

	
 
		if text:
 
			self._extract(text)
 

	
 
	def _extract(self, text):
 
		for k in range(1, 4):
 
			self.frequencies[k-1] = extract_ngram_freqs(text, k)
 

	
 
	@staticmethod
 
	def merge(samples):
 
		assert len({x.language for x in samples}) == 1
 

	
 
		res = Sample(samples[0].language)
 
		res.frequencies = []
 

	
 
		for freqs in zip(*(x.frequencies for x in samples)):
 
			res.frequencies.append(merge_ngram_freqs(freqs))
 

	
 
		return res
 

	
 
	def compare(self, other):
 
		"""take k most common
 
		use frequencies x order
 
		use letter, digrams, trigrams
 
		use absolute x square"""
 
		ordered_own_trigrams = sorted(self.frequencies[2].items(), key=lambda kv: -kv[1])[:400]
 
		ordered_other_trigrams = sorted(other.frequencies[2].items(), key=lambda kv: -kv[1])[:400]
 
		ordered_own_trigrams = sorted(self.frequencies[2].items(), key=lambda kv: -kv[1])[:TOP_TRIGRAM_COUNT]
 
		ordered_other_trigrams = sorted(other.frequencies[2].items(), key=lambda kv: -kv[1])[:TOP_TRIGRAM_COUNT]
 
		ranked_own_trigrams = dict(zip([key for (key, freq) in ordered_own_trigrams], itertools.count(0)))
 
		ranked_other_trigrams = dict(zip([key for (key, freq) in ordered_other_trigrams], itertools.count(0)))
 

	
 
		res = sum(abs(v-ranked_other_trigrams.get(k, 400)) for (k, v) in ranked_own_trigrams.items()) + \
 
					sum(abs(v-ranked_own_trigrams.get(k, 400)) for (k, v) in ranked_other_trigrams.items())
 
		print(">", self.language, res)
 
		res = sum(abs(v-ranked_other_trigrams.get(k, TOP_TRIGRAM_COUNT)) for (k, v) in ranked_own_trigrams.items()) + \
 
					sum(abs(v-ranked_own_trigrams.get(k, TOP_TRIGRAM_COUNT)) for (k, v) in ranked_other_trigrams.items())
 

	
 
		return res
 

	
 
	def print_overview(self):
 
		print(f"Sample({self.language}):")
 

	
 
		for freqs in self.frequencies:
 
			x = [
 
				(k, round(v, 3))
 
				for (k, v) in sorted(freqs.items(), key=lambda kv: -kv[1])
 
			]
 
			print("  ", x[:8], "...", x[-8:])
 

	
 
		print()
 

	
 

	
 
class SampleSet:
 
	def __init__(self, language):
 
		self.language = language
 
		self.texts = []
 
		self.samples = []
 

	
 
	def add(self, text):
 
		self.texts.append(text)
 
		self.samples.append(Sample(self.language, text))
 

	
 
	def create_model(self):
 
		return Sample.merge(self.samples)
 

	
 
	def generate_tests(self):
 
		for (text, sample) in zip(self.texts, self.samples):
 
			yield (text, Sample.merge([x for x in self.samples if x is not sample]))
 

	
 

	
 
def cross_validate(sample_sets):
 
	models = [s.create_model() for s in sample_sets]
 
	score = 0
 
	max_score = 0
 

	
 
	for s in sample_sets:
 
		for (test_text, partial_model) in s.generate_tests():
 
			partial_model.print_overview()
 
			real_lang = partial_model.language
 
			test_models = [partial_model] + [m for m in models if m.language != real_lang]
 

	
 
			for k in TEST_LENS:
 
				j = random.randrange(0, len(test_text)-k)
 
				t = test_text[j:j+k]
 
				predicted_lang = identify(t, test_models)
 
				print(real_lang, predicted_lang, t)
 
				if predicted_lang == real_lang:
 
					score += 1
 
				max_score += 1
 

	
 
	return score / max_score, (score, max_score)
 

	
 

	
 
def identify(text, models):
 
	sample = Sample(text=text)
 

	
 
	return min(map(lambda m: (m.compare(sample), m.language), models))[1]
 

	
 

	
 
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
 
LANG_DIRS = [x.path for x in os.scandir(DATA_DIR)]
 

	
 
if __name__ == "__main__":
 
	samples = []
 

	
 
	for d in LANG_DIRS:
 
		lang = os.path.basename(d)
 
		lang_samples = SampleSet(lang)
 
		samples.append(lang_samples)
 

	
 
		for file in os.scandir(d):
 
			with open(file) as f:
 
				text = f.read()
 
				text = preprocess(text)
 
				print(f"{file.name} ({len(text)})")
 
				print(text[:256])
 
				print()
 

	
 
				lang_samples.add(text)
 

	
 
	print(cross_validate(samples))
0 comments (0 inline, 0 general)