homer.py 50.8 KB
Newer Older
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
1
2
#!/usr/bin/env python3

3
4
5
6
7
# Homer is a DoH (DNS-over-HTTPS) and DoT (DNS-over-TLS) client. Its
# main purpose is to test DoH and DoT resolvers. Reference site is
# <https://framagit.org/bortzmeyer/homer/> See author, documentation,
# etc, there, or in the README.md included with the distribution.

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
8
9
10
11
12
13
# http://pycurl.io/docs/latest
import pycurl

# http://www.dnspython.org/
import dns.message

14
15
16
# https://github.com/drkjam/netaddr/
import netaddr

17
18
19
20
21
# Octobre 2019: the Python GnuTLS bindings don't work with Python 3. So we use OpenSSL.
# https://www.pyopenssl.org/
# https://pyopenssl.readthedocs.io/
import OpenSSL

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
22
23
24
25
26
27
import io
import sys
import base64
import getopt
import urllib.parse
import time
28
29
import socket
import ctypes
30
import re
31
import os.path
32
33
import hashlib
import base64
34
import signal
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
35

36
# Values that can be changed from the command line
37
dot = False # DoH by default
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
38
verbose = False
39
debug = False
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
40
insecure = False
41
post = False
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
42
head = False
43
44
dnssec = False
edns = True
45
no_ecs = True
46
sni = True
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
47
rtype = 'AAAA'
48
vhostname = None
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
49
tests = 1 # Number of repeated tests
Alexandre's avatar
Alexandre committed
50
key = None # SPKI
51
ifile = None # Input file
52
delay = None
53
54
55
forceIPv4 = False
forceIPv6 = False
connectTo = None
56
pipelining = False
57
max_in_flight = 20
58
multistreams = False
Alexandre's avatar
Alexandre committed
59
sync = False
60
display_results = True
61
show_time = False
62
check = False
63
64
mandatory_level = None
check_additional = True
65
66
67
# Monitoring plugin only:
host = None
path = None
68
expect = None
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
69

70
71
72
# Do not change these
re_host = re.compile(r'^([0-9a-z][0-9a-z-\.]*)|([0-9:]+)|([0-9\.])$')

73
74
75
76
77
78
79
# For the monitoring plugin
STATE_OK = 0
STATE_WARNING = 1
STATE_CRITICAL = 2
STATE_UNKNOWN = 3
STATE_DEPENDENT = 4

Alexandre's avatar
Alexandre committed
80
81
82
83
# For the check option
DOH_GET = 0
DOH_POST = 1
DOH_HEAD = 2
84
85
# Is the test mandatory?
mandatory_levels = {"legal": 30, "necessary": 20, "nicetohave": 10}
Alexandre's avatar
Alexandre committed
86

87
TIMEOUT_CONN = 2
88
89
90
TIMEOUT_READ = 1
SLEEP_TIMEOUT = 0.5
MAX_DURATION = 10
91

92
def error(msg=None, exit=True):
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
93
94
    if msg is None:
        msg = "Unknown error"
95
96
    if monitoring:
        print("%s: %s" % (url, msg))
97
98
        if exit:
            sys.exit(STATE_CRITICAL)
99
    else:
Alexandre's avatar
Alexandre committed
100
101
102
        print(msg, file=sys.stderr)
        if check:
            print('KO')
103
104
        if exit:
            sys.exit(1)
Alexandre's avatar
Alexandre committed
105

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
106
107
108
def usage(msg=None):
    if msg:
        print(msg,file=sys.stderr)
Alexandre's avatar
Alexandre committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    print("Usage: %s [options] url-or-servername [domain-name [DNS type]]" % sys.argv[0], file=sys.stderr)
    print("""Options
    -t --dot            Use DoT (by default use DoH)
    -P --post --POST    Use HTTP POST method for all transfers (DoH only)
    -e --head --HEAD    Use HTTP HEAD method for all transfers (DoH only)
    -r --repeat <N>     Perform N times the query. If used with -f, read up to
                        <N> rows of the <file>.
    -d --delay <T>      Time to wait in seconds between each synchronous
                        request (only with --repeat)
    -f --file <file>    Read domain names from <file>, one per row with an
                        optional DNS type
    --check             Perform a set of predefined tests.
    --mandatory-level <level>
                        Define the <level> of test to perform (only with
                        --check)
                        Available <level> : legal, necessary, nicetohave
    --multistreams      Use HTTP/2 streams, needs an input file with -f
                        (DoH only)
    --sync              Process received queries synchronously (only with
                        --multistreams)
    --no-display-results
                        Disable output of DNS response (only with
                        --multistreams)
    --time              Display the time elapsed for the query (only with
                        --multistreams)
    --dnssec            Request DNSSEC data (signatures)
    --noedns            Disable EDNS, default is to indicate EDNS support
    --ecs               Send ECS to authoritative servers, default is to
                        refuse it
    --key <key>         Authenticate a DoT resolver with its public <key> in
                        base64 (DoT only)
    --nosni             Do not perform SNI (DoT only)
Alexandre's avatar
Alexandre committed
141
    -V --vhost <vhost>  Use a specific virtual host
Alexandre's avatar
Alexandre committed
142
143
144
145
146
147
148
149
150
151
152
153
    -k --insecure       Do not check the certificate
    -4 --v4only         Force IPv4 resolution of url-or-servername
    -6 --v6only         Force IPv6 resolution of url-or-servername
    -v --verbose        Make the program more talkative
    --debug             Make the program even more talkative than -v
    -h --help           Print this message

    url-or-servername   The URL or domain name of the DoT/DoH server
    domain-name         The domain name to resolve, not required if -f is
                        provided
    DNS type            The DNS record type to resolve, default AAAA
    """, file=sys.stderr)
154
    print("See the README.md for more details.", file=sys.stderr)
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
155

156
def is_valid_hostname(name):
157
    name = canonicalize(name)
158
    return re_host.search(name)
159

160
161
162
163
164
165
166
167
168
def canonicalize(hostname):
    result = hostname.lower()
    # TODO handle properly the case where it fails with UnicodeError
    # (two consecutive dots for instance) to get a custom exception
    result = result.encode('idna').decode()
    if result[len(result)-1] == '.':
        result = result[:-1]
    return result

169
170
171
172
def is_valid_ip_address(addr):
    try:
        baddr = netaddr.IPAddress(addr)
    except netaddr.core.AddrFormatError:
173
174
        return (False, None)
    return (True, baddr.version)
175

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def is_valid_url(url):
  try:
    result = urllib.parse.urlparse(url) # A very poor validation, many
    # errors (for instance whitespaces, IPv6 address litterals without
    # brackets...) are ignored.
    return (result.scheme=="https" and result.netloc != "")
  except ValueError:
    return False

def get_certificate_san(x509cert):
    san = ""
    ext_count = x509cert.get_extension_count()
    for i in range(0, ext_count):
        ext = x509cert.get_extension(i)
        if "subjectAltName" in str(ext.get_short_name()):
            san = str(ext)
    return san

194
195
196
197
198
199
200
201
202
203
204
# Try one possible name. Names must be already canonicalized.
def match_hostname(hostname, possibleMatch):
    if possibleMatch.startswith("*."): # Wildcard
        base = possibleMatch[1:] # Skip the star
        # RFC 6125 says that we MAY accept left-most labels with
        # wildcards included (foo*bar). We don't do it here.
        try:
            (first, rest) = hostname.split(".", maxsplit=1)
        except ValueError: # One-label name
            rest = hostname
        if rest == base[1:]:
205
            return True
206
        if hostname == base[1:]:
207
            return True
208
209
210
211
212
213
214
215
        return False
    else:
        return hostname == possibleMatch

# Try all the names in the certificate
def validate_hostname(hostname, cert):
    # Complete specification is in RFC 6125. It is long and
    # complicated and I'm not sure we do it perfectly.
216
    (is_addr, family) = is_valid_ip_address(hostname)
217
    hostname = canonicalize(hostname)
218
    for alt_name in get_certificate_san(cert).split(", "):
219
        if alt_name.startswith("DNS:") and not is_addr:
220
            (start, base) = alt_name.split("DNS:")
221
            base = canonicalize(base)
222
223
            found = match_hostname(hostname, base)
            if found:
224
                return True
225
226
        elif alt_name.startswith("IP Address:") and is_addr:
            host_i = netaddr.IPAddress(hostname)
227
228
229
            (start, base) = alt_name.split("IP Address:")
            if base.endswith("\n"):
                base = base[:-1]
230
231
232
233
            try:
                base_i = netaddr.IPAddress(base)
            except netaddr.core.AddrFormatError:
                continue # Ignore broken IP addresses in certificates. Are we too liberal?
Alexandre's avatar
Alexandre committed
234
            if host_i == base_i:
235
236
                return True
        else:
237
238
239
            pass # Ignore unknown alternative name types. May be
                 # accept URI alternative names for DoH,
    # According to RFC 6125, we MUST NOT try the Common Name before the Subject Alternative Names.
240
    cn = canonicalize(cert.get_subject().commonName)
241
242
243
    found = match_hostname(hostname, cn)
    if found:
        return True
244
245
    return False

246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def check_ip_address(addr, dot=dot):
    repraddress = addr
    (is_addr, family) = is_valid_ip_address(addr)
    if not is_addr and not dot:
        raise CustomException("%s is not IPv4 and not IPv6" % addr)
    if forceIPv4 and family == 6:
        raise CustomException("You cannot force IPv4 with a litteral IPv6 address (%s)" % addr)
    elif forceIPv6 and family == 4:
        raise CustomException("You cannot force IPv6 with a litteral IPv4 address (%s)" % addr)
    if forceIPv4 or family == 4:
        family = socket.AF_INET
        repraddress = addr
    elif forceIPv6 or family == 6:
        family = socket.AF_INET6
        repraddress = f'[{addr}]'
    else:
        family = 0
    return (family, repraddress)

265
266
267
268
269
270
def dump_data(data, text="data"):
    pref = ' ' * (len(text) - 4)
    print(f'{text}: ', data)
    print(pref, 'hex:', " ".join(format(c, '02x') for c in data))
    print(pref, 'bin:', " ".join(format(c, '08b') for c in data))

271
272
def timeout_connection(signum, frame):
    raise TimeoutConnectionError('Connection timeout')
273

274
275
class TimeoutConnectionError(Exception):
    pass
276

Alexandre's avatar
Alexandre committed
277
278
class CustomException(Exception):
    pass
279

Alexandre's avatar
Alexandre committed
280
class Request:
281
    def __init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec):
282
283
284
285
286
        if no_ecs:
             opt = dns.edns.ECSOption(address='', srclen=0) # Disable ECS (RFC 7871, section 7.1.2)
             options = [opt]
        else:
            options = None
287
        self.message = dns.message.make_query(qname, dns.rdatatype.from_text(qtype),
288
                                              use_edns=use_edns, want_dnssec=want_dnssec, options=options)
289
        self.message.flags |= dns.flags.AD # Ask for validation
Alexandre's avatar
Alexandre committed
290
        self.ok = True
291
        self.i = 0 # request's number on the connection (default to the first)
Alexandre's avatar
Alexandre committed
292

293
    def trunc_data(self):
294
        self.data = self.message.to_wire()
295
296
        half = round(len(self.data) / 2)
        self.data = self.data[:half]
297

Alexandre's avatar
Alexandre committed
298
    def to_wire(self):
299
        self.data = self.message.to_wire()
300
301


Alexandre's avatar
Alexandre committed
302
class RequestDoT(Request):
Alexandre's avatar
Alexandre committed
303
    def check_response(self, debug=False):
304
305
        if self.response is None:
            raise Exception("No reply received")
306
        ok = self.ok
307
        if not self.rcode:
308
309
            self.ok = False
            return False
310
        if self.response.id != self.message.id:
311
            self.response = "The ID in the answer does not match the one in the query"
Alexandre's avatar
Alexandre committed
312
313
            if debug:
                self.response += f'"(query id: {self.message.id}) (response id: {self.response.id})'
314
315
316
            self.ok = False
            return False
        return self.ok
317

318
319
320
321
    def store_response(self, rcode, response, size):
        self.rcode = True
        self.response = response
        self.response_size = size
322

323

Alexandre's avatar
Alexandre committed
324
class RequestDoH(Request):
325
    def __init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec):
Alexandre's avatar
Alexandre committed
326
        Request.__init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec)
327
        self.message.id = 0 # DoH requests that
328
        self.post = False
Alexandre's avatar
Alexandre committed
329
        self.head = False
330

Alexandre's avatar
Alexandre committed
331
    def check_response(self, debug=False):
Alexandre's avatar
Alexandre committed
332
        ok = self.ok
333
        if self.rcode == 200:
334
335
            if self.ctype != "application/dns-message":
                self.response = "Content type of the response (\"%s\") invalid" % self.ctype
336
337
                ok = False
            else:
Alexandre's avatar
Alexandre committed
338
                if not self.head:
339
                    try:
Alexandre's avatar
Alexandre committed
340
                        response = dns.message.from_wire(self.response)
341
342
343
                    except dns.message.TrailingJunk: # Not DNS. Should
                        # not happen for a content type
                        # application/dns-message but who knows?
Alexandre's avatar
Alexandre committed
344
345
346
                        self.response = "ERROR Not proper DNS data, trailing junk"
                        if debug:
                            self.response += " \"%s\"" % response
347
348
                        ok = False
                    except dns.name.BadLabelType: # Not DNS.
Alexandre's avatar
Alexandre committed
349
350
351
                        self.response = "ERROR Not proper DNS data (wrong path in the URL?)"
                        if debug:
                            self.response += " \"%s\"" % response[:100]
352
                        ok = False
Alexandre's avatar
Alexandre committed
353
354
                    else:
                        self.response = response
355
                else:
Alexandre's avatar
Alexandre committed
356
                    if self.response_size == 0:
357
358
                        self.response = "HEAD successful"
                    else:
Alexandre's avatar
Alexandre committed
359
360
361
362
                        data = self.response
                        self.response = "ERROR Body length is not null"
                        if debug:
                            self.response += "\"%s\"" % data[:100]
363
364
365
                        ok = False
        else:
            ok = False
Alexandre's avatar
Alexandre committed
366
            if self.response_size == 0:
367
368
                self.response = "[No details]"
            else:
369
                self.response = self.response
Alexandre's avatar
Alexandre committed
370
        self.ok = ok
Alexandre's avatar
Alexandre committed
371
        return ok
372

373

374
class Connection:
375
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
376
                 dot=dot, verbose=verbose, debug=debug, insecure=insecure):
377
        if dot and not is_valid_hostname(server):
378
            error("DoT requires a host name or IP address, not \"%s\"" % server)
379
        if not dot and not is_valid_url(server):
380
            error("DoH requires a valid HTTPS URL, not \"%s\"" % server)
381
        if forceIPv4 and forceIPv6:
Alexandre's avatar
Alexandre committed
382
            raise CustomException("Force IPv4 *or* IPv6 but not both")
383
        self.dot = dot
384
        self.server = server
385
386
        self.servername = servername
        if self.servername is not None:
Alexandre's avatar
Alexandre committed
387
            self.check = self.servername
388
        else:
Alexandre's avatar
Alexandre committed
389
            self.check = self.server
390
391
        self.dot = dot
        self.verbose = verbose
392
        self.debug = debug
393
        self.insecure = insecure
394
395
        self.forceIPv4 = forceIPv4
        self.forceIPv6 = forceIPv6
396
        self.connect_to = connect
397

398
399
    def __str__(self):
        return self.server
400

401
402
    def do_test(self, request):
        # Routine doing one actual test. Returns nothing
Alexandre's avatar
Alexandre committed
403
404
        pass

405
406
407

class ConnectionDoT(Connection):
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
Alexandre's avatar
Alexandre committed
408
                 pipelining=pipelining, verbose=verbose, debug=debug, insecure=insecure):
409
        Connection.__init__(self, server, servername=servername, connect=connect,
410
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=True,
411
                verbose=verbose, debug=debug, insecure=insecure)
412
413
414
415
        if connect is not None:
            addr = connect
        else:
            addr = self.server
416
        self.family, self.repraddress = check_ip_address(self.server, dot=True)
Alexandre's avatar
Alexandre committed
417
        addrinfo_list = socket.getaddrinfo(addr, 853, self.family)
418
419
        addrinfo_set = { (addrinfo[4], addrinfo[0]) for addrinfo in addrinfo_list }
        signal.signal(signal.SIGALRM, timeout_connection)
Alexandre's avatar
Alexandre committed
420
        self.success = False
421
422
        for addrinfo in addrinfo_set:
            self.hasher = hashlib.sha256()
423
            if self.connect(addrinfo[0], addrinfo[1]):
Alexandre's avatar
Alexandre committed
424
                self.success = True
425
                break
Alexandre's avatar
Alexandre committed
426
            if self.verbose and connect is None:
Alexandre's avatar
Alexandre committed
427
                print("Trying another IP address")
Alexandre's avatar
Alexandre committed
428
429
        if not self.success:
            if self.verbose and connect is None:
Alexandre's avatar
Alexandre committed
430
                print("No other IP address")
Alexandre's avatar
Alexandre committed
431
432
433
            if connect is None:
                error(f'Could not connect to "{server}"')
            else:
434
                print(f'Could not connect to "{server}" on {connect}')
Alexandre's avatar
Alexandre committed
435
436
437
438
439
440
441
442
443
444
        self.pipelining = pipelining
        if pipelining:
            self.all_requests = [] # Currently, we load everything in memory
                                   # since we want to keep everything,
                                   # anyway. May be in the future, if we don't
                                   # want to keep individual results, we'll use
                                   # an iterator to fill a smaller table.
                                   # all_requests is indexed by its rank in the input file.
            self.pending = {} # pending is indexed by the query ID, and its
                              # maximum size is max_in_flight.
445
446
447
448
449

    def connect(self, addr, sock_family):
        signal.alarm(TIMEOUT_CONN)
        self.addr = addr
        self.sock = socket.socket(sock_family, socket.SOCK_STREAM)
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
        if self.verbose:
            print("Connecting to %s ..." % str(self.addr))
        # With typical DoT servers, we *must* use TLS 1.2 (otherwise,
        # do_handshake fails with "OpenSSL.SSL.SysCallError: (-1, 'Unexpected
        # EOF')" Typical HTTP servers are more lax.
        self.context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD)
        if self.insecure:
            self.context.set_verify(OpenSSL.SSL.VERIFY_NONE, lambda *x: True)
        else:
            self.context.set_default_verify_paths()
            self.context.set_verify_depth(4) # Seems ignored
            self.context.set_verify(OpenSSL.SSL.VERIFY_PEER | OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT | \
                                    OpenSSL.SSL.VERIFY_CLIENT_ONCE,
                                    lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
        self.session = OpenSSL.SSL.Connection(self.context, self.sock)
465
466
        if sni:
            self.session.set_tlsext_host_name(canonicalize(self.check).encode()) # Server Name Indication (SNI)
467
468
469
470
471
        try:
            self.session.connect((self.addr))
            # TODO We may here have exceptions such as OpenSSL.SSL.ZeroReturnError
            self.session.do_handshake()
        except TimeoutConnectionError:
Alexandre's avatar
Alexandre committed
472
473
            if self.verbose:
                print("Timeout")
474
            return False
475
476
477
478
        except OSError:
            if self.verbose:
                print("Cannot connect")
            return False
Alexandre's avatar
Alexandre committed
479
480
481
482
        except OpenSSL.SSL.Error as e:
            if self.verbose:
                print(f"OpenSSL error: {', '.join(err[0][2] for err in e.args)}")
            return False
483
        # RFC 7858, section 4.2 and appendix A
484
        self.cert = self.session.get_peer_certificate()
485
        self.publickey = self.cert.get_pubkey()
486
        if debug or key is not None:
487
488
489
490
            self.hasher.update(OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
                                                  self.publickey))
            self.digest = self.hasher.digest()
            key_string = base64.standard_b64encode(self.digest).decode()
491
        if debug:
492
493
494
495
496
            print("Certificate #%x for \"%s\", delivered by \"%s\"" % \
                  (self.cert.get_serial_number(),
                   self.cert.get_subject().commonName,
                   self.cert.get_issuer().commonName))
            print("Public key is pin-sha256=\"%s\"" % \
497
                  key_string)
498
        if not insecure:
499
500
501
502
503
504
505
            if key is None:
                valid = validate_hostname(self.check, self.cert)
                if not valid:
                    error("Certificate error: \"%s\" is not in the certificate" % (self.check))
            else:
                if key_string != key:
                    error("Key error: expected \"%s\", got \"%s\"" % (key, key_string))
506
        signal.alarm(0)
507
508
        if pipelining:
            self.sock.settimeout(TIMEOUT_READ)
509
        return True
510

511
    def end(self):
512
513
514
        self.session.shutdown()
        self.session.close()

515
516
517
    def send_data(self, data, dump=False):
        if dump:
            dump_data(data, 'data sent')
518
519
520
        length = len(data)
        self.session.send(length.to_bytes(2, byteorder='big') + data)

521
    def receive_data(self, dump=False):
522
523
524
525
        try:
            buf = self.session.recv(2)
        except OpenSSL.SSL.WantReadError:
            return (False, None, None)
526
527
        size = int.from_bytes(buf, byteorder='big')
        buf = self.session.recv(size)
528
529
        if dump:
            dump_data(buf, 'data recv')
530
531
        response = dns.message.from_wire(buf)
        return (True, response, size)
532

533
534
    def send_and_receive(self, request, dump=False):
        self.send_data(request.data, dump=dump)
535
        rcode, response, size = self.receive_data(dump=dump)
536
        request.store_response(rcode, response, size)
537

538
    def do_test(self, request, synchronous=True):
539
        self.send_data(request.data)
540
        if synchronous:
541
542
            rcode, response, size = self.receive_data()
            request.store_response(rcode, response, size)
543
            request.check_response(self.debug)
544

Alexandre's avatar
Alexandre committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
    def pipelining_add_request(self, request):
        self.all_requests.append({'request': request, 'response': None}) # No answer yet

    def pipelining_fill_pending(self, index):
        if index < len(self.all_requests):
            request = self.all_requests[index]['request']
            id = request.message.id
            # TODO check there is no duplicate in IDs
            self.pending[id] = (False, index, request)
            self.do_test(request, synchronous = False)

    def pipelining_init_pending(self, max_in_flight):
        for i in range(0, max_in_flight):
            if i == len(self.all_requests):
                break
            self.pipelining_fill_pending(i)
        return i

563
564
565
566
567
568
569
570
571
572
573
574
575
576
    def read_result(self, connection, requests):
        rcode, response, size = self.receive_data() # TODO can raise
                                                    # OpenSSL.SSL.ZeroReturnError
                                                    # if the
                                                    # conenction was
                                                    # closed
        if not rcode:
            if display_results:
                print("TIMEOUT")
            return None
        id = response.id
        if id not in requests:
            raise Exception("Received response for ID %s which is unexpected" % id)
        over, rank, request = requests[id]
Alexandre's avatar
Alexandre committed
577
        self.all_requests[rank]['response'] = (rcode, response, size)
578
579
        requests[id] = (True, rank, request)
        if display_results:
580
581
            print()
            print(response)
582
583
        # TODO a timeout if some responses are lost?
        return id
584

585
586
def create_handle(connection):
    def reset_opt_default(handle):
587
588
589
590
591
592
593
        opts = {
                pycurl.NOBODY: False,
                pycurl.POST: False,
                pycurl.POSTFIELDS: '',
                pycurl.URL: ''
               }
        for opt, value in opts.items():
594
            handle.setopt(opt, value)
595

596
    def prepare(handle, connection, request):
597
        if not connection.multistreams:
598
            handle.reset_opt_default(handle)
599
        if request.post:
600
601
602
            handle.setopt(pycurl.POST, True)
            handle.setopt(pycurl.POSTFIELDS, request.data)
            handle.setopt(pycurl.URL, connection.server)
603
        else:
604
605
606
607
608
            handle.setopt(pycurl.HTTPGET, True) # automatically sets CURLOPT_NOBODY to 0
            if request.head:
                handle.setopt(pycurl.NOBODY, True)
            dns_req = base64.urlsafe_b64encode(request.data).decode('UTF8').rstrip('=')
            handle.setopt(pycurl.URL, connection.server + ("?dns=%s" % dns_req))
609
610
611
        handle.buffer = io.BytesIO()
        handle.setopt(pycurl.WRITEDATA, handle.buffer)
        handle.request = request
612
613
614
615
616

    handle = pycurl.Curl()
    # Does not work if pycurl was not compiled with nghttp2 (recent Debian
    # packages are OK) https://github.com/pycurl/pycurl/issues/477
    handle.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
Alexandre's avatar
Alexandre committed
617
    if connection.debug:
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        handle.setopt(pycurl.VERBOSE, True)
    if connection.insecure:
        handle.setopt(pycurl.SSL_VERIFYPEER, False)
        handle.setopt(pycurl.SSL_VERIFYHOST, False)
    if connection.forceIPv4:
        handle.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
    if connection.forceIPv6:
        handle.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V6)
    if connection.connect is not None:
        family, repraddress = check_ip_address(connection.connect, dot=False)
        handle.setopt(pycurl.CONNECT_TO, [f'::{repraddress}:443',])
    handle.setopt(pycurl.HTTPHEADER,
            ["Accept: application/dns-message", "Content-type: application/dns-message"])
    handle.reset_opt_default = reset_opt_default
    handle.prepare = prepare
    return handle
634
635


636
637
class ConnectionDoH(Connection):
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
Alexandre's avatar
Alexandre committed
638
                 multistreams=False, verbose=verbose, debug=debug, insecure=insecure):
639
        Connection.__init__(self, server, servername=servername, connect=connect,
640
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=False,
Alexandre's avatar
Alexandre committed
641
                verbose=verbose, debug=debug, insecure=insecure)
642
        self.url = server
643
        self.connect = connect
Alexandre's avatar
Alexandre committed
644
        self.multistreams = multistreams
645
646
        if self.multistreams:
            self.multi = self.create_multi()
647
            self.all_handles = []
Alexandre's avatar
Alexandre committed
648
        else:
649
            self.curl_handle = create_handle(self)
650

651
652
653
654
    def create_multi(self):
        multi = pycurl.CurlMulti()
        multi.setopt(pycurl.M_MAX_HOST_CONNECTIONS, 1)
        return multi
655

656
657
658
    def end(self):
        if not self.multistreams:
            self.curl_handle.close()
659
        else:
660
            self.remove_handles()
661
            self.multi.close()
662

663
664
665
666
667
668
669
    def remove_handles(self):
        n, handle_success, handle_fail = self.multi.info_read()
        handles = handle_success + handle_fail
        for h in handles:
            h.close()
            self.multi.remove_handle(h)

670
671
672
673
674
675
676
677
678
679
680
    def perform_multi(self):
        while 1:
            ret, num_handles = self.multi.perform()
            if ret != pycurl.E_CALL_MULTI_PERFORM:
                break
        while num_handles:
            ret = self.multi.select(1.0)
            if ret == -1:
                continue
            while 1:
                ret, num_handles = self.multi.perform()
Alexandre's avatar
Alexandre committed
681
682
683
684
                if not sync:
                    n, handle_pass, handle_fail = self.multi.info_read()
                    for handle in handle_pass:
                        self.read_result_handle(handle)
685
686
                if ret != pycurl.E_CALL_MULTI_PERFORM:
                    break
687
688
689
690
        if not sync:
            n, handle_pass, handle_fail = self.multi.info_read()
            for handle in handle_pass:
                self.read_result_handle(handle)
691

692
693
694
    def send(self, handle):
        handle.buffer = io.BytesIO()
        handle.setopt(pycurl.WRITEDATA, handle.buffer)
Alexandre's avatar
Alexandre committed
695
        try:
696
            handle.perform()
Alexandre's avatar
Alexandre committed
697
698
        except pycurl.error as e:
            error(e.args[1])
699

700
701
702
    def receive(self, handle):
        request = handle.request
        body = handle.buffer.getvalue()
703
        body_size = len(body)
704
        http_code = handle.getinfo(pycurl.RESPONSE_CODE)
Alexandre's avatar
Alexandre committed
705
706
        handle.time = handle.getinfo(pycurl.TOTAL_TIME)
        handle.pretime = handle.getinfo(pycurl.PRETRANSFER_TIME)
707
        try:
Alexandre's avatar
Alexandre committed
708
709
            content_type = handle.getinfo(pycurl.CONTENT_TYPE)
        except TypeError: # This is the exception we get if there is no Content-Type: (for intance in response to HEAD requests)
710
            content_type = None
711
        request.response = body
Alexandre's avatar
Alexandre committed
712
        request.response_size = body_size
713
714
        request.rcode = http_code
        request.ctype = content_type
715
        handle.buffer.close()
716

Alexandre's avatar
Alexandre committed
717
    def send_and_receive(self, handle, dump=False):
718
719
        self.send(handle)
        self.receive(handle)
720

721
722
723
    def read_result_handle(self, handle):
        self.receive(handle)
        handle.request.check_response()
724
        if show_time:
725
            self.print_time(handle)
726
727
728
729
        try:
            self.finished['http'][handle.request.rcode] += 1
        except KeyError:
            self.finished['http'][handle.request.rcode] = 1
730
        if display_results:
731
732
733
            print("Return code %s (%.2f ms):" % (handle.request.rcode,
                (handle.time - handle.pretime) * 1000))
            print(f"{handle.request.response}\n")
Alexandre's avatar
Alexandre committed
734
735
        handle.close()
        self.multi.remove_handle(handle)
736

737
738
739
    def read_results(self):
        for handle in self.all_handles:
            self.read_result_handle(handle)
740

741
742
743
744
745
746
747
    def print_time(self, handle):
        print(f'{handle.request.i:3d}', end='   ')
        print(f'({handle.request.rcode})', end='   ')
        print(f'{handle.pretime * 1000:8.3f} ms', end='  ')
        print(f'{handle.time * 1000:8.3f} ms', end='  ')
        print(f'{(handle.time - handle.pretime) * 1000:8.3f} ms')

748
749
    def do_test(self, request, synchronous=True):
        if synchronous:
750
            handle = self.curl_handle
751
        else:
752
            handle = create_handle(self)
753
            self.all_handles.append(handle)
754
        handle.prepare(handle, self, request)
Alexandre's avatar
Alexandre committed
755
        if synchronous:
756
            self.send_and_receive(handle)
Alexandre's avatar
Alexandre committed
757
            request.check_response(self.debug)
758
        else:
759
            self.multi.add_handle(handle)
Alexandre's avatar
Alexandre committed
760

Alexandre's avatar
Alexandre committed
761
762
763
764
765
766
767
768
769
770
771
772
773

def get_next_domain(input_file):
    name, rtype = 'framagit.org', 'AAAA'
    line = input_file.readline()
    if line[:-1] == "":
        error("Not enough data in %s for the %i tests" % (ifile, tests))
    if line.find(' ') == -1:
        name = line[:-1]
        rtype = 'AAAA'
    else:
        (name, rtype) = line.split()
    return name, rtype

774
def print_result(connection, request, prefix=None, display_err=True):
Alexandre's avatar
Alexandre committed
775
    ok = request.ok
776
777
778
779
780
781
782
    dot = connection.dot
    server = connection.server
    rcode = request.rcode
    msg = request.response
    size = request.response_size
    if (dot and rcode) or (not dot and rcode == 200):
        if not monitoring:
783
784
            if not dot and show_time:
                connection.print_time(connection.curl_handle)
785
            if display_results and (not check or verbose):
786
787
                print(msg)
        else:
788
789
790
791
            if expect is not None and expect not in str(request.response):
                ok = False
                print("%s Cannot find \"%s\" in response" % (server, expect))
                sys.exit(STATE_CRITICAL)
792
            if ok and size is not None and size > 0:
793
794
795
796
797
798
                print("%s OK - %s" % (server, "No error for %s/%s, %i bytes received" % (name, rtype, size)))
            else:
                print("%s OK - %s" % (server, "No error"))
            sys.exit(STATE_OK)
    else:
        if not monitoring:
799
            if display_err:
800
801
                if check:
                    print(connection.connect_to, end=': ', file=sys.stderr)
802
803
804
805
806
807
808
809
810
811
                if prefix:
                    print(prefix, end=': ', file=sys.stderr)
                if dot:
                    print("Error: %s" % msg, file=sys.stderr)
                else:
                   try:
                       msg = msg.decode()
                   except (UnicodeDecodeError, AttributeError):
                       pass # Sometimes, msg can be binary, or Latin-1
                   print("HTTP error %i: %s" % (rcode, msg), file=sys.stderr) 
812
        else:
813
814
815
816
817
818
819
820
            if not dot:
                print("%s HTTP error - %i: %s" % (server, rcode, msg))
            else:
                print("%s Error - %i: %s" % (server, rcode, msg))
            sys.exit(STATE_CRITICAL)
        ok = False
    return ok

821
def create_request(qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec, dot=dot, trunc=False):
Alexandre's avatar
Alexandre committed
822
    if dot:
823
        request = RequestDoT(qname, rtype, use_edns, want_dnssec)
Alexandre's avatar
Alexandre committed
824
    else:
825
        request = RequestDoH(qname, rtype, use_edns, want_dnssec)
Alexandre's avatar
Alexandre committed
826
827
828
829
830
831
832
833
834
    if trunc:
        request.trunc_data()
    else:
        request.to_wire()
    return request

def create_requests_list(dot=dot, **req_args):
    requests = []
    if dot:
835
836
837
838
839
840
        requests.append(('Test 1', create_request(dot=dot, **req_args),
                        mandatory_levels["legal"]))
        requests.append(('Test 2', create_request(dot=dot, **req_args),
                         mandatory_levels["necessary"])) # RFC 7858,
        # section 3.3, SHOULD accept several requests on one connection.
        # TODO we miss the tests of pipelining and out-of-order.
Alexandre's avatar
Alexandre committed
841
    else:
842
843
844
845
846
847
848
849
        requests.append(('Test GET', create_request(**req_args), DOH_GET,
                         mandatory_levels["legal"])) # RFC 8484, section 4.1
        requests.append(('Test POST', create_request(**req_args), DOH_POST,
                         mandatory_levels["legal"])) # RFC 8484, section 4.1
        requests.append(('Test HEAD', create_request(**req_args), DOH_HEAD,
                         mandatory_levels["nicetohave"])) # HEAD
        # method is not mentioned in RFC 8484 (see section 4.1), so
        # just "nice to have".
Alexandre's avatar
Alexandre committed
850
    return requests
851

852
def run_check_default(connection):
Alexandre's avatar
Alexandre committed
853
    ok = True
854
855
856
857
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    requests = create_requests_list(dot=dot, **req_args)
    for request_pack in requests:
        if dot:
858
            test_name, request, mandatory = request_pack
859
        else:
860
            test_name, request, method, mandatory = request_pack
861
862
        if verbose:
            print(test_name)
Alexandre's avatar
Alexandre committed
863
864
865
        if dot:
            bundle = request
        else:
866
867
868
869
            if method == DOH_POST:
                request.post = True
            elif method == DOH_HEAD:
                request.head = True
870
871
872
            handle = connection.curl_handle
            handle.prepare(handle, connection, request)
            bundle = handle
873
        try:
874
            connection.send_and_receive(bundle)
875
876
877
878
        except CustomException as e:
            ok = False
            error(e)
            break
Alexandre's avatar
Alexandre committed
879
        request.check_response(debug)
880
881
882
883
        if not print_result(connection, request, prefix=test_name, display_err=False):
            if mandatory >= mandatory_level:
                print_result(connection, request, prefix=test_name, display_err=True)
                ok = False
Alexandre's avatar
Alexandre committed
884
885
            if verbose:
                print()
886
            break
Alexandre's avatar
Alexandre committed
887
888
        if verbose:
            print()
Alexandre's avatar
Alexandre committed
889
    return ok
890

891
892
893
894
895
def run_check_mime(connection, accept="application/dns-message", content_type="application/dns-message"):
    if dot:
        return True
    ok = True
    header = [f"Accept: {accept}", f"Content-type: {content_type}"]
Alexandre's avatar
Alexandre committed
896
897
898
    if verbose:
        test_name = f'Test mime: {", ".join(h for h in header)}'
        print(test_name)
899
900
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    request = create_request(**req_args)
901
902
903
    handle = connection.curl_handle
    handle.setopt(pycurl.HTTPHEADER, header)
    handle.prepare(handle, connection, request)
904
    try:
905
        connection.send_and_receive(handle)
906
907
908
    except CustomException as e:
        ok = False
        error(e)
Alexandre's avatar
Alexandre committed
909
    request.check_response(debug)
910
911
912
913
    if not print_result(connection, request, prefix=f"Test Header {', '.join(header)}"):
        ok = False
    default = "application/dns-message"
    default_header = [f"Accept: {default}", f"Content-type: {default}"]
914
    handle.setopt(pycurl.HTTPHEADER, default_header)
Alexandre's avatar
Alexandre committed
915
916
    if verbose:
        print()
917
918
    return ok

919
920
921
922
923
924
925
926
def run_check_trunc(connection):
    ok = True
    test_name = 'Test truncated data'
    if verbose:
        print(test_name)
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    if dot:
        request = create_request(dot=dot, trunc=True, **req_args)
927
        bundle = request
928
929
930
    else:
        request = create_request(trunc=True, **req_args)
        request.post = True
931
932
933
        handle = connection.curl_handle
        handle.prepare(handle, connection, request)
        bundle = handle
934
    try:
935
        # 8.8.8.8 replies FORMERR but most DoT servers violently shut down the connection (which is legal)
Alexandre's avatar
Alexandre committed
936
        connection.send_and_receive(bundle, dump=debug)
937
938
939
    except CustomException as e:
        ok = False
        error(e)
940
941
    except OpenSSL.SSL.ZeroReturnError: # This is acceptable
        return ok
942
943
944
945
946
    except dns.exception.FormError: # This is also acceptable
        # Some DSN resolvers will echo mangled requests with
        # the RCODE set to FORMERR
        # so response can not be parsed in this case
        return ok
Alexandre's avatar
Alexandre committed
947
    if request.check_response(debug): # FORMERR is expected
Alexandre's avatar
Alexandre committed
948
949
950
951
        if dot:
            ok = request.rcode == dns.rcode.FORMERR
        else:
            ok = (request.response.rcode() == dns.rcode.FORMERR)
952
    else:
Alexandre's avatar
Alexandre committed
953
954
955
956
957
        if dot:
            ok = False
        else: # a 400 response's status is acceptable
            ok = (request.rcode >= 400 and request.rcode < 500)
    print_result(connection, request, prefix=test_name, display_err=not ok)
Alexandre's avatar
Alexandre committed
958
959
    if verbose:
        print()