Commit 17962686 authored by Alexandre's avatar Alexandre
Browse files

Pass option values as function parameter

parent d86512a4
......@@ -253,29 +253,29 @@ def _validate_hostname(hostname, cert):
return True
return False
def get_addrfamily(addr, dot=False):
def get_addrfamily(addr, forceIPv4=False, forceIPv6=False):
"""Return the family as a socket object of the address."""
(is_ip, family) = is_valid_ip_address(addr)
# thoses checks between the IP family and the command line option
# might need to land somewhere else
if opts.forceIPv4 and family == 6:
if forceIPv4 and family == 6:
raise CustomException("You cannot force IPv4 with a litteral IPv6 address (%s)" % addr)
elif opts.forceIPv6 and family == 4:
elif forceIPv6 and family == 4:
raise CustomException("You cannot force IPv6 with a litteral IPv4 address (%s)" % addr)
if opts.forceIPv4 or family == 4:
if forceIPv4 or family == 4:
family = socket.AF_INET
elif opts.forceIPv6 or family == 6:
elif forceIPv6 or family == 6:
family = socket.AF_INET6
else:
family = 0
return family
def check_ip_address(addr, dot=False):
return get_addrfamily(addr, dot)
def check_ip_address(addr, forceIPv4=False, forceIPv6=False):
return get_addrfamily(addr, forceIPv4, forceIPv6)
def dump_data(data, text="data"):
pref = ' ' * (len(text) - 4)
......@@ -389,8 +389,9 @@ class RequestDoH(Request):
class Connection:
def __init__(self, server, servername=None, connect_to=None, forceIPv4=False, forceIPv6=False,
dot=False, verbose=False, debug=False, insecure=False):
def __init__(self, server, servername=None, connect_to=None,
forceIPv4=False, forceIPv6=False, insecure=False,
verbose=False, debug=False, dot=False):
if dot and not is_valid_hostname(server):
raise CustomException("DoT requires a host name or IP address, not \"%s\"" % server)
......@@ -425,14 +426,17 @@ class Connection:
class ConnectionDoT(Connection):
def __init__(self, server, servername=None, connect_to=None, forceIPv4=False, forceIPv6=False,
sni=True, pipelining=False, verbose=False, debug=False, insecure=False):
def __init__(self, server, servername=None, connect_to=None,
forceIPv4=False, forceIPv6=False, insecure=False,
verbose=False, debug=False,
sni=True, key=None, pipelining=False):
Connection.__init__(self, server, servername=servername, connect_to=connect_to,
forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=True,
verbose=verbose, debug=debug, insecure=insecure)
forceIPv4=forceIPv4, forceIPv6=forceIPv6, insecure=insecure,
verbose=verbose, debug=debug, dot=True)
self.sni = sni
self.key = key
self.pipelining = pipelining
if self.pipelining:
self.all_requests = [] # Currently, we load everything in memory
......@@ -464,7 +468,7 @@ class ConnectionDoT(Connection):
else:
addr = self.server # otherwise keep the server name
family = get_addrfamily(addr, dot=True)
family = get_addrfamily(addr, forceIPv4=self.forceIPv4, forceIPv6=self.forceIPv6)
addrinfo_list = socket.getaddrinfo(addr, 853, family)
addrinfo_set = { (addrinfo[4], addrinfo[0]) for addrinfo in addrinfo_list }
......@@ -546,7 +550,7 @@ class ConnectionDoT(Connection):
# RFC 7858, section 4.2 and appendix A
self.cert = self.session.get_peer_certificate()
self.publickey = self.cert.get_pubkey()
if self.debug or opts.key is not None:
if self.debug or self.key is not None:
self.hasher.update(OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
self.publickey))
self.digest = self.hasher.digest()
......@@ -558,14 +562,14 @@ class ConnectionDoT(Connection):
self.cert.get_issuer().commonName))
print("Public key is pin-sha256=\"%s\"" % key_string)
if not self.insecure:
if opts.key is None:
if self.key is None:
valid = _validate_hostname(self.check_name_cert, self.cert)
if not valid:
error("Certificate error: \"%s\" is not in the certificate" % (self.check_name_cert), exit=False)
return False
else:
if key_string != opts.key:
error("Key error: expected \"%s\", got \"%s\"" % (opts.key, key_string), exit=False)
if key_string != self.key:
error("Key error: expected \"%s\", got \"%s\"" % (self.key, key_string), exit=False)
return False
# restore the timer
......@@ -628,14 +632,14 @@ class ConnectionDoT(Connection):
self.pipelining_fill_pending(i)
return i
def read_result(self, connection, requests):
def read_result(self, connection, requests, display_results=True):
rcode, response, size = self.receive_data() # TODO can raise
# OpenSSL.SSL.ZeroReturnError
# if the
# conenction was
# closed
if not rcode:
if opts.display_results:
if display_results:
print("TIMEOUT")
return None
id = response.id
......@@ -644,7 +648,7 @@ class ConnectionDoT(Connection):
over, rank, request = requests[id]
self.all_requests[rank]['response'] = (rcode, response, size)
requests[id] = (True, rank, request)
if opts.display_results:
if display_results:
print()
print(response)
# TODO a timeout if some responses are lost?
......@@ -701,18 +705,23 @@ def create_handle(connection):
class ConnectionDoH(Connection):
def __init__(self, server, servername=None, connect_to=None, forceIPv4=False, forceIPv6=False,
multistreams=False, verbose=False, debug=False, insecure=False):
def __init__(self, server, servername=None, connect_to=None,
forceIPv4=False, forceIPv6=False,
insecure=False, verbose=False, debug=False,
multistreams=False):
Connection.__init__(self, server, servername=servername, connect_to=connect_to,
forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=False,
verbose=verbose, debug=debug, insecure=insecure)
forceIPv4=forceIPv4, forceIPv6=forceIPv6, insecure=insecure,
verbose=verbose, debug=debug, dot=False)
self.url = server
self.multistreams = multistreams
# temporary tweak to check that the ip family is coherent with
# user choice on forced IP
if self.connect_to:
check_ip_address(self.connect_to)
check_ip_address(self.connect_to, forceIPv4=self.forceIPv4, forceIPv6=self.forceIPv6)
if self.multistreams:
self.multi = self.create_multi()
......@@ -739,7 +748,7 @@ class ConnectionDoH(Connection):
except (OpenSSL.SSL.Error, CustomException) as e:
ok = False
error(e)
self.perform_multi(silent=True)
self.perform_multi(silent=True, display_results=False, show_time=False)
self.all_handles = []
self.finished = { 'http': {} }
......@@ -757,7 +766,7 @@ class ConnectionDoH(Connection):
h.close()
self.multi.remove_handle(h)
def perform_multi(self, silent=False):
def perform_multi(self, silent=False, display_results=True, show_time=False):
while 1:
ret, num_handles = self.multi.perform()
if ret != pycurl.E_CALL_MULTI_PERFORM:
......@@ -770,12 +779,12 @@ class ConnectionDoH(Connection):
ret, num_handles = self.multi.perform()
n, handle_pass, handle_fail = self.multi.info_read()
for handle in handle_pass:
self.read_result_handle(handle, silent=silent)
self.read_result_handle(handle, silent=silent, display_results=display_results, show_time=show_time)
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
n, handle_pass, handle_fail = self.multi.info_read()
for handle in handle_pass:
self.read_result_handle(handle, silent=silent)
self.read_result_handle(handle, silent=silent, display_results=display_results, show_time=show_time)
def send(self, handle):
handle.buffer = io.BytesIO()
......@@ -806,25 +815,25 @@ class ConnectionDoH(Connection):
self.send(handle)
self.receive(handle)
def read_result_handle(self, handle, silent=False):
def read_result_handle(self, handle, silent=False, display_results=True, show_time=False):
self.receive(handle)
handle.request.check_response()
if not silent and opts.show_time:
if not silent and show_time:
self.print_time(handle)
try:
self.finished['http'][handle.request.rcode] += 1
except KeyError:
self.finished['http'][handle.request.rcode] = 1
if not silent and opts.display_results:
if not silent and display_results:
print("Return code %s (%.2f ms):" % (handle.request.rcode,
(handle.time - handle.pretime) * 1000))
print(f"{handle.request.response}\n")
handle.close()
self.multi.remove_handle(handle)
def read_results(self):
def read_results(self, display_results=True, show_time=False):
for handle in self.all_handles:
self.read_result_handle(handle)
self.read_result_handle(handle, display_results=display_results, show_time=show_time)
def print_time(self, handle):
print(f'{handle.request.i:3d}', end=' ')
......@@ -1329,7 +1338,7 @@ def run_default(name, connection, opts):
else: # We do pipelining
connection.pipelining_add_request(request)
if opts.multistreams:
connection.perform_multi()
connection.perform_multi(opts.show_time, display_results=opts.display_results)
if opts.dot and opts.pipelining:
print("")
done = 0
......@@ -1344,7 +1353,7 @@ def run_default(name, connection, opts):
print("Elapsed time too long, %i requests never got a reply" % (opts.tests-done))
ok = False
break
id = connection.read_result(connection, connection.pending)
id = connection.read_result(connection, connection.pending, display_results=opts.display_results)
if id is None: # Probably a timeout
time.sleep(SLEEP_TIMEOUT)
continue
......@@ -1410,13 +1419,15 @@ for ip in ip_set:
print("Checking \"%s\" on %s ..." % (url, ip))
try:
if opts.dot:
conn = ConnectionDoT(url, servername=extracheck, connect_to=ip, verbose=opts.verbose,
debug=opts.debug, forceIPv4=opts.forceIPv4, forceIPv6=opts.forceIPv6,
sni=opts.sni, pipelining=opts.pipelining, insecure=opts.insecure)
conn = ConnectionDoT(url, servername=extracheck, connect_to=ip,
forceIPv4=opts.forceIPv4, forceIPv6=opts.forceIPv6,
insecure=opts.insecure, verbose=opts.verbose, debug=opts.debug,
sni=opts.sni, key=opts.key, pipelining=opts.pipelining)
else:
conn = ConnectionDoH(url, servername=extracheck, connect_to=ip, verbose=opts.verbose,
debug=opts.debug, forceIPv4=opts.forceIPv4, forceIPv6=opts.forceIPv6,
multistreams=opts.multistreams, insecure=opts.insecure)
conn = ConnectionDoH(url, servername=extracheck, connect_to=ip,
forceIPv4=opts.forceIPv4, forceIPv6=opts.forceIPv6,
insecure=opts.insecure, verbose=opts.verbose, debug=opts.debug,
multistreams=opts.multistreams)
except TimeoutError:
error("timeout")
except ConnectionRefusedError:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment