homer.py 52.5 KB
Newer Older
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
1
2
#!/usr/bin/env python3

3
4
5
6
7
# Homer is a DoH (DNS-over-HTTPS) and DoT (DNS-over-TLS) client. Its
# main purpose is to test DoH and DoT resolvers. Reference site is
# <https://framagit.org/bortzmeyer/homer/> See author, documentation,
# etc, there, or in the README.md included with the distribution.

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
8
9
10
11
12
13
# http://pycurl.io/docs/latest
import pycurl

# http://www.dnspython.org/
import dns.message

14
15
16
# https://github.com/drkjam/netaddr/
import netaddr

17
18
19
20
21
# Octobre 2019: the Python GnuTLS bindings don't work with Python 3. So we use OpenSSL.
# https://www.pyopenssl.org/
# https://pyopenssl.readthedocs.io/
import OpenSSL

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
22
23
24
25
26
27
import io
import sys
import base64
import getopt
import urllib.parse
import time
28
29
import socket
import ctypes
30
import re
31
import os.path
32
33
import hashlib
import base64
34
import signal
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
35

36
# Values that can be changed from the command line
37
dot = False # DoH by default
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
38
verbose = False
39
debug = False
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
40
insecure = False
41
post = False
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
42
head = False
43
44
dnssec = False
edns = True
45
no_ecs = True
46
sni = True
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
47
rtype = 'AAAA'
48
vhostname = None
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
49
tests = 1 # Number of repeated tests
Alexandre's avatar
Alexandre committed
50
key = None # SPKI
51
ifile = None # Input file
52
delay = None
53
54
55
forceIPv4 = False
forceIPv6 = False
connectTo = None
56
pipelining = False
57
max_in_flight = 20
58
multistreams = False
Alexandre's avatar
Alexandre committed
59
sync = False
60
display_results = True
61
show_time = False
62
check = False
63
64
mandatory_level = None
check_additional = True
65
66
doh_header = []
doh_header_default = ["Accept: application/dns-message", "Content-type: application/dns-message"]
67
68
69
# Monitoring plugin only:
host = None
path = None
70
expect = None
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
71

72
73
74
# Do not change these
re_host = re.compile(r'^([0-9a-z][0-9a-z-\.]*)|([0-9:]+)|([0-9\.])$')

75
76
77
78
79
80
81
# For the monitoring plugin
STATE_OK = 0
STATE_WARNING = 1
STATE_CRITICAL = 2
STATE_UNKNOWN = 3
STATE_DEPENDENT = 4

Alexandre's avatar
Alexandre committed
82
83
84
85
# For the check option
DOH_GET = 0
DOH_POST = 1
DOH_HEAD = 2
86
87
# Is the test mandatory?
mandatory_levels = {"legal": 30, "necessary": 20, "nicetohave": 10}
Alexandre's avatar
Alexandre committed
88

89
TIMEOUT_CONN = 2
90
91
92
TIMEOUT_READ = 1
SLEEP_TIMEOUT = 0.5
MAX_DURATION = 10
93

94
def error(msg=None, exit=True):
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
95
96
    if msg is None:
        msg = "Unknown error"
97
98
    if monitoring:
        print("%s: %s" % (url, msg))
99
100
        if exit:
            sys.exit(STATE_CRITICAL)
101
    else:
Alexandre's avatar
Alexandre committed
102
103
104
        print(msg, file=sys.stderr)
        if check:
            print('KO')
105
106
        if exit:
            sys.exit(1)
Alexandre's avatar
Alexandre committed
107

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
108
109
110
def usage(msg=None):
    if msg:
        print(msg,file=sys.stderr)
Alexandre's avatar
Alexandre committed
111
112
113
114
115
    print("Usage: %s [options] url-or-servername [domain-name [DNS type]]" % sys.argv[0], file=sys.stderr)
    print("""Options
    -t --dot            Use DoT (by default use DoH)
    -P --post --POST    Use HTTP POST method for all transfers (DoH only)
    -e --head --HEAD    Use HTTP HEAD method for all transfers (DoH only)
Alexandre's avatar
Alexandre committed
116
117
118
    -H --header <header>
                        Use custom header field defined as 'key: value',
                        define as many --header <header> as required (DoH only)
Alexandre's avatar
Alexandre committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    -r --repeat <N>     Perform N times the query. If used with -f, read up to
                        <N> rows of the <file>.
    -d --delay <T>      Time to wait in seconds between each synchronous
                        request (only with --repeat)
    -f --file <file>    Read domain names from <file>, one per row with an
                        optional DNS type
    --check             Perform a set of predefined tests.
    --mandatory-level <level>
                        Define the <level> of test to perform (only with
                        --check)
                        Available <level> : legal, necessary, nicetohave
    --multistreams      Use HTTP/2 streams, needs an input file with -f
                        (DoH only)
    --sync              Process received queries synchronously (only with
                        --multistreams)
    --no-display-results
                        Disable output of DNS response (only with
                        --multistreams)
    --time              Display the time elapsed for the query (only with
                        --multistreams)
    --dnssec            Request DNSSEC data (signatures)
    --noedns            Disable EDNS, default is to indicate EDNS support
    --ecs               Send ECS to authoritative servers, default is to
                        refuse it
    --key <key>         Authenticate a DoT resolver with its public <key> in
                        base64 (DoT only)
    --nosni             Do not perform SNI (DoT only)
Alexandre's avatar
Alexandre committed
146
    -V --vhost <vhost>  Use a specific virtual host
Alexandre's avatar
Alexandre committed
147
148
149
150
151
152
153
154
155
156
157
158
    -k --insecure       Do not check the certificate
    -4 --v4only         Force IPv4 resolution of url-or-servername
    -6 --v6only         Force IPv6 resolution of url-or-servername
    -v --verbose        Make the program more talkative
    --debug             Make the program even more talkative than -v
    -h --help           Print this message

    url-or-servername   The URL or domain name of the DoT/DoH server
    domain-name         The domain name to resolve, not required if -f is
                        provided
    DNS type            The DNS record type to resolve, default AAAA
    """, file=sys.stderr)
