Commit e11fddc7 authored by Alexandre's avatar Alexandre
Browse files

[DoT] Refactor Connection class

parent 22ca31ec
......@@ -428,36 +428,15 @@ class Connection:
class ConnectionDoT(Connection):
def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
pipelining=False, verbose=False, debug=False, insecure=False):
Connection.__init__(self, server, servername=servername, connect_to=connect,
forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=True,
verbose=verbose, debug=debug, insecure=insecure)
if connect is not None:
addr = connect
else:
addr = self.server
self.pipelining = pipelining
self.family = check_ip_address(self.server, dot=True)
addrinfo_list = socket.getaddrinfo(addr, 853, self.family)
addrinfo_set = { (addrinfo[4], addrinfo[0]) for addrinfo in addrinfo_list }
signal.signal(signal.SIGALRM, timeout_connection)
self.success = False
for addrinfo in addrinfo_set:
self.hasher = hashlib.sha256()
if self.connect(addrinfo[0], addrinfo[1]):
self.success = True
break
if self.verbose and connect is None:
print("Trying another IP address")
if not self.success:
if self.verbose and connect is None:
print("No other IP address")
if connect is None:
error(f'Could not connect to "{server}"')
else:
print(f'Could not connect to "{server}" on {connect}')
if self.pipelining:
self.all_requests = [] # Currently, we load everything in memory
# since we want to keep everything,
......@@ -468,9 +447,60 @@ class ConnectionDoT(Connection):
self.pending = {} # pending is indexed by the query ID, and its
# maximum size is max_in_flight.
def connect(self, addr, sock_family):
# establish the connection
self.connect()
def connect(self):
# if connect_to is defined, it means we know the IP address of the
# server and therefore we can establish a connection with it
# otherwise we only have a domain name and we should loop on all
# resolved IPs until a connection can be established
# getaddrinfo provides a list of resolved IPs, when connect_to is
# defined this list will have only one element
# so we can loop on the items until a connection is made
# the list is converted into a set of tuples to avoid duplicates
self.success = False
if self.connect_to is not None: # the server's IP address is known
addr = self.connect_to
else:
addr = self.server # otherwise keep the server name
family = get_addrfamily(addr, dot=True)
addrinfo_list = socket.getaddrinfo(addr, 853, family)
addrinfo_set = { (addrinfo[4], addrinfo[0]) for addrinfo in addrinfo_list }
signal.signal(signal.SIGALRM, timeout_connection)
for addrinfo in addrinfo_set:
if self.establish_session(addrinfo[0], addrinfo[1]):
self.success = True
break
if self.verbose and self.connect_to is None:
print("Trying another IP address")
# we could not establish a connection
if not self.success:
# we tried all the resolved IPs
if self.verbose and self.connect_to is None:
print("No other IP address")
error = "Could not connect to \"%s\"" % self.server
if self.connect_to is not None:
error += " on %s" % self.connect_to
raise CustomException(error)
def establish_session(self, addr, sock_family):
"""Return True if a TLS session is established."""
self.hasher = hashlib.sha256()
# start the timer
signal.alarm(TIMEOUT_CONN)
self.sock = socket.socket(sock_family, socket.SOCK_STREAM)
if self.verbose:
print("Connecting to %s ..." % str(addr))
# With typical DoT servers, we *must* use TLS 1.2 (otherwise,
......@@ -486,8 +516,9 @@ class ConnectionDoT(Connection):
OpenSSL.SSL.VERIFY_CLIENT_ONCE,
lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
self.session = OpenSSL.SSL.Connection(self.context, self.sock)
if opts.sni:
if opts.sni: # remove global variable call
self.session.set_tlsext_host_name(canonicalize(self.check_name_cert).encode())
try:
self.session.connect((addr))
self.session.do_handshake()
......@@ -501,7 +532,7 @@ class ConnectionDoT(Connection):
return False
except OpenSSL.SSL.SysCallError as e:
if self.verbose:
error(f"OpenSSL error: {e.args[1]}", exit=False)
error("OpenSSL error: %s" % e.args[1], exit=False)
return False
except OpenSSL.SSL.ZeroReturnError:
# see #18
......@@ -510,8 +541,9 @@ class ConnectionDoT(Connection):
return False
except OpenSSL.SSL.Error as e:
if self.verbose:
error(f"OpenSSL error: {', '.join(err[0][2] for err in e.args)}", exit=False)
error("OpenSSL error: %s" % ', '.join(err[0][2] for err in e.args), exit=False)
return False
# RFC 7858, section 4.2 and appendix A
self.cert = self.session.get_peer_certificate()
self.publickey = self.cert.get_pubkey()
......@@ -525,8 +557,7 @@ class ConnectionDoT(Connection):
(self.cert.get_serial_number(),
self.cert.get_subject().commonName,
self.cert.get_issuer().commonName))
print("Public key is pin-sha256=\"%s\"" % \
key_string)
print("Public key is pin-sha256=\"%s\"" % key_string)
if not self.insecure:
if opts.key is None:
valid = _validate_hostname(self.check_name_cert, self.cert)
......@@ -537,7 +568,10 @@ class ConnectionDoT(Connection):
if key_string != opts.key:
error("Key error: expected \"%s\", got \"%s\"" % (opts.key, key_string), exit=False)
return False
# restore the timer
signal.alarm(0)
# and start a new timer when pipelining requests
if self.pipelining:
self.sock.settimeout(TIMEOUT_READ)
return True
......
......@@ -319,7 +319,7 @@ tests:
- '--dot'
- 'brok.sources.org'
- 'in'
partstdout: 'Could not connect to'
partstderr: 'Could not connect to'
- exe: './homer.py'
name: '[dot][check] Test all IPs on brok.sources.org, get a KO'
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment