Commit 7e72380f authored by Alexandre's avatar Alexandre
Browse files

Merge branch 'partitioning'

Partition code into modules
Remove --sync option

See merge request bortzmeyer/homer!23
parents 2b9408b1 d17a9aa2
......@@ -102,7 +102,6 @@ with `-v` to keep only part of the output
* `-P, --post, --POST` : use HTTP POST method for all the transfers
* `-e, --head, --HEAD` : use HTTP HEAD method for all the transfers
* `--multistreams` : use HTTP/2 streams, needs an input file with `-f`
* `--sync` : process received queries synchronously (only with --multistreams)
* `--time` : display the time elapsed for the query (only with --multistreams)
#### DoT only options
......@@ -235,10 +234,7 @@ the multi interface, see the branch [homer-perf](-/tree/homer-perf).
As soon as a a response is received, it is displayed with the HTTP
return code and the elapsed time for this specific query. This output
can be suppressed with `--no-display-results`. It is possible to delay
the output of the answers after the last transfer finishes with `--sync`.
In that case the DNS responses are displayed in the same order as they
were sent.
can be suppressed with `--no-display-results`.
It is also possible to focus on the elapsed time only with the use
of the `--time` option combined with `--no-display-results`. This shows
......
This diff is collapsed.
from .utils import dump_data
from .utils import is_valid_hostname
from .utils import is_valid_ip_address
from .utils import is_valid_url
from .utils import get_addrfamily
from .utils import check_ip_address
from .utils import canonicalize
from .utils import validate_hostname
from .connection import Connection
from .connection import ConnectionDOT
from .connection import ConnectionDOH
from .request import RequestDOT
from .request import RequestDOH
from .request import create_request
from .request import create_requests_list
from .exceptions import TimeoutConnectionError
from .exceptions import ConnectionException
from .exceptions import ConnectionDOTException
from .exceptions import ConnectionDOHException
from .exceptions import FamilyException
from .exceptions import RequestException
from .exceptions import RequestDOTException
from .exceptions import PipeliningException
from .exceptions import DOHException
TIMEOUT_CONN = 2
TIMEOUT_READ = 1
SLEEP_TIMEOUT = 0.5
MAX_DURATION = 10
# For the check option
DOH_GET = 0
DOH_POST = 1
DOH_HEAD = 2
# Is the test mandatory?
mandatory_levels = {"legal": 30, "necessary": 20, "nicetohave": 10}
import io
import socket
import signal
import hashlib
import base64
try:
# http://pycurl.io/docs/latest
import pycurl
# Octobre 2019: the Python GnuTLS bindings don't work with Python 3. So we use OpenSSL.
# https://www.pyopenssl.org/
# https://pyopenssl.readthedocs.io/
import OpenSSL
# http://www.dnspython.org/
import dns.message
except ImportError as e:
print("Error: missing module")
print(e)
sys.exit(1)
import homer.utils
import homer.exceptions
class Connection:
def __init__(self, server, servername=None, connect_to=None,
forceIPv4=False, forceIPv6=False, insecure=False,
verbose=False, debug=False, dot=False):
if dot and not homer.is_valid_hostname(server):
raise homer.ConnectionDOTException("DoT requires a host name or IP address, not \"%s\"" % server)
if not dot and not homer.is_valid_url(server):
raise homer.ConnectionDOHException("DoH requires a valid HTTPS URL, not \"%s\"" % server)
if forceIPv4 and forceIPv6:
raise homer.ConnectionException("Force IPv4 *or* IPv6 but not both")
self.dot = dot
self.server = server
self.servername = servername
if self.servername is not None:
self.check_name_cert = self.servername
else:
self.check_name_cert = self.server
self.verbose = verbose
self.debug = debug
self.insecure = insecure
self.forceIPv4 = forceIPv4
self.forceIPv6 = forceIPv6
self.connect_to = connect_to
def __str__(self):
return self.server
def do_test(self, request):
# Routine doing one actual test. Returns nothing
pass
class ConnectionDOT(Connection):
def __init__(self, server, servername=None, connect_to=None,
forceIPv4=False, forceIPv6=False, insecure=False,
verbose=False, debug=False,
sni=True, key=None, pipelining=False):
super().__init__(server, servername=servername, connect_to=connect_to,
forceIPv4=forceIPv4, forceIPv6=forceIPv6, insecure=insecure,
verbose=verbose, debug=debug, dot=True)
self.sni = sni
self.key = key
self.pipelining = pipelining
if self.pipelining:
self.all_requests = [] # Currently, we load everything in memory
# since we want to keep everything,
# anyway. May be in the future, if we don't
# want to keep individual results, we'll use
# an iterator to fill a smaller table.
# all_requests is indexed by its rank in the input file.
self.pending = {} # pending is indexed by the query ID, and its
# maximum size is max_in_flight.
# establish the connection
self.connect()
def connect(self):
# if connect_to is defined, it means we know the IP address of the
# server and therefore we can establish a connection with it
# otherwise we only have a domain name and we should loop on all
# resolved IPs until a connection can be established
# getaddrinfo provides a list of resolved IPs, when connect_to is
# defined this list will have only one element
# so we can loop on the items until a connection is made
# the list is converted into a set of tuples to avoid duplicates
self.success = False
if self.connect_to is not None: # the server's IP address is known
addr = self.connect_to
else:
addr = self.server # otherwise keep the server name
family = homer.get_addrfamily(addr, forceIPv4=self.forceIPv4, forceIPv6=self.forceIPv6)
addrinfo_list = socket.getaddrinfo(addr, 853, family)
addrinfo_set = { (addrinfo[4], addrinfo[0]) for addrinfo in addrinfo_list }
signal.signal(signal.SIGALRM, homer.exceptions.timeout_connection)
for addrinfo in addrinfo_set:
if self.establish_session(addrinfo[0], addrinfo[1]):
self.success = True
break
if self.verbose and self.connect_to is None:
print("Could not connect to %s" % addrinfo[0][0])
print("Trying another IP address")
# we could not establish a connection
if not self.success:
# we tried all the resolved IPs
if self.verbose and self.connect_to is None:
print("No other IP address")
error = "Could not connect to \"%s\"" % self.server
if self.connect_to is not None:
error += " on %s" % self.connect_to
raise homer.ConnectionDOTException(error)
def establish_session(self, addr, sock_family):
"""Return True if a TLS session is established."""
self.hasher = hashlib.sha256()
# start the timer
signal.alarm(homer.TIMEOUT_CONN)
self.sock = socket.socket(sock_family, socket.SOCK_STREAM)
if self.verbose:
print("Connecting to %s ..." % addr[0])
# With typical DoT servers, we *must* use TLS 1.2 (otherwise,
# do_handshake fails with "OpenSSL.SSL.SysCallError: (-1, 'Unexpected
# EOF')" Typical HTTP servers are more lax.
self.context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD)
if self.insecure:
self.context.set_verify(OpenSSL.SSL.VERIFY_NONE, lambda *x: True)
else:
self.context.set_default_verify_paths()
self.context.set_verify_depth(4) # Seems ignored
self.context.set_verify(OpenSSL.SSL.VERIFY_PEER | OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT | \
OpenSSL.SSL.VERIFY_CLIENT_ONCE,
lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
self.session = OpenSSL.SSL.Connection(self.context, self.sock)
if self.sni:
self.session.set_tlsext_host_name(homer.canonicalize(self.check_name_cert).encode())
try:
self.session.connect((addr))
self.session.do_handshake()
except homer.exceptions.TimeoutConnectionError:
if self.verbose:
print("Timeout")
return False
except OSError:
if self.verbose:
print("Cannot connect")
return False
except OpenSSL.SSL.SysCallError as e:
if self.verbose:
print("OpenSSL error: %s" % e.args[1], file=sys.stderr)
return False
except OpenSSL.SSL.ZeroReturnError:
# see #18
if self.verbose:
print("Error: The SSL connection has been closed (try with --nosni to avoid sending SNI ?)", file=sys.stderr)
return False
except OpenSSL.SSL.Error as e:
if self.verbose:
print("OpenSSL error: %s" % ', '.join(err[0][2] for err in e.args), file=sys.stderr)
return False
# RFC 7858, section 4.2 and appendix A
self.cert = self.session.get_peer_certificate()
self.publickey = self.cert.get_pubkey()
if self.debug or self.key is not None:
self.hasher.update(OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
self.publickey))
self.digest = self.hasher.digest()
key_string = base64.standard_b64encode(self.digest).decode()
if self.debug:
print("Certificate #%x for \"%s\", delivered by \"%s\"" % \
(self.cert.get_serial_number(),
self.cert.get_subject().commonName,
self.cert.get_issuer().commonName))
print("Public key is pin-sha256=\"%s\"" % key_string)
if not self.insecure:
if self.key is None:
valid = homer.validate_hostname(self.check_name_cert, self.cert)
if not valid:
error("Certificate error: \"%s\" is not in the certificate" % (self.check_name_cert), exit=False)
return False
else:
if key_string != self.key:
error("Key error: expected \"%s\", got \"%s\"" % (self.key, key_string), exit=False)
return False
# restore the timer
signal.alarm(0)
# and start a new timer when pipelining requests
if self.pipelining:
self.sock.settimeout(homer.TIMEOUT_READ)
return True
def end(self):
self.session.shutdown()
self.session.close()
def send_data(self, data, dump=False):
if dump:
homer.dump_data(data, 'data sent')
length = len(data)
self.session.send(length.to_bytes(2, byteorder='big') + data)
def receive_data(self, dump=False):
try:
buf = self.session.recv(2)
except OpenSSL.SSL.WantReadError:
return (False, None, None)
size = int.from_bytes(buf, byteorder='big')
data = self.session.recv(size)
if dump:
homer.dump_data(data, 'data recv')
return (True, data, size)
def send_and_receive(self, request, dump=False):
self.send_data(request.data, dump=dump)
rcode, data, size = self.receive_data(dump=dump)
request.store_response(rcode, data, size)
# this function might need to be move outside
def do_test(self, request, synchronous=True):
self.send_data(request.data)
if synchronous:
rcode, data, size = self.receive_data()
request.store_response(rcode, data, size)
request.check_response(self.debug)
# should the pipelining methods be part of ConnectionDOT ?
def pipelining_add_request(self, request):
self.all_requests.append({'request': request, 'response': None}) # No answer yet
def pipelining_fill_pending(self, index):
if index < len(self.all_requests):
request = self.all_requests[index]['request']
id = request.message.id
# TODO check there is no duplicate in IDs
self.pending[id] = (False, index, request)
self.do_test(request, synchronous = False)
def pipelining_init_pending(self, max_in_flight):
for i in range(0, max_in_flight):
if i == len(self.all_requests):
break
self.pipelining_fill_pending(i)
return i
# this method might need to be moved somewhere else in order to avoid
# calling dns.message.from_wire()
def read_result(self, connection, requests, display_results=True):
rcode, data, size = self.receive_data() # TODO can raise
# OpenSSL.SSL.ZeroReturnError
# if the
# connection was
# closed
if not rcode:
if display_results:
print("TIMEOUT")
return None
# TODO remove call to dns.message (use abstraction instead)
response = dns.message.from_wire(data)
id = response.id
if id not in requests:
raise homer.PipeliningException("Received response for ID %s which is unexpected" % id)
over, rank, request = requests[id]
self.all_requests[rank]['response'] = (rcode, response, size)
requests[id] = (True, rank, request)
if display_results:
print()
print(response)
# TODO a timeout if some responses are lost?
return id
def create_handle(connection):
def reset_opt_default(handle):
opts = {
pycurl.NOBODY: False,
pycurl.POST: False,
pycurl.POSTFIELDS: '',
pycurl.URL: ''
}
for opt, value in opts.items():
handle.setopt(opt, value)
def prepare(handle, connection, request):
if not connection.multistreams:
handle.reset_opt_default(handle)
if request.post:
handle.setopt(pycurl.POST, True)
handle.setopt(pycurl.POSTFIELDS, request.data)
handle.setopt(pycurl.URL, connection.server)
else:
handle.setopt(pycurl.HTTPGET, True) # automatically sets CURLOPT_NOBODY to 0
if request.head:
handle.setopt(pycurl.NOBODY, True)
dns_req = base64.urlsafe_b64encode(request.data).decode('UTF8').rstrip('=')
handle.setopt(pycurl.URL, connection.server + ("?dns=%s" % dns_req))
handle.buffer = io.BytesIO()
handle.setopt(pycurl.WRITEDATA, handle.buffer)
handle.request = request
handle = pycurl.Curl()
# Does not work if pycurl was not compiled with nghttp2 (recent Debian
# packages are OK) https://github.com/pycurl/pycurl/issues/477
handle.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
if connection.debug:
handle.setopt(pycurl.VERBOSE, True)
if connection.insecure:
handle.setopt(pycurl.SSL_VERIFYPEER, False)
handle.setopt(pycurl.SSL_VERIFYHOST, False)
if connection.forceIPv4:
handle.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
if connection.forceIPv6:
handle.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V6)
if connection.connect_to is not None:
handle.setopt(pycurl.CONNECT_TO, ["::[%s]:443" % connection.connect_to,])
handle.setopt(pycurl.HTTPHEADER,
["Accept: application/dns-message", "Content-type: application/dns-message"])
handle.reset_opt_default = reset_opt_default
handle.prepare = prepare
return handle
class ConnectionDOH(Connection):
def __init__(self, server, servername=None, connect_to=None,
forceIPv4=False, forceIPv6=False,
insecure=False, verbose=False, debug=False,
multistreams=False):
super().__init__(server, servername=servername, connect_to=connect_to,
forceIPv4=forceIPv4, forceIPv6=forceIPv6, insecure=insecure,
verbose=verbose, debug=debug, dot=False)
self.url = server
self.multistreams = multistreams
# temporary tweak to check that the ip family is coherent with
# user choice on forced IP
if self.connect_to:
homer.check_ip_address(self.connect_to, forceIPv4=self.forceIPv4, forceIPv6=self.forceIPv6)
if self.multistreams:
self.multi = self.create_multi()
self.all_handles = []
self.finished = { 'http': {} }
else:
self.curl_handle = create_handle(self)
def create_multi(self):
multi = pycurl.CurlMulti()
multi.setopt(pycurl.M_MAX_HOST_CONNECTIONS, 1)
return multi
def init_multi(self):
# perform a first query alone
# to establish the connection and hence avoid starting
# the transfer of all the other queries simultaneously
# query the root NS because this should not impact the resover cache
if self.verbose:
print("Establishing multistreams connection...")
request = homer.create_request('.', qtype='NS', dot=False)
self.do_test(request, synchronous=False)
self.perform_multi(silent=True, display_results=False, show_time=False)
self.all_handles = []
self.finished = { 'http': {} }
def end(self):
if not self.multistreams:
self.curl_handle.close()
else:
self.remove_handles()
self.multi.close()
def remove_handles(self):
n, handle_success, handle_fail = self.multi.info_read()
handles = handle_success + handle_fail
for h in handles:
h.close()
self.multi.remove_handle(h)
def perform_multi(self, silent=False, display_results=True, show_time=False):
while 1:
ret, num_handles = self.multi.perform()
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
while num_handles:
ret = self.multi.select(1.0)
if ret == -1:
continue
while 1:
ret, num_handles = self.multi.perform()
n, handle_pass, handle_fail = self.multi.info_read()
for handle in handle_pass:
self.read_result_handle(handle, silent=silent, display_results=display_results, show_time=show_time)
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
n, handle_pass, handle_fail = self.multi.info_read()
for handle in handle_pass:
self.read_result_handle(handle, silent=silent, display_results=display_results, show_time=show_time)
def send(self, handle):
handle.buffer = io.BytesIO()
handle.setopt(pycurl.WRITEDATA, handle.buffer)
try:
handle.perform()
except pycurl.error as e:
raise homer.DOHException(e.args[1])
def receive(self, handle):
request = handle.request
body = handle.buffer.getvalue()
body_size = len(body)
http_code = handle.getinfo(pycurl.RESPONSE_CODE)
handle.time = handle.getinfo(pycurl.TOTAL_TIME)
handle.pretime = handle.getinfo(pycurl.PRETRANSFER_TIME)
try:
content_type = handle.getinfo(pycurl.CONTENT_TYPE)
except TypeError: # This is the exception we get if there is no Content-Type: (for intance in response to HEAD requests)
content_type = None
request.response = body
request.response_size = body_size
request.rcode = http_code
request.ctype = content_type
handle.buffer.close()
def send_and_receive(self, handle, dump=False):
self.send(handle)
self.receive(handle)
def read_result_handle(self, handle, silent=False, display_results=True, show_time=False):
self.receive(handle)
handle.request.check_response()
if not silent and show_time:
self.print_time(handle)
try:
self.finished['http'][handle.request.rcode] += 1
except KeyError:
self.finished['http'][handle.request.rcode] = 1
if not silent and display_results:
print("Return code %s (%.2f ms):" % (handle.request.rcode,
(handle.time - handle.pretime) * 1000))
print(f"{handle.request.response}\n")
handle.close()
self.multi.remove_handle(handle)
def read_results(self, display_results=True, show_time=False):
for handle in self.all_handles:
self.read_result_handle(handle, display_results=display_results, show_time=show_time)
def print_time(self, handle):
print(f'{handle.request.i:3d}', end=' ')
print(f'({handle.request.rcode})', end=' ')
print(f'{handle.pretime * 1000:8.3f} ms', end=' ')
print(f'{handle.time * 1000:8.3f} ms', end=' ')
print(f'{(handle.time - handle.pretime) * 1000:8.3f} ms')
def do_test(self, request, synchronous=True):
if synchronous:
handle = self.curl_handle
else:
handle = create_handle(self)
self.all_handles.append(handle)
handle.prepare(handle, self, request)
if synchronous:
self.send_and_receive(handle)
request.check_response(self.debug)
else:
self.multi.add_handle(handle)
def timeout_connection(signum, frame):
raise TimeoutConnectionError('Connection timeout')
class TimeoutConnectionError(Exception):
pass
class ConnectionException(Exception):
pass
class ConnectionDOTException(ConnectionException):
pass
class ConnectionDOHException(ConnectionException):
pass
class FamilyException(ConnectionException):
pass
class RequestException(Exception):
pass
class RequestDOTException(RequestException):
pass
class PipeliningException(Exception):
pass
class DOHException(Exception):
pass
# request.py
try:
# http://www.dnspython.org/
import dns.message
except ImportError as e:
print("Error: missing module")
print(e)
sys.exit(1)
import homer
class Request:
def __init__(self, qname, qtype='AAAA', use_edns=True, want_dnssec=False, no_ecs=True):
if no_ecs:
opt = dns.edns.ECSOption(address='', srclen=0) # Disable ECS (RFC 7871, section 7.1.2)
options = [opt]
else:
options = None
self.message = dns.message.make_query(qname, dns.rdatatype.from_text(qtype),
use_edns