159
    print("See the README.md for more details.", file=sys.stderr)
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
160

161
def is_valid_hostname(name):
162
    name = canonicalize(name)
163
    return re_host.search(name)
164

165
166
167
168
169
170
171
172
173
def canonicalize(hostname):
    result = hostname.lower()
    # TODO handle properly the case where it fails with UnicodeError
    # (two consecutive dots for instance) to get a custom exception
    result = result.encode('idna').decode()
    if result[len(result)-1] == '.':
        result = result[:-1]
    return result

174
175
176
177
def is_valid_ip_address(addr):
    try:
        baddr = netaddr.IPAddress(addr)
    except netaddr.core.AddrFormatError:
178
179
        return (False, None)
    return (True, baddr.version)
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def is_valid_url(url):
  try:
    result = urllib.parse.urlparse(url) # A very poor validation, many
    # errors (for instance whitespaces, IPv6 address litterals without
    # brackets...) are ignored.
    return (result.scheme=="https" and result.netloc != "")
  except ValueError:
    return False

def get_certificate_san(x509cert):
    san = ""
    ext_count = x509cert.get_extension_count()
    for i in range(0, ext_count):
        ext = x509cert.get_extension(i)
        if "subjectAltName" in str(ext.get_short_name()):
            san = str(ext)
    return san

199
200
201
202
203
204
205
206
207
208
209
# Try one possible name. Names must be already canonicalized.
def match_hostname(hostname, possibleMatch):
    if possibleMatch.startswith("*."): # Wildcard
        base = possibleMatch[1:] # Skip the star
        # RFC 6125 says that we MAY accept left-most labels with
        # wildcards included (foo*bar). We don't do it here.
        try:
            (first, rest) = hostname.split(".", maxsplit=1)
        except ValueError: # One-label name
            rest = hostname
        if rest == base[1:]:
210
            return True
211
        if hostname == base[1:]:
212
            return True
213
214
215
216
217
218
219
220
        return False
    else:
        return hostname == possibleMatch

# Try all the names in the certificate
def validate_hostname(hostname, cert):
    # Complete specification is in RFC 6125. It is long and
    # complicated and I'm not sure we do it perfectly.
221
    (is_addr, family) = is_valid_ip_address(hostname)
222
    hostname = canonicalize(hostname)
223
    for alt_name in get_certificate_san(cert).split(", "):
224
        if alt_name.startswith("DNS:") and not is_addr:
225
            (start, base) = alt_name.split("DNS:")
226
            base = canonicalize(base)
227
228
            found = match_hostname(hostname, base)
            if found:
229
                return True
230
231
        elif alt_name.startswith("IP Address:") and is_addr:
            host_i = netaddr.IPAddress(hostname)
232
233
234
            (start, base) = alt_name.split("IP Address:")
            if base.endswith("\n"):
                base = base[:-1]
235
236
237
238
            try:
                base_i = netaddr.IPAddress(base)
            except netaddr.core.AddrFormatError:
                continue # Ignore broken IP addresses in certificates. Are we too liberal?
Alexandre's avatar
Alexandre committed
239
            if host_i == base_i:
240
241
                return True
        else:
242
243
244
            pass # Ignore unknown alternative name types. May be
                 # accept URI alternative names for DoH,
    # According to RFC 6125, we MUST NOT try the Common Name before the Subject Alternative Names.
245
    cn = canonicalize(cert.get_subject().commonName)
246
247
248
    found = match_hostname(hostname, cn)
    if found:
        return True
249
250
    return False

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def check_ip_address(addr, dot=dot):
    repraddress = addr
    (is_addr, family) = is_valid_ip_address(addr)
    if not is_addr and not dot:
        raise CustomException("%s is not IPv4 and not IPv6" % addr)
    if forceIPv4 and family == 6:
        raise CustomException("You cannot force IPv4 with a litteral IPv6 address (%s)" % addr)
    elif forceIPv6 and family == 4:
        raise CustomException("You cannot force IPv6 with a litteral IPv4 address (%s)" % addr)
    if forceIPv4 or family == 4:
        family = socket.AF_INET
        repraddress = addr
    elif forceIPv6 or family == 6:
        family = socket.AF_INET6
        repraddress = f'[{addr}]'
    else:
        family = 0
    return (family, repraddress)

270
271
272
273
274
275
def dump_data(data, text="data"):
    pref = ' ' * (len(text) - 4)
    print(f'{text}: ', data)
    print(pref, 'hex:', " ".join(format(c, '02x') for c in data))
    print(pref, 'bin:', " ".join(format(c, '08b') for c in data))

276
277
278
279
280
281
def timeout_connection(signum, frame):
    raise TimeoutConnectionError('Connection timeout')

class TimeoutConnectionError(Exception):
    pass

Alexandre's avatar
Alexandre committed
282
283
284
class CustomException(Exception):
    pass

Alexandre's avatar
Alexandre committed
285
class Request:
286
    def __init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec):
287
288
289
290
291
        if no_ecs:
             opt = dns.edns.ECSOption(address='', srclen=0) # Disable ECS (RFC 7871, section 7.1.2)
             options = [opt]
        else:
            options = None
292
        self.message = dns.message.make_query(qname, dns.rdatatype.from_text(qtype),
293
                                              use_edns=use_edns, want_dnssec=want_dnssec, options=options)
294
        self.message.flags |= dns.flags.AD # Ask for validation
Alexandre's avatar
Alexandre committed
295
        self.ok = True
296
        self.i = 0 # request's number on the connection (default to the first)
Alexandre's avatar
Alexandre committed
297

298
299
300
301
    def trunc_data(self):
        self.data = self.message.to_wire()
        half = round(len(self.data) / 2)
        self.data = self.data[:half]
Alexandre's avatar
Alexandre committed
302
303

    def to_wire(self):
304
305
        self.data = self.message.to_wire()

Alexandre's avatar
Alexandre committed
306
307

class RequestDoT(Request):
Alexandre's avatar
Alexandre committed
308
    def check_response(self, debug=False):
309
310
        if self.response is None:
            raise Exception("No reply received")
311
        ok = self.ok
312
        if not self.rcode:
313
314
            self.ok = False
            return False
315
        if self.response.id != self.message.id:
316
            self.response = "The ID in the answer does not match the one in the query"
Alexandre's avatar
Alexandre committed
317
318
            if debug:
                self.response += f'"(query id: {self.message.id}) (response id: {self.response.id})'
319
320
321
            self.ok = False
            return False
        return self.ok
322

323
324
325
326
    def store_response(self, rcode, response, size):
        self.rcode = True
        self.response = response
        self.response_size = size
327

328

Alexandre's avatar
Alexandre committed
329
class RequestDoH(Request):
330
    def __init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec):
Alexandre's avatar
Alexandre committed
331
        Request.__init__(self, qname, qtype=qtype, use_edns=edns, want_dnssec=dnssec)
332
        self.message.id = 0 # DoH requests that
333
        self.post = False
334
        self.head = False
335

Alexandre's avatar
Alexandre committed
336
    def check_response(self, debug=False):
Alexandre's avatar
Alexandre committed
337
        ok = self.ok
338
        if self.rcode == 200:
339
340
            if self.ctype != "application/dns-message":
                self.response = "Content type of the response (\"%s\") invalid" % self.ctype
341
342
                ok = False
            else:
343
                if not self.head:
344
                    try:
Alexandre's avatar
Alexandre committed
345
                        response = dns.message.from_wire(self.response)
346
347
348
                    except dns.message.TrailingJunk: # Not DNS. Should
                        # not happen for a content type
                        # application/dns-message but who knows?
Alexandre's avatar
Alexandre committed
349
350
351
                        self.response = "ERROR Not proper DNS data, trailing junk"
                        if debug:
                            self.response += " \"%s\"" % response
352
353
                        ok = False
                    except dns.name.BadLabelType: # Not DNS.
Alexandre's avatar
Alexandre committed
354
355
356
                        self.response = "ERROR Not proper DNS data (wrong path in the URL?)"
                        if debug:
                            self.response += " \"%s\"" % response[:100]
357
                        ok = False
Alexandre's avatar
Alexandre committed
358
359
                    else:
                        self.response = response
360
                else:
361
                    if self.response_size == 0:
362
363
                        self.response = "HEAD successful"
                    else:
Alexandre's avatar
Alexandre committed
364
365
366
367
                        data = self.response
                        self.response = "ERROR Body length is not null"
                        if debug:
                            self.response += "\"%s\"" % data[:100]
368
369
370
                        ok = False
        else:
            ok = False
371
            if self.response_size == 0:
372
373
                self.response = "[No details]"
            else:
374
                self.response = self.response
Alexandre's avatar
Alexandre committed
375
        self.ok = ok
376
        return ok
377

378

379
class Connection:
380
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
381
                 dot=dot, verbose=verbose, debug=debug, insecure=insecure):
382
        if dot and not is_valid_hostname(server):
383
            error("DoT requires a host name or IP address, not \"%s\"" % server)
384
        if not dot and not is_valid_url(server):
385
            error("DoH requires a valid HTTPS URL, not \"%s\"" % server)
386
        if forceIPv4 and forceIPv6:
Alexandre's avatar
Alexandre committed
387
            raise CustomException("Force IPv4 *or* IPv6 but not both")
388
        self.dot = dot
389
        self.server = server
390
391
        self.servername = servername
        if self.servername is not None:
Alexandre's avatar
Alexandre committed
392
            self.check = self.servername
393
        else:
Alexandre's avatar
Alexandre committed
394
            self.check = self.server
395
396
        self.dot = dot
        self.verbose = verbose
397
        self.debug = debug
398
        self.insecure = insecure
399
400
        self.forceIPv4 = forceIPv4
        self.forceIPv6 = forceIPv6
401
        self.connect_to = connect
402

403
404
    def __str__(self):
        return self.server
405

406
407
    def do_test(self, request):
        # Routine doing one actual test. Returns nothing
Alexandre's avatar
Alexandre committed
408
409
        pass

410
411
412

class ConnectionDoT(Connection):
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
Alexandre's avatar
Alexandre committed
413
                 pipelining=pipelining, verbose=verbose, debug=debug, insecure=insecure):
414
        Connection.__init__(self, server, servername=servername, connect=connect,
415
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=True,
416
                verbose=verbose, debug=debug, insecure=insecure)
417
418
419
420
        if connect is not None:
            addr = connect
        else:
            addr = self.server
421
        self.family, self.repraddress = check_ip_address(self.server, dot=True)
Alexandre's avatar
Alexandre committed
422
        addrinfo_list = socket.getaddrinfo(addr, 853, self.family)
423
424
        addrinfo_set = { (addrinfo[4], addrinfo[0]) for addrinfo in addrinfo_list }
        signal.signal(signal.SIGALRM, timeout_connection)
Alexandre's avatar
Alexandre committed
425
        self.success = False
426
427
        for addrinfo in addrinfo_set:
            self.hasher = hashlib.sha256()
428
            if self.connect(addrinfo[0], addrinfo[1]):
Alexandre's avatar
Alexandre committed
429
                self.success = True
430
                break
Alexandre's avatar
Alexandre committed
431
            if self.verbose and connect is None:
Alexandre's avatar
Alexandre committed
432
                print("Trying another IP address")
Alexandre's avatar
Alexandre committed
433
434
        if not self.success:
            if self.verbose and connect is None:
Alexandre's avatar
Alexandre committed
435
                print("No other IP address")
Alexandre's avatar
Alexandre committed
436
437
438
            if connect is None:
                error(f'Could not connect to "{server}"')
            else:
439
                print(f'Could not connect to "{server}" on {connect}')
Alexandre's avatar
Alexandre committed
440
441
442
443
444
445
446
447
448
449
        self.pipelining = pipelining
        if pipelining:
            self.all_requests = [] # Currently, we load everything in memory
                                   # since we want to keep everything,
                                   # anyway. May be in the future, if we don't
                                   # want to keep individual results, we'll use
                                   # an iterator to fill a smaller table.
                                   # all_requests is indexed by its rank in the input file.
            self.pending = {} # pending is indexed by the query ID, and its
                              # maximum size is max_in_flight.
450
451
452
453
454

    def connect(self, addr, sock_family):
        signal.alarm(TIMEOUT_CONN)
        self.addr = addr
        self.sock = socket.socket(sock_family, socket.SOCK_STREAM)
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        if self.verbose:
            print("Connecting to %s ..." % str(self.addr))
        # With typical DoT servers, we *must* use TLS 1.2 (otherwise,
        # do_handshake fails with "OpenSSL.SSL.SysCallError: (-1, 'Unexpected
        # EOF')" Typical HTTP servers are more lax.
        self.context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD)
        if self.insecure:
            self.context.set_verify(OpenSSL.SSL.VERIFY_NONE, lambda *x: True)
        else:
            self.context.set_default_verify_paths()
            self.context.set_verify_depth(4) # Seems ignored
            self.context.set_verify(OpenSSL.SSL.VERIFY_PEER | OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT | \
                                    OpenSSL.SSL.VERIFY_CLIENT_ONCE,
                                    lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
        self.session = OpenSSL.SSL.Connection(self.context, self.sock)
470
471
        if sni:
            self.session.set_tlsext_host_name(canonicalize(self.check).encode()) # Server Name Indication (SNI)
472
        try:
473
            self.session.connect((self.addr))
474
            # TODO We may here have exceptions such as OpenSSL.SSL.ZeroReturnError
475
            self.session.do_handshake()
476
        except TimeoutConnectionError:
477
            if self.verbose:
Alexandre's avatar
Alexandre committed
478
                print("Timeout")
479
            return False
480
481
482
483
        except OSError:
            if self.verbose:
                print("Cannot connect")
            return False
Alexandre's avatar
Alexandre committed
484
485
486
487
        except OpenSSL.SSL.Error as e:
            if self.verbose:
                print(f"OpenSSL error: {', '.join(err[0][2] for err in e.args)}")
            return False
488
        # RFC 7858, section 4.2 and appendix A
489
        self.cert = self.session.get_peer_certificate()
490
        self.publickey = self.cert.get_pubkey()
491
        if debug or key is not None:
492
493
494
495
            self.hasher.update(OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
                                                  self.publickey))
            self.digest = self.hasher.digest()
            key_string = base64.standard_b64encode(self.digest).decode()
496
        if debug:
497
498
499
500
501
            print("Certificate #%x for \"%s\", delivered by \"%s\"" % \
                  (self.cert.get_serial_number(),
                   self.cert.get_subject().commonName,
                   self.cert.get_issuer().commonName))
            print("Public key is pin-sha256=\"%s\"" % \
502
                  key_string)
503
        if not insecure:
504
505
506
507
508
509
510
            if key is None:
                valid = validate_hostname(self.check, self.cert)
                if not valid:
                    error("Certificate error: \"%s\" is not in the certificate" % (self.check))
            else:
                if key_string != key:
                    error("Key error: expected \"%s\", got \"%s\"" % (key, key_string))
511
        signal.alarm(0)
512
513
        if pipelining:
            self.sock.settimeout(TIMEOUT_READ)
514
        return True
515

516
    def end(self):
517
518
        self.session.shutdown()
        self.session.close()
519

520
521
522
    def send_data(self, data, dump=False):
        if dump:
            dump_data(data, 'data sent')
523
524
525
        length = len(data)
        self.session.send(length.to_bytes(2, byteorder='big') + data)

526
    def receive_data(self, dump=False):
527
528
529
530
        try:
            buf = self.session.recv(2)
        except OpenSSL.SSL.WantReadError:
            return (False, None, None)
531
532
        size = int.from_bytes(buf, byteorder='big')
        buf = self.session.recv(size)
533
534
        if dump:
            dump_data(buf, 'data recv')
535
536
        response = dns.message.from_wire(buf)
        return (True, response, size)
537

538
539
    def send_and_receive(self, request, dump=False):
        self.send_data(request.data, dump=dump)
540
        rcode, response, size = self.receive_data(dump=dump)
541
        request.store_response(rcode, response, size)
542

543
    def do_test(self, request, synchronous=True):
544
        self.send_data(request.data)
545
        if synchronous:
546
547
            rcode, response, size = self.receive_data()
            request.store_response(rcode, response, size)
548
            request.check_response(self.debug)
549

Alexandre's avatar
Alexandre committed
550
551
552
553
554
555
556
557
558
559
    def pipelining_add_request(self, request):
        self.all_requests.append({'request': request, 'response': None}) # No answer yet

    def pipelining_fill_pending(self, index):
        if index < len(self.all_requests):
            request = self.all_requests[index]['request']
            id = request.message.id
            # TODO check there is no duplicate in IDs
            self.pending[id] = (False, index, request)
            self.do_test(request, synchronous = False)
560

Alexandre's avatar
Alexandre committed
561
562
563
564
565
566
567
    def pipelining_init_pending(self, max_in_flight):
        for i in range(0, max_in_flight):
            if i == len(self.all_requests):
                break
            self.pipelining_fill_pending(i)
        return i

568
569
570
571
572
573
574
575
576
577
578
579
580
581
    def read_result(self, connection, requests):
        rcode, response, size = self.receive_data() # TODO can raise
                                                    # OpenSSL.SSL.ZeroReturnError
                                                    # if the
                                                    # conenction was
                                                    # closed
        if not rcode:
            if display_results:
                print("TIMEOUT")
            return None
        id = response.id
        if id not in requests:
            raise Exception("Received response for ID %s which is unexpected" % id)
        over, rank, request = requests[id]
Alexandre's avatar
Alexandre committed
582
        self.all_requests[rank]['response'] = (rcode, response, size)
583
584
        requests[id] = (True, rank, request)
        if display_results:
585
586
            print()
            print(response)
587
588
        # TODO a timeout if some responses are lost?
        return id
589

590
def create_handle(connection, header=doh_header_default):
591
    def reset_opt_default(handle):
592
593
594
595
596
597
598
        opts = {
                pycurl.NOBODY: False,
                pycurl.POST: False,
                pycurl.POSTFIELDS: '',
                pycurl.URL: ''
               }
        for opt, value in opts.items():
599
            handle.setopt(opt, value)
600

601
    def prepare(handle, connection, request):
602
        if not connection.multistreams:
603
            handle.reset_opt_default(handle)
604
        if request.post:
605
606
607
            handle.setopt(pycurl.POST, True)
            handle.setopt(pycurl.POSTFIELDS, request.data)
            handle.setopt(pycurl.URL, connection.server)
608
        else:
609
610
611
612
613
            handle.setopt(pycurl.HTTPGET, True) # automatically sets CURLOPT_NOBODY to 0
            if request.head:
                handle.setopt(pycurl.NOBODY, True)
            dns_req = base64.urlsafe_b64encode(request.data).decode('UTF8').rstrip('=')
            handle.setopt(pycurl.URL, connection.server + ("?dns=%s" % dns_req))
614
615
        if hasattr(request, 'header') and len(request.header) > 0: # overwrite default header
            handle.setopt(pycurl.HTTPHEADER, request.header)
616
617
618
        handle.buffer = io.BytesIO()
        handle.setopt(pycurl.WRITEDATA, handle.buffer)
        handle.request = request
619
620
621
622
623

    handle = pycurl.Curl()
    # Does not work if pycurl was not compiled with nghttp2 (recent Debian
    # packages are OK) https://github.com/pycurl/pycurl/issues/477
    handle.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
Alexandre's avatar
Alexandre committed
624
    if connection.debug:
625
626
627
628
629
630
631
632
633
634
635
        handle.setopt(pycurl.VERBOSE, True)
    if connection.insecure:
        handle.setopt(pycurl.SSL_VERIFYPEER, False)
        handle.setopt(pycurl.SSL_VERIFYHOST, False)
    if connection.forceIPv4:
        handle.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
    if connection.forceIPv6:
        handle.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V6)
    if connection.connect is not None:
        family, repraddress = check_ip_address(connection.connect, dot=False)
        handle.setopt(pycurl.CONNECT_TO, [f'::{repraddress}:443',])
636
    handle.setopt(pycurl.HTTPHEADER, header)
637
638
639
    handle.reset_opt_default = reset_opt_default
    handle.prepare = prepare
    return handle
640

641
642
643

class ConnectionDoH(Connection):
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
Alexandre's avatar
Alexandre committed
644
                 multistreams=False, verbose=verbose, debug=debug, insecure=insecure):
645
        Connection.__init__(self, server, servername=servername, connect=connect,
646
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=False,
Alexandre's avatar
Alexandre committed
647
                verbose=verbose, debug=debug, insecure=insecure)
648
        self.url = server
649
        self.connect = connect
Alexandre's avatar
Alexandre committed
650
        self.multistreams = multistreams
651
652
        if self.multistreams:
            self.multi = self.create_multi()
653
            self.all_handles = []
654
            self.finished = { 'http': {} }
Alexandre's avatar
Alexandre committed
655
        else:
656
            self.curl_handle = create_handle(self)
657

658
659
660
661
    def create_multi(self):
        multi = pycurl.CurlMulti()
        multi.setopt(pycurl.M_MAX_HOST_CONNECTIONS, 1)
        return multi
662

663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    def init_multi(self):
        # perform a first query alone
        # to establish the connection and hence avoid starting
        # the transfer of all the other queries simultaneously
        # query the root NS because this should not impact the resover cache
        if verbose:
            print("Establishing multistreams connection...")
        request = create_request('.', qtype='NS', dot=False)
        try:
            self.do_test(request, synchronous=False)
        except (OpenSSL.SSL.Error, CustomException) as e:
            ok = False
            error(e)
        self.perform_multi(silent=True)
        self.all_handles = []
        self.finished = { 'http': {} }

680
681
682
    def end(self):
        if not self.multistreams:
            self.curl_handle.close()
683
        else:
684
            self.remove_handles()
685
            self.multi.close()
686

687
688
689
690
691
692
693
    def remove_handles(self):
        n, handle_success, handle_fail = self.multi.info_read()
        handles = handle_success + handle_fail
        for h in handles:
            h.close()
            self.multi.remove_handle(h)

694
    def perform_multi(self, silent=False):
695
696
697
698
699
700
701
702
703
704
        while 1:
            ret, num_handles = self.multi.perform()
            if ret != pycurl.E_CALL_MULTI_PERFORM:
                break
        while num_handles:
            ret = self.multi.select(1.0)
            if ret == -1:
                continue
            while 1:
                ret, num_handles = self.multi.perform()
Alexandre's avatar
Alexandre committed
705
706
707
                if not sync:
                    n, handle_pass, handle_fail = self.multi.info_read()
                    for handle in handle_pass:
708
                        self.read_result_handle(handle, silent=silent)
709
710
                if ret != pycurl.E_CALL_MULTI_PERFORM:
                    break
711
712
713
        if not sync:
            n, handle_pass, handle_fail = self.multi.info_read()
            for handle in handle_pass:
714
                self.read_result_handle(handle, silent=silent)
715

716
717
718
719
720
721
722
    def send(self, handle):
        handle.buffer = io.BytesIO()
        handle.setopt(pycurl.WRITEDATA, handle.buffer)
        try:
            handle.perform()
        except pycurl.error as e:
            error(e.args[1])
Alexandre's avatar
Alexandre committed
723

724
725
726
    def receive(self, handle):
        request = handle.request
        body = handle.buffer.getvalue()
727
        body_size = len(body)
728
        http_code = handle.getinfo(pycurl.RESPONSE_CODE)
Alexandre's avatar
Alexandre committed
729
730
        handle.time = handle.getinfo(pycurl.TOTAL_TIME)
        handle.pretime = handle.getinfo(pycurl.PRETRANSFER_TIME)
731
        try:
Alexandre's avatar
Alexandre committed
732
733
            content_type = handle.getinfo(pycurl.CONTENT_TYPE)
        except TypeError: # This is the exception we get if there is no Content-Type: (for intance in response to HEAD requests)
734
            content_type = None
735
736
        request.response = body
        request.response_size = body_size
737
738
        request.rcode = http_code
        request.ctype = content_type
739
        handle.buffer.close()
740

Alexandre's avatar
Alexandre committed
741
    def send_and_receive(self, handle, dump=False):
742
743
        self.send(handle)
        self.receive(handle)
744

745
    def read_result_handle(self, handle, silent=False):
746
747
        self.receive(handle)
        handle.request.check_response()
748
        if not silent and show_time:
749
            self.print_time(handle)
750
751
752
753
        try:
            self.finished['http'][handle.request.rcode] += 1
        except KeyError:
            self.finished['http'][handle.request.rcode] = 1
754
        if not silent and display_results:
755
756
757
            print("Return code %s (%.2f ms):" % (handle.request.rcode,
                (handle.time - handle.pretime) * 1000))
            print(f"{handle.request.response}\n")
Alexandre's avatar
Alexandre committed
758
759
        handle.close()
        self.multi.remove_handle(handle)
760
761
762
763
764

    def read_results(self):
        for handle in self.all_handles:
            self.read_result_handle(handle)

765
766
767
768
769
770
771
    def print_time(self, handle):
        print(f'{handle.request.i:3d}', end='   ')
        print(f'({handle.request.rcode})', end='   ')
        print(f'{handle.pretime * 1000:8.3f} ms', end='  ')
        print(f'{handle.time * 1000:8.3f} ms', end='  ')
        print(f'{(handle.time - handle.pretime) * 1000:8.3f} ms')

772
    def do_test(self, request, synchronous=True):
773
        if synchronous:
774
            handle = self.curl_handle
775
        else:
776
            handle = create_handle(self)
777
            self.all_handles.append(handle)
778
        handle.prepare(handle, self, request)
Alexandre's avatar
Alexandre committed
779
        if synchronous:
780
            self.send_and_receive(handle)
Alexandre's avatar
Alexandre committed
781
            request.check_response(self.debug)
782
        else:
783
            self.multi.add_handle(handle)
784
785
786
787
788
789
790
791
792
793
794
795
796


def get_next_domain(input_file):
    name, rtype = 'framagit.org', 'AAAA'
    line = input_file.readline()
    if line[:-1] == "":
        error("Not enough data in %s for the %i tests" % (ifile, tests))
    if line.find(' ') == -1:
        name = line[:-1]
        rtype = 'AAAA'
    else:
        (name, rtype) = line.split()
    return name, rtype
Alexandre's avatar
Alexandre committed
797

798
def print_result(connection, request, prefix=None, display_err=True):
Alexandre's avatar
Alexandre committed
799
    ok = request.ok
800
801
802
803
804
805
806
    dot = connection.dot
    server = connection.server
    rcode = request.rcode
    msg = request.response
    size = request.response_size
    if (dot and rcode) or (not dot and rcode == 200):
        if not monitoring:
807
808
            if not dot and show_time:
                connection.print_time(connection.curl_handle)
809
            if display_results and (not check or verbose):
810
                print(msg)
811
        else:
812
            if expect is not None and expect not in str(request.response):
813
                ok = False
814
815
                print("%s Cannot find \"%s\" in response" % (server, expect))
                sys.exit(STATE_CRITICAL)
816
            if ok and size is not None and size > 0:
817
                print("%s OK - %s" % (server, "No error for %s/%s, %i bytes received" % (name, rtype, size)))
818
            else:
819
820
821
822
                print("%s OK - %s" % (server, "No error"))
            sys.exit(STATE_OK)
    else:
        if not monitoring:
823
            if display_err:
824
825
                if check:
                    print(connection.connect_to, end=': ', file=sys.stderr)
826
827
828
829
830
831
832
833
834
835
                if prefix:
                    print(prefix, end=': ', file=sys.stderr)
                if dot:
                    print("Error: %s" % msg, file=sys.stderr)
                else:
                   try:
                       msg = msg.decode()
                   except (UnicodeDecodeError, AttributeError):
                       pass # Sometimes, msg can be binary, or Latin-1
                   print("HTTP error %i: %s" % (rcode, msg), file=sys.stderr) 
836
        else:
837
838
            if not dot:
                print("%s HTTP error - %i: %s" % (server, rcode, msg))
839
            else:
840
841
842
843
844
                print("%s Error - %i: %s" % (server, rcode, msg))
            sys.exit(STATE_CRITICAL)
        ok = False
    return ok

845
def create_request(qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec, dot=dot, trunc=False):
Alexandre's avatar
Alexandre committed
846
    if dot:
Alexandre's avatar
Alexandre committed
847
        request = RequestDoT(qname, qtype, use_edns, want_dnssec)
Alexandre's avatar
Alexandre committed
848
    else:
Alexandre's avatar
Alexandre committed
849
        request = RequestDoH(qname, qtype, use_edns, want_dnssec)
Alexandre's avatar
Alexandre committed
850
851
852
853
854
855
856
857
858
    if trunc:
        request.trunc_data()
    else:
        request.to_wire()
    return request

def create_requests_list(dot=dot, **req_args):
    requests = []
    if dot:
859
860
861
862
863
864
        requests.append(('Test 1', create_request(dot=dot, **req_args),
                        mandatory_levels["legal"]))
        requests.append(('Test 2', create_request(dot=dot, **req_args),
                         mandatory_levels["necessary"])) # RFC 7858,
        # section 3.3, SHOULD accept several requests on one connection.
        # TODO we miss the tests of pipelining and out-of-order.
Alexandre's avatar
Alexandre committed
865
    else:
866
867
868
869
870
871
872
873
        requests.append(('Test GET', create_request(**req_args), DOH_GET,
                         mandatory_levels["legal"])) # RFC 8484, section 4.1
        requests.append(('Test POST', create_request(**req_args), DOH_POST,
                         mandatory_levels["legal"])) # RFC 8484, section 4.1
        requests.append(('Test HEAD', create_request(**req_args), DOH_HEAD,
                         mandatory_levels["nicetohave"])) # HEAD
        # method is not mentioned in RFC 8484 (see section 4.1), so
        # just "nice to have".
Alexandre's avatar
Alexandre committed
874
    return requests
875

876
def run_check_default(connection):
Alexandre's avatar
Alexandre committed
877
    ok = True
878
879
880
881
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    requests = create_requests_list(dot=dot, **req_args)
    for request_pack in requests:
        if dot:
882
            test_name, request, mandatory = request_pack
883
        else:
884
            test_name, request, method, mandatory = request_pack
885
886
        if verbose:
            print(test_name)
Alexandre's avatar
Alexandre committed
887
888
889
        if dot:
            bundle = request
        else:
890
891
892
893
            if method == DOH_POST:
                request.post = True
            elif method == DOH_HEAD:
                request.head = True
894
895
896
            handle = connection.curl_handle
            handle.prepare(handle, connection, request)
            bundle = handle
897
        try:
898
            connection.send_and_receive(bundle)
899
900
901
902
        except CustomException as e:
            ok = False
            error(e)
            break
Alexandre's avatar
Alexandre committed
903
        request.check_response(debug)
904
905
906
907
        if not print_result(connection, request, prefix=test_name, display_err=False):
            if mandatory >= mandatory_level:
                print_result(connection, request, prefix=test_name, display_err=True)
                ok = False
Alexandre's avatar
Alexandre committed
908
909
            if verbose:
                print()
910
            break
Alexandre's avatar
Alexandre committed
911
912
        if verbose:
            print()
Alexandre's avatar
Alexandre committed
913
    return ok
914

915
def run_check_mime(connection, header):
Alexandre's avatar
Alexandre committed
916
917
918
    # change the MIME value and see what happens
    # based on the RFC only application/dns-message must be supported, any
    # other MIME type can be also supported, but nothing is said on that
919
920
921
    if dot:
        return True
    ok = True
Alexandre's avatar
Alexandre committed
922
923
924
    if verbose:
        test_name = f'Test mime: {", ".join(h for h in header)}'
        print(test_name)
925
926
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    request = create_request(**req_args)
927
928
929
    handle = connection.curl_handle
    handle.setopt(pycurl.HTTPHEADER, header)
    handle.prepare(handle, connection, request)
930
    try:
931
        connection.send_and_receive(handle)
932
933
934
    except CustomException as e:
        ok = False
        error(e)
Alexandre's avatar
Alexandre committed
935
    request.check_response(debug)
936
937
    if not print_result(connection, request, prefix=f"Test Header {', '.join(header)}"):
        ok = False
938
    handle.setopt(pycurl.HTTPHEADER, doh_header_default)
Alexandre's avatar
Alexandre committed
939
940
    if verbose:
        print()
941
942
    return ok

943
def run_check_trunc(connection):
Alexandre's avatar
Alexandre committed
944
945
946
947
    # send truncated DNS request to the server and expect a HTTP return code
    # either equal to 200 or in the 400 range
    # in case the server answers with 200, look for a FORMERR error in the DNS
    # response
948
949
950
951
952
953
954
    ok = True
    test_name = 'Test truncated data'
    if verbose:
        print(test_name)
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    if dot:
        request = create_request(dot=dot, trunc=True, **req_args)
955
        bundle = request
956
957
958
    else:
        request = create_request(trunc=True, **req_args)
        request.post = True
959
960
961
        handle = connection.curl_handle
        handle.prepare(handle, connection, request)
        bundle = handle
962
    try:
963
        # 8.8.8.8 replies FORMERR but most DoT servers violently shut down the connection (which is legal)
Alexandre's avatar
Alexandre committed
964
        connection.send_and_receive(bundle, dump=debug)
965
966
967
    except CustomException as e:
        ok = False
        error(e)
968
969
    except OpenSSL.SSL.ZeroReturnError: # This is acceptable
        return ok
970
971
972
973
974
    except dns.exception.FormError: # This is also acceptable
        # Some DSN resolvers will echo mangled requests with
        # the RCODE set to FORMERR
        # so response can not be parsed in this case
        return ok
Alexandre's avatar
Alexandre committed
975
    if request.check_response(debug): # FORMERR is expected
Alexandre's avatar
Alexandre committed