Commit ca994f00 authored by Alexandre's avatar Alexandre
Browse files

Refactor pipelining code

parent 5547974c
......@@ -360,7 +360,7 @@ class Connection:
class ConnectionDoT(Connection):
def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
verbose=verbose, debug=debug, insecure=insecure):
pipelining=pipelining, verbose=verbose, debug=debug, insecure=insecure):
Connection.__init__(self, server, servername=servername, connect=connect,
forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=True,
verbose=verbose, debug=debug, insecure=insecure)
......@@ -387,7 +387,16 @@ class ConnectionDoT(Connection):
error(f'Could not connect to "{server}"')
else:
print(f'Could not connect to "{server}" on {connect}')
self.pipelining = pipelining
if pipelining:
self.all_requests = [] # Currently, we load everything in memory
# since we want to keep everything,
# anyway. May be in the future, if we don't
# want to keep individual results, we'll use
# an iterator to fill a smaller table.
# all_requests is indexed by its rank in the input file.
self.pending = {} # pending is indexed by the query ID, and its
# maximum size is max_in_flight.
def connect(self, addr, sock_family):
signal.alarm(TIMEOUT_CONN)
......@@ -488,6 +497,24 @@ class ConnectionDoT(Connection):
request.store_response(rcode, response, size)
request.check_response(self.debug)
def pipelining_add_request(self, request):
self.all_requests.append({'request': request, 'response': None}) # No answer yet
def pipelining_fill_pending(self, index):
if index < len(self.all_requests):
request = self.all_requests[index]['request']
id = request.message.id
# TODO check there is no duplicate in IDs
self.pending[id] = (False, index, request)
self.do_test(request, synchronous = False)
def pipelining_init_pending(self, max_in_flight):
for i in range(0, max_in_flight):
if i == len(self.all_requests):
break
self.pipelining_fill_pending(i)
return i
def read_result(self, connection, requests):
rcode, response, size = self.receive_data() # TODO can raise
# OpenSSL.SSL.ZeroReturnError
......@@ -502,7 +529,7 @@ class ConnectionDoT(Connection):
if id not in requests:
raise Exception("Received response for ID %s which is unexpected" % id)
over, rank, request = requests[id]
all_requests[rank]['response'] = (rcode, response, size)
self.all_requests[rank]['response'] = (rcode, response, size)
requests[id] = (True, rank, request)
if display_results:
print()
......@@ -1163,7 +1190,7 @@ for connectTo in ip_set:
if dot:
conn = ConnectionDoT(url, servername=extracheck, connect=connectTo, verbose=verbose,
debug=debug, forceIPv4=forceIPv4, forceIPv6=forceIPv6,
insecure=insecure)
pipelining=pipelining, insecure=insecure)
else:
conn = ConnectionDoH(url, servername=extracheck, connect=connectTo, verbose=verbose,
debug=debug, forceIPv4=forceIPv4, forceIPv6=forceIPv6,
......@@ -1183,15 +1210,6 @@ for connectTo in ip_set:
continue
if ifile is not None:
input = open(ifile)
if pipelining:
all_requests = [] # Currently, we load everything in memory
# since we want to keep everything,
# anyway. May be in the future, if we don't
# want to keep individual results, we'll use
# an iterator to fill a smaller table.
# all_requests is indexed by its rank in the input file.
pending = {} # pending is indexed by the query ID, and its
# maximum size is max_in_flight.
if not check:
for i in range (0, tests):
if tests > 1 and (verbose or display_results):
......@@ -1223,7 +1241,7 @@ for connectTo in ip_set:
if delay is not None:
time.sleep(delay)
else: # We do pipelining
all_requests.append({'request': request, 'response': None}) # No answer yet
conn.pipelining_add_request(request)
if multistreams:
conn.perform_multi()
if sync:
......@@ -1231,35 +1249,23 @@ for connectTo in ip_set:
if dot and pipelining:
print("")
done = 0
current = 0
for i in range(0,max_in_flight):
if i == len(all_requests):
break
request = all_requests[i]['request']
id = request.message.id
current += 1
# TODO check there is no duplicate in IDs
pending[id] = (False, i, request)
conn.do_test(request, synchronous = False)
current = conn.pipelining_init_pending(max_in_flight)
while done < tests:
if time.time() > start + MAX_DURATION:
print("Elapsed time too long, %i requests never got a reply" % (tests-done))
break
id = conn.read_result(conn,pending)
id = conn.read_result(conn, conn.pending)
if id is None: # Probably a timeout
time.sleep(SLEEP_TIMEOUT)
continue
done += 1
over, rank, request = pending[id]
over, rank, request = conn.pending[id]
if not over:
error("Internal error, request %i should be over" % id)
all_requests[rank] = request
if current < len(all_requests):
request = all_requests[current]
id = request['request'].message.id
pending[id] = (False, current, request)
conn.all_requests[rank] = request
if current < len(conn.all_requests):
conn.pipelining_fill_pending(current)
current += 1
conn.do_test(request['request'], synchronous = False)
else:
ok = run_check(conn) and ok # need to run run_check first
stop = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment