Commit 8cba1a31 authored by Alexandre's avatar Alexandre
Browse files

Split CustomException with better exception names

parent 8e3ad9a4
...@@ -261,9 +261,9 @@ def get_addrfamily(addr, forceIPv4=False, forceIPv6=False): ...@@ -261,9 +261,9 @@ def get_addrfamily(addr, forceIPv4=False, forceIPv6=False):
# thoses checks between the IP family and the command line option # thoses checks between the IP family and the command line option
# might need to land somewhere else # might need to land somewhere else
if forceIPv4 and family == 6: if forceIPv4 and family == 6:
raise CustomException("You cannot force IPv4 with a litteral IPv6 address (%s)" % addr) raise FamilyException("You cannot force IPv4 with a litteral IPv6 address (%s)" % addr)
elif forceIPv6 and family == 4: elif forceIPv6 and family == 4:
raise CustomException("You cannot force IPv6 with a litteral IPv4 address (%s)" % addr) raise FamilyException("You cannot force IPv6 with a litteral IPv4 address (%s)" % addr)
if forceIPv4 or family == 4: if forceIPv4 or family == 4:
family = socket.AF_INET family = socket.AF_INET
...@@ -289,7 +289,16 @@ def timeout_connection(signum, frame): ...@@ -289,7 +289,16 @@ def timeout_connection(signum, frame):
class TimeoutConnectionError(Exception): class TimeoutConnectionError(Exception):
pass pass
class CustomException(Exception): class ConnectionException(Exception):
pass
class ConnectionDOTException(ConnectionException):
pass
class ConnectionDOHException(ConnectionException):
pass
class FamilyException(ConnectionException):
pass pass
class Request: class Request:
...@@ -389,18 +398,19 @@ class RequestDOH(Request): ...@@ -389,18 +398,19 @@ class RequestDOH(Request):
class Connection: class Connection:
def __init__(self, server, servername=None, connect_to=None, def __init__(self, server, servername=None, connect_to=None,
forceIPv4=False, forceIPv6=False, insecure=False, forceIPv4=False, forceIPv6=False, insecure=False,
verbose=False, debug=False, dot=False): verbose=False, debug=False, dot=False):
if dot and not is_valid_hostname(server): if dot and not is_valid_hostname(server):
raise CustomException("DoT requires a host name or IP address, not \"%s\"" % server) raise ConnectionDOTException("DoT requires a host name or IP address, not \"%s\"" % server)
if not dot and not is_valid_url(server): if not dot and not is_valid_url(server):
raise CustomException("DoH requires a valid HTTPS URL, not \"%s\"" % server) raise ConnectionDOHException("DoH requires a valid HTTPS URL, not \"%s\"" % server)
if forceIPv4 and forceIPv6: if forceIPv4 and forceIPv6:
raise CustomException("Force IPv4 *or* IPv6 but not both") raise ConnectionException("Force IPv4 *or* IPv6 but not both")
self.dot = dot self.dot = dot
self.server = server self.server = server
...@@ -491,7 +501,7 @@ class ConnectionDOT(Connection): ...@@ -491,7 +501,7 @@ class ConnectionDOT(Connection):
error = "Could not connect to \"%s\"" % self.server error = "Could not connect to \"%s\"" % self.server
if self.connect_to is not None: if self.connect_to is not None:
error += " on %s" % self.connect_to error += " on %s" % self.connect_to
raise CustomException(error) raise ConnectionDOTException(error)
def establish_session(self, addr, sock_family): def establish_session(self, addr, sock_family):
"""Return True if a TLS session is established.""" """Return True if a TLS session is established."""
...@@ -749,7 +759,7 @@ class ConnectionDOH(Connection): ...@@ -749,7 +759,7 @@ class ConnectionDOH(Connection):
request = create_request('.', qtype='NS', dot=False) request = create_request('.', qtype='NS', dot=False)
try: try:
self.do_test(request, synchronous=False) self.do_test(request, synchronous=False)
except (OpenSSL.SSL.Error, CustomException) as e: except OpenSSL.SSL.Error as e:
ok = False ok = False
error(e) error(e)
self.perform_multi(silent=True, display_results=False, show_time=False) self.perform_multi(silent=True, display_results=False, show_time=False)
...@@ -971,12 +981,7 @@ def run_check_default(connection): ...@@ -971,12 +981,7 @@ def run_check_default(connection):
handle = connection.curl_handle handle = connection.curl_handle
handle.prepare(handle, connection, request) handle.prepare(handle, connection, request)
bundle = handle bundle = handle
try: connection.send_and_receive(bundle)
connection.send_and_receive(bundle)
except CustomException as e:
ok = False
error(e)
break
request.check_response(connection.debug) request.check_response(connection.debug)
if not print_result(connection, request, prefix=test_name, display_err=False): if not print_result(connection, request, prefix=test_name, display_err=False):
if level >= opts.mandatory_level: if level >= opts.mandatory_level:
...@@ -1005,11 +1010,7 @@ def run_check_mime(connection, accept="application/dns-message", content_type="a ...@@ -1005,11 +1010,7 @@ def run_check_mime(connection, accept="application/dns-message", content_type="a
handle = connection.curl_handle handle = connection.curl_handle
handle.setopt(pycurl.HTTPHEADER, header) handle.setopt(pycurl.HTTPHEADER, header)
handle.prepare(handle, connection, request) handle.prepare(handle, connection, request)
try: connection.send_and_receive(handle)
connection.send_and_receive(handle)
except CustomException as e:
ok = False
error(e)
request.check_response(connection.debug) request.check_response(connection.debug)
if not print_result(connection, request, prefix=f"Test Header {', '.join(header)}"): if not print_result(connection, request, prefix=f"Test Header {', '.join(header)}"):
ok = False ok = False
...@@ -1042,9 +1043,6 @@ def run_check_trunc(connection): ...@@ -1042,9 +1043,6 @@ def run_check_trunc(connection):
try: try:
# 8.8.8.8 replies FORMERR but most DoT servers violently shut down the connection (which is legal) # 8.8.8.8 replies FORMERR but most DoT servers violently shut down the connection (which is legal)
connection.send_and_receive(bundle, dump=connection.debug) connection.send_and_receive(bundle, dump=connection.debug)
except CustomException as e:
ok = False
error(e)
except OpenSSL.SSL.ZeroReturnError: # This is acceptable except OpenSSL.SSL.ZeroReturnError: # This is acceptable
return ok return ok
except dns.exception.FormError: # This is also acceptable except dns.exception.FormError: # This is also acceptable
...@@ -1328,7 +1326,7 @@ def run_default(name, connection, opts): ...@@ -1328,7 +1326,7 @@ def run_default(name, connection, opts):
if not opts.pipelining: if not opts.pipelining:
try: try:
connection.do_test(request, synchronous = not opts.multistreams) connection.do_test(request, synchronous = not opts.multistreams)
except (OpenSSL.SSL.Error, CustomException) as e: except OpenSSL.SSL.Error as e:
ok = False ok = False
error(e) error(e)
break break
...@@ -1440,7 +1438,7 @@ for ip in ip_set: ...@@ -1440,7 +1438,7 @@ for ip in ip_set:
error("\"%s\" not a name or an IP address" % url) error("\"%s\" not a name or an IP address" % url)
except socket.gaierror: except socket.gaierror:
error("Could not resolve \"%s\"" % url) error("Could not resolve \"%s\"" % url)
except CustomException as e: except ConnectionException as e:
error(e) error(e)
if conn.dot and not conn.success: if conn.dot and not conn.success:
ok = False ok = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment