# HG changeset patch
# User Laman
# Date 2022-10-11 22:33:33
# Node ID f8fe5a65e7fc75333b0309b149a1bbb0c4c3e467
# Parent  a27c2661a846ca8df221226875e9d2497d3f6116

model saving and loading

diff --git a/languedoc.py b/languedoc.py
--- a/languedoc.py
+++ b/languedoc.py
@@ -1,6 +1,8 @@
 import os
 import random
 import itertools
+import json
+import gzip
 
 from shared import preprocess, identify, extract_ngram_freqs, rank_ngram_freqs, Sample
 
@@ -72,6 +74,7 @@ def cross_validate(sample_sets):
 
 DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
 LANG_DIRS = sorted([x.path for x in os.scandir(DATA_DIR)])
+MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")
 
 if __name__ == "__main__":
 	samples = []
@@ -89,4 +92,7 @@ if __name__ == "__main__":
 
 				lang_samples.add(text)
 
+	with gzip.open(MODEL_PATH, mode="wt", encoding="utf-8") as f:
+		json.dump([sample_set.create_model().export() for sample_set in samples], f, ensure_ascii=False)
+
 	print(cross_validate(samples))
diff --git a/shared.py b/shared.py
--- a/shared.py
+++ b/shared.py
@@ -1,7 +1,11 @@
+import os
 import re
 import itertools
+import json
+import gzip
 
 TOP_NGRAM_COUNT = 3000
+MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")
 
 
 def preprocess(text):
@@ -65,10 +69,6 @@ class Sample:
 		}
 
 	def compare(self, other):
-		"""take k most common
-		use frequencies x order
-		use letter, digrams, trigrams
-		use absolute x square"""
 		m = len(other.ranked_ngrams)
 
 		res = sum(
@@ -79,7 +79,15 @@ class Sample:
 		return res
 
 
-def identify(text, models):
+def load_models(model_path):
+	with gzip.open(model_path, mode="rt", encoding="utf-8") as f:
+		return [Sample.load(obj) for obj in json.load(f)]
+
+
+def identify(text, models=[]):
+	if not models:
+		models = load_models(MODEL_PATH)
+
 	sample = Sample.extract(text)
 
 	return sorted(models, key=lambda m: sample.compare(m))[0].language