diff --git a/src/server.py b/src/server.py --- a/src/server.py +++ b/src/server.py @@ -1,5 +1,6 @@ import hashlib import socket +import ssl import logging as log from hashtree import HashTree @@ -8,8 +9,10 @@ import config as conf class Connection: - def __init__(self,serverSocket): - self._socket, address = serverSocket.accept() + def __init__(self,serverSocket,sslContext): + sock, address = serverSocket.accept() + self._socket=sslContext.wrap_socket(sock,server_side=True) + log.info('Connected by {0}'.format(address)) fr=self._socket.makefile(mode="rb") fw=self._socket.makefile(mode="wb") @@ -21,6 +24,7 @@ class Connection: return self.incoming,self.outcoming def __exit__(self, exc_type, exc_val, exc_tb): + self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() @@ -37,6 +41,10 @@ class Server: self._newLeaves=dict() self.BLOCK_SIZE=self._tree.BLOCK_SIZE + self._ssl=ssl.create_default_context(ssl.Purpose.CLIENT_AUTH,cafile=conf.peers) + self._ssl.verify_mode=ssl.CERT_REQUIRED + self._ssl.load_cert_chain(conf.certfile,conf.keyfile) + self._ss = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._ss.bind(("", conf.port)) self._ss.listen(1) @@ -46,7 +54,7 @@ class Server: def serve(self): while True: - with Connection(self._ss) as (incoming, outcoming): + with Connection(self._ss,self._ssl) as (incoming, outcoming): try: while True: if not self._serveOne(incoming,outcoming): return