Commit a4f2463a authored by Alexandre's avatar Alexandre
Browse files

[DoH] Store each Request inside the curl handle

parent 359069f9
......@@ -456,8 +456,9 @@ def create_handle(connection):
handle.setopt(pycurl.NOBODY, True)
dns_req = base64.urlsafe_b64encode(request.data).decode('UTF8').rstrip('=')
handle.setopt(pycurl.URL, connection.server + ("?dns=%s" % dns_req))
request.buffer = io.BytesIO()
handle.setopt(pycurl.WRITEDATA, request.buffer)
handle.buffer = io.BytesIO()
handle.setopt(pycurl.WRITEDATA, handle.buffer)
handle.request = request
handle = pycurl.Curl()
# Does not work if pycurl was not compiled with nghttp2 (recent Debian
......@@ -515,14 +516,6 @@ class ConnectionDoH(Connection):
h.close()
self.multi.remove_handle(h)
def perform(self, request):
request.buffer = io.BytesIO()
self.curl_handle.setopt(pycurl.WRITEDATA, request.buffer)
try:
self.curl_handle.perform()
except pycurl.error as e:
error(e.args[1])
def perform_multi(self):
while 1:
ret, num_handles = self.multi.perform()
......@@ -537,12 +530,17 @@ class ConnectionDoH(Connection):
if ret != pycurl.E_CALL_MULTI_PERFORM:
break
def send(self, request):
self.perform(request)
def send(self, handle):
handle.buffer = io.BytesIO()
handle.setopt(pycurl.WRITEDATA, handle.buffer)
try:
handle.perform()
except pycurl.error as e:
error(e.args[1])
def receive(self, request):
handle = request.handle
body = request.buffer.getvalue()
def receive(self, handle):
request = handle.request
body = handle.buffer.getvalue()
body_size = len(body)
http_code = handle.getinfo(pycurl.RESPONSE_CODE)
try:
......@@ -553,23 +551,24 @@ class ConnectionDoH(Connection):
request.response_size = body_size
request.rcode = http_code
request.ctype = content_type
request.buffer.close()
handle.buffer.close()
def send_and_receive(self, request):
self.send(request)
self.receive(request)
def send_and_receive(self, handle):
self.send(handle)
self.receive(handle)
def do_test(self, request, synchronous=True):
if synchronous:
request.handle = self.curl_handle
handle = self.curl_handle
else:
request.handle = create_handle(self)
request.handle.prepare(request.handle, self, request)
handle = create_handle(self)
handle.prepare(handle, self, request)
if synchronous:
self.send_and_receive(request)
self.send_and_receive(handle)
request.check_response()
else:
self.multi.add_handle(request.handle)
self.multi.add_handle(handle)
return handle
def get_next_domain(input_file):
......@@ -628,12 +627,12 @@ def print_result(connection, request, prefix=None, display_err=True):
return ok
# pending_requests must be an array
def read_results(connection, pending_requests):
for i in range(0, len(pending_requests)):
request = pending_requests[i]
connection.receive(request)
request.check_response()
return pending_requests
def read_results(connection, pending_handles):
for i in range(0, len(pending_handles)):
handle = pending_handles[i]
connection.receive(handle)
handle.request.check_response()
return pending_handles
def create_request(qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec, dot=dot, trunc=False):
if dot:
......@@ -682,10 +681,13 @@ def run_check_default(connection):
request.post = True
elif method == DOH_HEAD:
request.head = True
request.handle = connection.curl_handle
request.handle.prepare(request.handle, connection, request)
handle = connection.curl_handle
handle.prepare(handle, connection, request)
bundle = handle
else:
bundle = request
try:
connection.send_and_receive(request)
connection.send_and_receive(bundle)
except CustomException as e:
ok = False
error(e)
......@@ -705,11 +707,11 @@ def run_check_mime(connection, accept="application/dns-message", content_type="a
header = [f"Accept: {accept}", f"Content-type: {content_type}"]
req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
request = create_request(**req_args)
connection.curl_handle.setopt(pycurl.HTTPHEADER, header)
request.handle = connection.curl_handle
request.handle.prepare(request.handle, connection, request)
handle = connection.curl_handle
handle.setopt(pycurl.HTTPHEADER, header)
handle.prepare(handle, connection, request)
try:
connection.send_and_receive(request)
connection.send_and_receive(handle)
except CustomException as e:
ok = False
error(e)
......@@ -718,7 +720,7 @@ def run_check_mime(connection, accept="application/dns-message", content_type="a
ok = False
default = "application/dns-message"
default_header = [f"Accept: {default}", f"Content-type: {default}"]
connection.curl_handle.setopt(pycurl.HTTPHEADER, default_header)
handle.setopt(pycurl.HTTPHEADER, default_header)
return ok
def run_check_trunc(connection):
......@@ -729,14 +731,16 @@ def run_check_trunc(connection):
req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
if dot:
request = create_request(dot=dot, trunc=True, **req_args)
bundle = request
else:
request = create_request(trunc=True, **req_args)
request.post = True
request.handle = connection.curl_handle
request.handle.prepare(request.handle, connection, request)
handle = connection.curl_handle
handle.prepare(handle, connection, request)
bundle = handle
try:
# 8.8.8.8 replies FORMERR but most DoT servers violently shut down the connection (which is legal)
connection.send_and_receive(request)
connection.send_and_receive(bundle)
except CustomException as e:
ok = False
error(e)
......@@ -1029,7 +1033,10 @@ for connectTo in ip_set:
request.head = head
request.post = post
try:
conn.do_test(request, synchronous = not multistreams)
if dot:
conn.do_test(request, synchronous = not multistreams)
else:
handle = conn.do_test(request, synchronous = not multistreams)
except (OpenSSL.SSL.Error, CustomException) as e:
ok = False
error(e)
......@@ -1038,7 +1045,7 @@ for connectTo in ip_set:
if not print_result(conn, request):
ok = False
else: # We do multistreams
pending[i] = request # No result yet
pending[i] = handle # No result yet
if tests > 1 and i == 0:
start2 = time.time()
if delay is not None:
......@@ -1046,9 +1053,9 @@ for connectTo in ip_set:
if multistreams:
conn.perform_multi()
print("")
result = read_results(conn, pending)
for j in result:
print("Return code %s: %s\n" % (result[j].rcode, result[j].response))
handles = read_results(conn, pending)
for j in handles:
print("Return code %s: %s\n" % (handles[j].request.rcode, handles[j].request.response))
else:
ok = run_check(conn) and ok # need to run run_check first
stop = time.time()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment