diff --git a/src/client.py b/src/client.py --- a/src/client.py +++ b/src/client.py @@ -8,11 +8,12 @@ import config as conf import stats from util import Progress from hashtree import HashTree,hashBlock -from networkers import NetworkReader,NetworkWriter +from netnode import BaseConnection,NetNode -class Connection: +class Connection(BaseConnection): def __init__(self): + super().__init__() sock=socket.socket(socket.AF_INET, socket.SOCK_STREAM) sslContext=ssl.create_default_context(cafile=conf.peers) @@ -21,34 +22,14 @@ class Connection: self._socket=sslContext.wrap_socket(sock) self._socket.connect((conf.hosts[0], conf.port)) - fr=self._socket.makefile(mode="rb") - fw=self._socket.makefile(mode="wb") - self.incoming=NetworkReader(fr) - self.outcoming=NetworkWriter(fw) - - def __enter__(self): - return self.incoming,self.outcoming - - def __exit__(self, exc_type, exc_val, exc_tb): - self._socket.shutdown(socket.SHUT_RDWR) - self._socket.close() + self.createNetworkers() -class Client: +class Client(NetNode): def __init__(self,filename,treeFile=""): - self._incoming=None - self._outcoming=None - self._filename=filename - self._treeFile=treeFile - print(datetime.now(), "initializing...") - if treeFile: - self._tree=HashTree.load(treeFile) - else: - self._tree=HashTree.fromFile(filename) - - self._newLeaves=dict() + super().__init__(filename,treeFile) ## Asks server for node hashes to determine which are to be transferred. # @@ -157,10 +138,7 @@ class Client: dataFile.close() if self._treeFile: - log.info("updating hash tree...") - for (k,v) in self._newLeaves.items(): - self._tree.updateLeaf(k, v) - self._tree.save(self._treeFile) + self._updateTree() def setConnection(self,connection): (self._incoming,self._outcoming)=connection diff --git a/src/netnode.py b/src/netnode.py new file mode 100644 --- /dev/null +++ b/src/netnode.py @@ -0,0 +1,48 @@ +import socket +import logging as log + +from networkers import NetworkReader,NetworkWriter +from hashtree import HashTree + + +class BaseConnection: # abstract + def __init__(self): + self._socket=None + self.incoming=None + self.outcoming=None + + def createNetworkers(self): + fr=self._socket.makefile(mode="rb") + fw=self._socket.makefile(mode="wb") + + self.incoming=NetworkReader(fr) + self.outcoming=NetworkWriter(fw) + + def __enter__(self): + return self.incoming,self.outcoming + + def __exit__(self, exc_type, exc_val, exc_tb): + self._socket.shutdown(socket.SHUT_RDWR) + self._socket.close() + + +class NetNode: + def __init__(self,filename,treeFile=""): + self._incoming=None + self._outcoming=None + + self._filename=filename + self._treeFile=treeFile + + if treeFile: + self._tree=HashTree.load(treeFile) + else: + self._tree=HashTree.fromFile(filename) + + self._newLeaves=dict() + + def _updateTree(self): + log.info("updating hash tree...") + for (k,v) in self._newLeaves.items(): + self._tree.updateLeaf(k, v) + self._tree.save(self._treeFile) diff --git a/src/server.py b/src/server.py --- a/src/server.py +++ b/src/server.py @@ -3,29 +3,20 @@ import ssl import multiprocessing import logging as log -from hashtree import HashTree,hashBlock -from networkers import NetworkReader,NetworkWriter +from hashtree import hashBlock +from netnode import BaseConnection,NetNode import config as conf -class Connection: +class Connection(BaseConnection): def __init__(self,serverSocket,sslContext): + super().__init__() + sock, address = serverSocket.accept() self._socket=sslContext.wrap_socket(sock,server_side=True) log.info('Connected by {0}'.format(address)) - fr=self._socket.makefile(mode="rb") - fw=self._socket.makefile(mode="wb") - - self.incoming=NetworkReader(fr) - self.outcoming=NetworkWriter(fw) - - def __enter__(self): - return self.incoming,self.outcoming - - def __exit__(self, exc_type, exc_val, exc_tb): - self._socket.shutdown(socket.SHUT_RDWR) - self._socket.close() + self.createNetworkers() class Miniserver: @@ -49,18 +40,11 @@ class Miniserver: p.join() -class Server: +class Server(NetNode): def __init__(self,connection,filename,treeFile=""): + super().__init__(filename,treeFile) (self._incoming,self._outcoming)=connection - self._filename=filename - self._treeFile=treeFile - if treeFile: - self._tree=HashTree.load(treeFile) - else: - self._tree=HashTree.fromFile(filename) - - self._newLeaves=dict() self.BLOCK_SIZE=self._tree.BLOCK_SIZE self._lastIndex=-1 @@ -148,8 +132,5 @@ class Server: self._dataFile.close() self._dataFileHandle=None if self._treeFile: - log.info("updating hash tree...") - for (k,v) in self._newLeaves.items(): - self._tree.updateLeaf(k, v) - self._tree.save(self._treeFile) + self._updateTree() log.info("done")