Changeset - ba1303bfd58c
[Not reviewed]
default
0 1 1
Laman - 23 months ago 2023-05-06 17:19:39

determinized the model generation
2 files changed with 3 insertions and 2 deletions:
0 comments (0 inline, 0 general)
src/languedoc/models.json.gz
Show inline comments
 
new file 100644
 
binary diff not shown
src/languedoc/train.py
Show inline comments
 
@@ -68,59 +68,60 @@ def cross_validate(sample_sets: list[Sam
 

	
 
	:param sample_sets: sample sets of all target languages
 
	:return: ratio of correctly predicted samples, its absolute number and the theoretical maximum"""
 
	models = [s.create_model() for s in sample_sets]
 
	score = 0
 
	max_score = 0
 

	
 
	for s in sample_sets:
 
		for (test_text, partial_model) in s.generate_tests(CROSSVALIDATION_SOURCE_COUNT):
 
			real_lang = partial_model.language
 
			test_models = [partial_model] + [m for m in models if m.language != real_lang]
 

	
 
			for k in TEST_LENS:
 
				for i in range(10):
 
					j = random.randrange(0, len(test_text)-k)
 
					t = test_text[j:j+k]
 
					predicted_lang = identify(t, test_models)
 
					if predicted_lang == real_lang:
 
						score += 1
 
					else:
 
						print(real_lang, predicted_lang, t)
 
					max_score += 1
 

	
 
	return score/max_score, score, max_score
 

	
 

	
 
def train(data_dir: str, model_path: str):
 
	"""Run the training and create a prediction model.
 
	files
 
	:param data_dir: path to the data directory, with one subdirectory for each language
 
		containing several text files as separate sources.
 
	:param model_path: where to save the result language model as a .json.gz"""
 
	samples = []
 
	lang_dirs = sorted([x.path for x in os.scandir(data_dir)])
 

	
 
	for d in lang_dirs:
 
		lang = os.path.basename(d)
 
		lang_samples = SampleSet(lang)
 
		samples.append(lang_samples)
 

	
 
		for file in sorted(os.scandir(d), key=lambda f: f.name):
 
			with open(file) as f:
 
				text = f.read()
 
				text = preprocess(text)
 
				print(f"{lang}: {file.name} ({len(text)})")
 

	
 
				lang_samples.add(text)
 

	
 
	with gzip.open(model_path, mode="wt", encoding="utf-8") as f:
 
		json.dump([sample_set.create_model().export() for sample_set in samples], f, ensure_ascii=False)
 
	with gzip.GzipFile(model_path, mode="wb", mtime=0) as f:
 
		s = json.dumps([sample_set.create_model().export() for sample_set in samples], ensure_ascii=False, sort_keys=True)
 
		f.write(s.encode("utf-8"))
 

	
 
	print(cross_validate(samples))
 

	
 

	
 
DATA_DIR = os.path.join(os.path.dirname(__file__), "../../data")
 
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models.json.gz")
 

	
 
if __name__ == "__main__":
 
	train(DATA_DIR, MODEL_PATH)
0 comments (0 inline, 0 general)