# HG changeset patch # User Laman # Date 2017-10-02 19:53:54 # Node ID cd2ba192bf12d60cb4c58b2416d8fae3d47cbfd0 # Parent 44cf81f3b6b85d4a31bc940a0684cc3c1475f9e4 ssl connection diff --git a/src/client.py b/src/client.py --- a/src/client.py +++ b/src/client.py @@ -1,5 +1,6 @@ import collections import socket +import ssl import logging as log from datetime import datetime @@ -11,8 +12,10 @@ from networkers import NetworkReader,Net class Connection: - def __init__(self): - self._socket=socket.socket(socket.AF_INET, socket.SOCK_STREAM) + def __init__(self,sslContext): + sock=socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + self._socket=sslContext.wrap_socket(sock) self._socket.connect((conf.hosts[0], conf.port)) fr=self._socket.makefile(mode="rb") fw=self._socket.makefile(mode="wb") @@ -24,6 +27,7 @@ class Connection: return self.incoming,self.outcoming def __exit__(self, exc_type, exc_val, exc_tb): + self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() @@ -31,6 +35,10 @@ class Client: def __init__(self,filename): self._filename=filename + self._ssl=ssl.create_default_context(cafile=conf.peers) + self._ssl.check_hostname=False + self._ssl.load_cert_chain(conf.certfile,conf.keyfile) + def negotiate(self): print(datetime.now(), "initializing...") localTree=HashTree.fromFile(self._filename) @@ -38,7 +46,7 @@ class Client: nodeStack=collections.deque([0]) # root # initialize session - with Connection() as (incoming,outcoming): + with Connection(self._ssl) as (incoming,outcoming): jsonData={"command":"init", "blockSize":localTree.BLOCK_SIZE, "blockCount":localTree.leafCount, "version":conf.version} outcoming.writeMsg(jsonData) jsonData,binData=incoming.readMsg() @@ -49,7 +57,7 @@ class Client: progress=Progress(localTree.leafCount) while len(nodeStack)>0: i=nodeStack.pop() - outcoming.writeMsg({"command":"req", "index":i}) + outcoming.writeMsg({"command":"req", "index":i, "dataType":"hash"}) jsonData,binData=incoming.readMsg() assert jsonData["index"]==i @@ -73,7 +81,7 @@ class Client: i1=-1 print(datetime.now(), "sending data:") - with Connection() as (incoming,outcoming): + with Connection(self._ssl) as (incoming,outcoming): progress=Progress(len(blocksToTransfer)) for (k,i2) in enumerate(blocksToTransfer): jsonData={"command":"send", "index":i2, "dataType":"data"} @@ -91,7 +99,7 @@ class Client: progress.p(k) progress.done() - with Connection() as (incoming,outcoming): + with Connection(self._ssl) as (incoming,outcoming): outcoming.writeMsg({"command":"end"}) log.info("closing session...") diff --git a/src/config.py b/src/config.py --- a/src/config.py +++ b/src/config.py @@ -1,3 +1,4 @@ +import os import datetime import logging as log @@ -11,4 +12,9 @@ log.basicConfig( version=0 hosts=["127.0.0.1"] -port=50009 +port=9001 + +directory=os.path.join(os.path.dirname(__file__),"..") +certfile=os.path.join(directory,"cert.pem") +keyfile=os.path.join(directory,"key.pem") +peers=os.path.join(directory,"peers.pem") diff --git a/src/server.py b/src/server.py --- a/src/server.py +++ b/src/server.py @@ -1,5 +1,6 @@ import hashlib import socket +import ssl import logging as log from hashtree import HashTree @@ -8,8 +9,10 @@ import config as conf class Connection: - def __init__(self,serverSocket): - self._socket, address = serverSocket.accept() + def __init__(self,serverSocket,sslContext): + sock, address = serverSocket.accept() + self._socket=sslContext.wrap_socket(sock,server_side=True) + log.info('Connected by {0}'.format(address)) fr=self._socket.makefile(mode="rb") fw=self._socket.makefile(mode="wb") @@ -21,6 +24,7 @@ class Connection: return self.incoming,self.outcoming def __exit__(self, exc_type, exc_val, exc_tb): + self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() @@ -37,6 +41,10 @@ class Server: self._newLeaves=dict() self.BLOCK_SIZE=self._tree.BLOCK_SIZE + self._ssl=ssl.create_default_context(ssl.Purpose.CLIENT_AUTH,cafile=conf.peers) + self._ssl.verify_mode=ssl.CERT_REQUIRED + self._ssl.load_cert_chain(conf.certfile,conf.keyfile) + self._ss = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._ss.bind(("", conf.port)) self._ss.listen(1) @@ -46,7 +54,7 @@ class Server: def serve(self): while True: - with Connection(self._ss) as (incoming, outcoming): + with Connection(self._ss,self._ssl) as (incoming, outcoming): try: while True: if not self._serveOne(incoming,outcoming): return