diff --git a/src/hashtree.py b/src/hashtree.py --- a/src/hashtree.py +++ b/src/hashtree.py @@ -1,4 +1,5 @@ -import hashlib +import collections +import hashlib import os from datetime import datetime @@ -85,3 +86,19 @@ class HashTree: self.store[i]=hashBlock(self.store[i*2+1]+self.store[i*2+2]) progress.p(i) progress.done() + + ## Update faster than repeated insertLeaf. + def batchUpdate(self,keysHashes): + queue=collections.deque() + for (k,v) in sorted(keysHashes): + self.store[k]=v + parentK=(k-1)//2 + if len(queue)==0 or queue[-1]!=parentK: + queue.append(parentK) + + while len(queue)>0: + k=queue.pop() + self.store[k]=hashBlock(self.store[k*2+1]+self.store[k*2+2]) + parentK=(k-1)//2 + if (len(queue)==0 or queue[0]!=parentK) and k!=0: + queue.appendleft(parentK) diff --git a/src/netnode.py b/src/netnode.py --- a/src/netnode.py +++ b/src/netnode.py @@ -49,6 +49,5 @@ class NetNode: def _updateTree(self): log.info("updating hash tree...") - for (k,v) in self._newLeaves.items(): - self._tree.updateLeaf(k, v) + self._tree.batchUpdate(self._newLeaves.items()) self._tree.save(self._treeFile) diff --git a/src/tests/__init__.py b/src/tests/__init__.py --- a/src/tests/__init__.py +++ b/src/tests/__init__.py @@ -0,0 +1,15 @@ +import sys + + +class RedirectedOutput(): + _stdout=None + + @classmethod + def setUpClass(cls): + cls._stdout=sys.stdout + sys.stdout=open("/tmp/morevna-stdout.log",mode="a") + + @classmethod + def tearDownClass(cls): + sys.stdout.close() + sys.stdout=cls._stdout diff --git a/src/tests/test_hashtree.py b/src/tests/test_hashtree.py new file mode 100644 --- /dev/null +++ b/src/tests/test_hashtree.py @@ -0,0 +1,31 @@ +import random +from unittest import TestCase + +from hashtree import HashTree +from . import RedirectedOutput + + +random.seed(17) + + +def buildTree(leaves): + tree=HashTree(len(leaves)) + for l in leaves: + tree.insertLeaf(l) + tree.buildTree() + return tree + + +class TestMorevna(RedirectedOutput,TestCase): + def test_batchUpdate(self): + leaves=[b"a" for i in range(8)] + t1=buildTree(leaves) + keys=list(range(8)) + + for i in range(8): + random.shuffle(keys) + for k in keys[:i+1]: + leaves[k]=bytes([random.randrange(256)]) + t2=buildTree(leaves) + t1.batchUpdate((k+t1.leafStart,leaves[k]) for k in keys[:i+1]) + self.assertEqual(t1.store,t2.store) diff --git a/src/tests/test_overall.py b/src/tests/test_overall.py --- a/src/tests/test_overall.py +++ b/src/tests/test_overall.py @@ -10,6 +10,7 @@ import config from hashtree import HashTree from client import Client, Connection as ClientConnection from server import Miniserver +from . import RedirectedOutput config.logger.removeHandler(config.handler) @@ -31,23 +32,17 @@ def compareFiles(f1,f2): return (h,h2) -class TestMorevna(TestCase): +class TestMorevna(RedirectedOutput,TestCase): _stdout=None - @classmethod - def setUpClass(cls): - cls._stdout=sys.stdout - sys.stdout=open("/tmp/morevna-stdout.log",mode="a") - def setUp(self): src=os.path.join(dataDir,"test1.img") shutil.copyfile(src,filename) @classmethod def tearDownClass(cls): + super().tearDownClass() os.remove(filename) - sys.stdout.close() - sys.stdout=cls._stdout def test_build(self): treeFile=os.path.join(dataDir,"test.bin")