homer.py 56.3 KB
Newer Older
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
1
2
#!/usr/bin/env python3

3
4
5
6
7
# Homer is a DoH (DNS-over-HTTPS) and DoT (DNS-over-TLS) client. Its
# main purpose is to test DoH and DoT resolvers. Reference site is
# <https://framagit.org/bortzmeyer/homer/> See author, documentation,
# etc, there, or in the README.md included with the distribution.

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
8
9
10
11
12
13
import io
import sys
import base64
import getopt
import urllib.parse
import time
14
15
import socket
import ctypes
16
import re
17
import os.path
18
19
import hashlib
import base64
20
import signal
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
21

Alexandre's avatar
Alexandre committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
try:
    # http://pycurl.io/docs/latest
    import pycurl

    # http://www.dnspython.org/
    import dns.message

    # https://github.com/drkjam/netaddr/
    import netaddr

    # Octobre 2019: the Python GnuTLS bindings don't work with Python 3. So we use OpenSSL.
    # https://www.pyopenssl.org/
    # https://pyopenssl.readthedocs.io/
    import OpenSSL
except ImportError as e:
    print("Error: missing module")
    print(e)
    sys.exit(1)

41
# Values that can be changed from the command line
Alexandre's avatar
Alexandre committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class opts:
    dot = False # DoH by default
    verbose = False
    debug = False
    insecure = False
    post = False
    head = False
    dnssec = False
    edns = True
    no_ecs = True
    sni = True
    rtype = 'AAAA'
    vhostname = None
    tests = 1 # Number of repeated tests
    key = None # SPKI
    ifile = None # Input file
    delay = None
    forceIPv4 = False
    forceIPv6 = False
    connectTo = None
    pipelining = False
    max_in_flight = 20
    multistreams = False
    display_results = True
    show_time = False
    check = False
    mandatory_level = None
    check_additional = True
    # Monitoring plugin only:
    host = None
    path = None
    expect = None
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
74

75
76
77
# Do not change these
re_host = re.compile(r'^([0-9a-z][0-9a-z-\.]*)|([0-9:]+)|([0-9\.])$')

78
79
80
81
82
83
84
# For the monitoring plugin
STATE_OK = 0
STATE_WARNING = 1
STATE_CRITICAL = 2
STATE_UNKNOWN = 3
STATE_DEPENDENT = 4

Alexandre's avatar
Alexandre committed
85
86
87
88
# For the check option
DOH_GET = 0
DOH_POST = 1
DOH_HEAD = 2
89
90
# Is the test mandatory?
mandatory_levels = {"legal": 30, "necessary": 20, "nicetohave": 10}
Alexandre's avatar
Alexandre committed
91

92
TIMEOUT_CONN = 2
93
94
95
TIMEOUT_READ = 1
SLEEP_TIMEOUT = 0.5
MAX_DURATION = 10
96

97
def error(msg=None, exit=True):
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
98
99
    if msg is None:
        msg = "Unknown error"
100
101
    if monitoring:
        print("%s: %s" % (url, msg))
102
103
        if exit:
            sys.exit(STATE_CRITICAL)
104
    else:
Alexandre's avatar
Alexandre committed
105
        print(msg, file=sys.stderr)
Alexandre's avatar
Alexandre committed
106
        if opts.check:
Alexandre's avatar
Alexandre committed
107
            print('KO')
108
109
        if exit:
            sys.exit(1)
Alexandre's avatar
Alexandre committed
110

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
111
112
113
def usage(msg=None):
    if msg:
        print(msg,file=sys.stderr)
Alexandre's avatar
Alexandre committed
114
115
116
    print("Usage: %s [options] url-or-servername [domain-name [DNS type]]" % sys.argv[0], file=sys.stderr)
    print("""Options
    -t --dot            Use DoT (by default use DoH)
117
118
119
120
121
    -k --insecure       Do not check the certificate
    -4 --v4only         Force IPv4 resolution of url-or-servername
    -6 --v6only         Force IPv6 resolution of url-or-servername
    -v --verbose        Make the program more talkative
    --debug             Make the program even more talkative than -v
Alexandre's avatar
Alexandre committed
122
    -r --repeat <N>     Perform N times the query. If used with -f, read up to
123
                        <N> lines of the <file>
Alexandre's avatar
Alexandre committed
124
125
126
    -d --delay <T>      Time to wait in seconds between each synchronous
                        request (only with --repeat)
    -f --file <file>    Read domain names from <file>, one per row with an
127
128
129
                        optional DNS type. Read the first line only, use
                        --repeat N to read up to N lines of the file
    --check             Perform a set of predefined tests
Alexandre's avatar
Alexandre committed
130
131
132
133
134
    --mandatory-level <level>
                        Define the <level> of test to perform (only with
                        --check)
                        Available <level> : legal, necessary, nicetohave
    --no-display-results
135
                        Disable output of DNS response
Alexandre's avatar
Alexandre committed
136
137
138
139
    --dnssec            Request DNSSEC data (signatures)
    --noedns            Disable EDNS, default is to indicate EDNS support
    --ecs               Send ECS to authoritative servers, default is to
                        refuse it
Alexandre's avatar
Alexandre committed
140
    -V --vhost <vhost>  Use a specific virtual host
Alexandre's avatar
Alexandre committed
141
142
    -h --help           Print this message

143
144
145
146
147
148
149
150
151
152
153
  DoH only options:
    -P --post --POST    Use HTTP POST method for all the transfers
    -e --head --HEAD    Use HTTP HEAD method for all the transfers
    --multistreams      Use HTTP/2 streams, needs an input file with -f
    --time              Display the time elapsed for the query (only with
                        --multistreams)

  DoT only options:
    --key <key>         Authenticate a DoT resolver with its public <key> in
                        base64
    --nosni             Do not perform SNI
Alexandre's avatar
Alexandre committed
154
155
156
    --pipelining        Pipeline the requests, needs an input file with -f
    --max-in-flight <M> Maximum number of concurrent requests in parallel (only
                        with --pipelining)
157

Alexandre's avatar
Alexandre committed
158
159
160
161
162
    url-or-servername   The URL or domain name of the DoT/DoH server
    domain-name         The domain name to resolve, not required if -f is
                        provided
    DNS type            The DNS record type to resolve, default AAAA
    """, file=sys.stderr)
163
    print("See the README.md for more details.", file=sys.stderr)
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
164

165
def is_valid_hostname(name):
166
    name = canonicalize(name)
167
    return re_host.search(name)
168

169
170
171
172
173
174
175
176
177
def canonicalize(hostname):
    result = hostname.lower()
    # TODO handle properly the case where it fails with UnicodeError
    # (two consecutive dots for instance) to get a custom exception
    result = result.encode('idna').decode()
    if result[len(result)-1] == '.':
        result = result[:-1]
    return result

178
def is_valid_ip_address(addr):
179
    """ Return True and the address family if the IP address is valid. """
180
181
182
    try:
        baddr = netaddr.IPAddress(addr)
    except netaddr.core.AddrFormatError:
183
184
        return (False, None)
    return (True, baddr.version)
185

186
187
188
189
190
191
192
193
194
def is_valid_url(url):
  try:
    result = urllib.parse.urlparse(url) # A very poor validation, many
    # errors (for instance whitespaces, IPv6 address litterals without
    # brackets...) are ignored.
    return (result.scheme=="https" and result.netloc != "")
  except ValueError:
    return False

Alexandre's avatar
Alexandre committed
195
def _get_certificate_san(x509cert):
196
197
198
199
200
201
202
203
    san = ""
    ext_count = x509cert.get_extension_count()
    for i in range(0, ext_count):
        ext = x509cert.get_extension(i)
        if "subjectAltName" in str(ext.get_short_name()):
            san = str(ext)
    return san

204
# Try one possible name. Names must be already canonicalized.
Alexandre's avatar
Alexandre committed
205
def _match_hostname(hostname, possibleMatch):
206
207
208
209
210
211
212
213
214
    if possibleMatch.startswith("*."): # Wildcard
        base = possibleMatch[1:] # Skip the star
        # RFC 6125 says that we MAY accept left-most labels with
        # wildcards included (foo*bar). We don't do it here.
        try:
            (first, rest) = hostname.split(".", maxsplit=1)
        except ValueError: # One-label name
            rest = hostname
        if rest == base[1:]:
215
            return True
216
        if hostname == base[1:]:
217
            return True
218
219
220
221
222
        return False
    else:
        return hostname == possibleMatch

# Try all the names in the certificate
Alexandre's avatar
Alexandre committed
223
def _validate_hostname(hostname, cert):
224
225
    # Complete specification is in RFC 6125. It is long and
    # complicated and I'm not sure we do it perfectly.
226
    (is_addr, family) = is_valid_ip_address(hostname)
227
    hostname = canonicalize(hostname)
Alexandre's avatar
Alexandre committed
228
    for alt_name in _get_certificate_san(cert).split(", "):
229
        if alt_name.startswith("DNS:") and not is_addr:
230
            (start, base) = alt_name.split("DNS:")
231
            base = canonicalize(base)
Alexandre's avatar
Alexandre committed
232
            found = _match_hostname(hostname, base)
233
            if found:
234
                return True
235
236
        elif alt_name.startswith("IP Address:") and is_addr:
            host_i = netaddr.IPAddress(hostname)
237
238
239
            (start, base) = alt_name.split("IP Address:")
            if base.endswith("\n"):
                base = base[:-1]
240
241
242
243
            try:
                base_i = netaddr.IPAddress(base)
            except netaddr.core.AddrFormatError:
                continue # Ignore broken IP addresses in certificates. Are we too liberal?
Alexandre's avatar
Alexandre committed
244
            if host_i == base_i:
245
246
                return True
        else:
247
248
249
            pass # Ignore unknown alternative name types. May be
                 # accept URI alternative names for DoH,
    # According to RFC 6125, we MUST NOT try the Common Name before the Subject Alternative Names.
250
    cn = canonicalize(cert.get_subject().commonName)
Alexandre's avatar
Alexandre committed
251
    found = _match_hostname(hostname, cn)
252
253
    if found:
        return True
254
255
    return False

256
def get_addrfamily(addr, forceIPv4=False, forceIPv6=False):
257
258
259
260
261
262
    """Return the family as a socket object of the address."""

    (is_ip, family) = is_valid_ip_address(addr)

    # thoses checks between the IP family and the command line option
    # might need to land somewhere else
263
    if forceIPv4 and family == 6:
264
        raise FamilyException("You cannot force IPv4 with a litteral IPv6 address (%s)" % addr)
265
    elif forceIPv6 and family == 4:
266
        raise FamilyException("You cannot force IPv6 with a litteral IPv4 address (%s)" % addr)
267

268
    if forceIPv4 or family == 4:
269
        family = socket.AF_INET
270
    elif forceIPv6 or family == 6:
271
272
273
        family = socket.AF_INET6
    else:
        family = 0
274
275
276

    return family

277
278
def check_ip_address(addr, forceIPv4=False, forceIPv6=False):
    return get_addrfamily(addr, forceIPv4, forceIPv6)
279

280
281
282
283
284
285
def dump_data(data, text="data"):
    pref = ' ' * (len(text) - 4)
    print(f'{text}: ', data)
    print(pref, 'hex:', " ".join(format(c, '02x') for c in data))
    print(pref, 'bin:', " ".join(format(c, '08b') for c in data))

286
287
def timeout_connection(signum, frame):
    raise TimeoutConnectionError('Connection timeout')
288

289
290
class TimeoutConnectionError(Exception):
    pass
291

292
293
294
295
296
297
298
299
300
301
class ConnectionException(Exception):
    pass

class ConnectionDOTException(ConnectionException):
    pass

class ConnectionDOHException(ConnectionException):
    pass

class FamilyException(ConnectionException):
Alexandre's avatar
Alexandre committed
302
    pass
303

Alexandre's avatar
Alexandre committed
304
class Request:
Alexandre's avatar
Alexandre committed
305
306
    def __init__(self, qname, qtype='AAAA', use_edns=True, want_dnssec=False, no_ecs=True):
        if no_ecs:
307
308
309
310
             opt = dns.edns.ECSOption(address='', srclen=0) # Disable ECS (RFC 7871, section 7.1.2)
             options = [opt]
        else:
            options = None
311
        self.message = dns.message.make_query(qname, dns.rdatatype.from_text(qtype),
312
                                              use_edns=use_edns, want_dnssec=want_dnssec, options=options)
313
        self.message.flags |= dns.flags.AD # Ask for validation
Alexandre's avatar
Alexandre committed
314
        self.ok = True
315
        self.i = 0 # request's number on the connection (default to the first)
Alexandre's avatar
Alexandre committed
316

317
    def trunc_data(self):
318
        self.data = self.message.to_wire()
319
320
        half = round(len(self.data) / 2)
        self.data = self.data[:half]
321

Alexandre's avatar
Alexandre committed
322
    def to_wire(self):
323
        self.data = self.message.to_wire()
324
325


Alexandre's avatar
Alexandre committed
326
class RequestDOT(Request):
327
    # raising custom exception for each unexpected response might be a good idea
Alexandre's avatar
Alexandre committed
328
    def check_response(self, debug=False):
329
330
        if self.response is None:
            raise Exception("No reply received")
331
        ok = self.ok
332
        if not self.rcode:
333
334
            self.ok = False
            return False
335
        if self.response.id != self.message.id:
336
            self.response = "The ID in the answer does not match the one in the query"
Alexandre's avatar
Alexandre committed
337
338
            if debug:
                self.response += f'"(query id: {self.message.id}) (response id: {self.response.id})'
339
340
341
            self.ok = False
            return False
        return self.ok
342

343
    def store_response(self, rcode, data, size):
344
        self.rcode = True
345
        self.response = dns.message.from_wire(data)
346
        self.response_size = size
347

348

Alexandre's avatar
Alexandre committed
349
class RequestDOH(Request):
Alexandre's avatar
Alexandre committed
350
351
    def __init__(self, qname, qtype='AAAA', use_edns=True, want_dnssec=False, no_ecs=True):
        Request.__init__(self, qname, qtype=qtype, use_edns=use_edns, want_dnssec=want_dnssec, no_ecs=no_ecs)
352
        self.message.id = 0 # DoH requests that
353
        self.post = False
Alexandre's avatar
Alexandre committed
354
        self.head = False
355

356
    # raising custom exception for each unexpected response might be a good idea
Alexandre's avatar
Alexandre committed
357
    def check_response(self, debug=False):
Alexandre's avatar
Alexandre committed
358
        ok = self.ok
359
        if self.rcode == 200:
360
361
            if self.ctype != "application/dns-message":
                self.response = "Content type of the response (\"%s\") invalid" % self.ctype
362
363
                ok = False
            else:
Alexandre's avatar
Alexandre committed
364
                if not self.head:
365
                    try:
Alexandre's avatar
Alexandre committed
366
                        response = dns.message.from_wire(self.response)
367
368
369
                    except dns.message.TrailingJunk: # Not DNS. Should
                        # not happen for a content type
                        # application/dns-message but who knows?
Alexandre's avatar
Alexandre committed
370
371
372
                        self.response = "ERROR Not proper DNS data, trailing junk"
                        if debug:
                            self.response += " \"%s\"" % response
373
374
                        ok = False
                    except dns.name.BadLabelType: # Not DNS.
Alexandre's avatar
Alexandre committed
375
376
377
                        self.response = "ERROR Not proper DNS data (wrong path in the URL?)"
                        if debug:
                            self.response += " \"%s\"" % response[:100]
378
                        ok = False
Alexandre's avatar
Alexandre committed
379
380
                    else:
                        self.response = response
381
                else:
Alexandre's avatar
Alexandre committed
382
                    if self.response_size == 0:
383
384
                        self.response = "HEAD successful"
                    else:
Alexandre's avatar
Alexandre committed
385
386
387
388
                        data = self.response
                        self.response = "ERROR Body length is not null"
                        if debug:
                            self.response += "\"%s\"" % data[:100]
389
390
391
                        ok = False
        else:
            ok = False
Alexandre's avatar
Alexandre committed
392
            if self.response_size == 0:
393
394
                self.response = "[No details]"
            else:
395
                self.response = self.response
Alexandre's avatar
Alexandre committed
396
        self.ok = ok
Alexandre's avatar
Alexandre committed
397
        return ok
398

399

400
class Connection:
401

402
403
404
    def __init__(self, server, servername=None, connect_to=None,
                 forceIPv4=False, forceIPv6=False, insecure=False,
                 verbose=False, debug=False, dot=False):
405

406
        if dot and not is_valid_hostname(server):
407
            raise ConnectionDOTException("DoT requires a host name or IP address, not \"%s\"" % server)
408

409
        if not dot and not is_valid_url(server):
410
            raise ConnectionDOHException("DoH requires a valid HTTPS URL, not \"%s\"" % server)
411

412
        if forceIPv4 and forceIPv6:
413
            raise ConnectionException("Force IPv4 *or* IPv6 but not both")
414

415
        self.dot = dot
416
        self.server = server
417
418
        self.servername = servername
        if self.servername is not None:
419
            self.check_name_cert = self.servername
420
        else:
421
            self.check_name_cert = self.server
422
        self.verbose = verbose
423
        self.debug = debug
424
        self.insecure = insecure
425
426
        self.forceIPv4 = forceIPv4
        self.forceIPv6 = forceIPv6
Alexandre's avatar
Alexandre committed
427
        self.connect_to = connect_to
428

429
430
    def __str__(self):
        return self.server
431

432
433
    def do_test(self, request):
        # Routine doing one actual test. Returns nothing
Alexandre's avatar
Alexandre committed
434
435
        pass

436

Alexandre's avatar
Alexandre committed
437
class ConnectionDOT(Connection):
Alexandre's avatar
Alexandre committed
438

439
440
441
442
    def __init__(self, server, servername=None, connect_to=None,
                 forceIPv4=False, forceIPv6=False, insecure=False,
                 verbose=False, debug=False,
                 sni=True, key=None, pipelining=False):
Alexandre's avatar
Alexandre committed
443

Alexandre's avatar
Alexandre committed
444
        Connection.__init__(self, server, servername=servername, connect_to=connect_to,
445
446
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, insecure=insecure,
                verbose=verbose, debug=debug, dot=True)
Alexandre's avatar
Alexandre committed
447

448
        self.sni = sni
449
        self.key = key
Alexandre's avatar
Alexandre committed
450
451
        self.pipelining = pipelining
        if self.pipelining:
Alexandre's avatar
Alexandre committed
452
453
454
455
456
457
458
459
            self.all_requests = [] # Currently, we load everything in memory
                                   # since we want to keep everything,
                                   # anyway. May be in the future, if we don't
                                   # want to keep individual results, we'll use
                                   # an iterator to fill a smaller table.
                                   # all_requests is indexed by its rank in the input file.
            self.pending = {} # pending is indexed by the query ID, and its
                              # maximum size is max_in_flight.
460

Alexandre's avatar
Alexandre committed
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
        # establish the connection
        self.connect()

    def connect(self):
        # if connect_to is defined, it means we know the IP address of the
        # server and therefore we can establish a connection with it
        # otherwise we only have a domain name and we should loop on all
        # resolved IPs until a connection can be established
        # getaddrinfo provides a list of resolved IPs, when connect_to is
        # defined this list will have only one element
        # so we can loop on the items until a connection is made
        # the list is converted into a set of tuples to avoid duplicates

        self.success = False

        if self.connect_to is not None: # the server's IP address is known
            addr = self.connect_to
        else:
            addr = self.server # otherwise keep the server name

481
        family = get_addrfamily(addr, forceIPv4=self.forceIPv4, forceIPv6=self.forceIPv6)
Alexandre's avatar
Alexandre committed
482
483
484
485
486
487
488
489
490
491
        addrinfo_list = socket.getaddrinfo(addr, 853, family)
        addrinfo_set = { (addrinfo[4], addrinfo[0]) for addrinfo in addrinfo_list }

        signal.signal(signal.SIGALRM, timeout_connection)

        for addrinfo in addrinfo_set:
            if self.establish_session(addrinfo[0], addrinfo[1]):
                self.success = True
                break
            if self.verbose and self.connect_to is None:
Alexandre's avatar
Alexandre committed
492
                print("Could not connect to %s" % addrinfo[0][0])
Alexandre's avatar
Alexandre committed
493
494
495
496
497
498
499
500
501
502
503
                print("Trying another IP address")

        # we could not establish a connection
        if not self.success:
            # we tried all the resolved IPs
            if self.verbose and self.connect_to is None:
                print("No other IP address")

            error = "Could not connect to \"%s\"" % self.server
            if self.connect_to is not None:
                error += " on %s" % self.connect_to
504
            raise ConnectionDOTException(error)
Alexandre's avatar
Alexandre committed
505
506
507
508
509
510
511

    def establish_session(self, addr, sock_family):
        """Return True if a TLS session is established."""

        self.hasher = hashlib.sha256()

        # start the timer
512
        signal.alarm(TIMEOUT_CONN)
Alexandre's avatar
Alexandre committed
513

514
        self.sock = socket.socket(sock_family, socket.SOCK_STREAM)
Alexandre's avatar
Alexandre committed
515

516
        if self.verbose:
Alexandre's avatar
Alexandre committed
517
518
            print("Connecting to %s ..." % addr[0])

519
520
521
522
523
524
525
526
527
528
529
530
531
        # With typical DoT servers, we *must* use TLS 1.2 (otherwise,
        # do_handshake fails with "OpenSSL.SSL.SysCallError: (-1, 'Unexpected
        # EOF')" Typical HTTP servers are more lax.
        self.context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD)
        if self.insecure:
            self.context.set_verify(OpenSSL.SSL.VERIFY_NONE, lambda *x: True)
        else:
            self.context.set_default_verify_paths()
            self.context.set_verify_depth(4) # Seems ignored
            self.context.set_verify(OpenSSL.SSL.VERIFY_PEER | OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT | \
                                    OpenSSL.SSL.VERIFY_CLIENT_ONCE,
                                    lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
        self.session = OpenSSL.SSL.Connection(self.context, self.sock)
532
        if self.sni:
533
            self.session.set_tlsext_host_name(canonicalize(self.check_name_cert).encode())
Alexandre's avatar
Alexandre committed
534

535
        try:
536
            self.session.connect((addr))
537
538
            self.session.do_handshake()
        except TimeoutConnectionError:
Alexandre's avatar
Alexandre committed
539
540
            if self.verbose:
                print("Timeout")
541
            return False
542
543
544
545
        except OSError:
            if self.verbose:
                print("Cannot connect")
            return False
546
547
        except OpenSSL.SSL.SysCallError as e:
            if self.verbose:
Alexandre's avatar
Alexandre committed
548
                error("OpenSSL error: %s" % e.args[1], exit=False)
549
            return False
Alexandre's avatar
Alexandre committed
550
551
552
553
554
        except OpenSSL.SSL.ZeroReturnError:
            # see #18
            if self.verbose:
                error("Error: The SSL connection has been closed (try with --nosni to avoid sending SNI ?)", exit=False)
            return False
Alexandre's avatar
Alexandre committed
555
556
        except OpenSSL.SSL.Error as e:
            if self.verbose:
Alexandre's avatar
Alexandre committed
557
                error("OpenSSL error: %s" % ', '.join(err[0][2] for err in e.args), exit=False)
Alexandre's avatar
Alexandre committed
558
            return False
Alexandre's avatar
Alexandre committed
559

560
        # RFC 7858, section 4.2 and appendix A
561
        self.cert = self.session.get_peer_certificate()
562
        self.publickey = self.cert.get_pubkey()
563
        if self.debug or self.key is not None:
564
565
566
567
            self.hasher.update(OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
                                                  self.publickey))
            self.digest = self.hasher.digest()
            key_string = base64.standard_b64encode(self.digest).decode()
Alexandre's avatar
Alexandre committed
568
        if self.debug:
569
570
571
572
            print("Certificate #%x for \"%s\", delivered by \"%s\"" % \
                  (self.cert.get_serial_number(),
                   self.cert.get_subject().commonName,
                   self.cert.get_issuer().commonName))
Alexandre's avatar
Alexandre committed
573
            print("Public key is pin-sha256=\"%s\"" % key_string)
Alexandre's avatar
Alexandre committed
574
        if not self.insecure:
575
            if self.key is None:
Alexandre's avatar
Alexandre committed
576
                valid = _validate_hostname(self.check_name_cert, self.cert)
577
                if not valid:
578
579
                    error("Certificate error: \"%s\" is not in the certificate" % (self.check_name_cert), exit=False)
                    return False
580
            else:
581
582
                if key_string != self.key:
                    error("Key error: expected \"%s\", got \"%s\"" % (self.key, key_string), exit=False)
583
                    return False
Alexandre's avatar
Alexandre committed
584
585

        # restore the timer
586
        signal.alarm(0)
Alexandre's avatar
Alexandre committed
587
        # and start a new timer when pipelining requests
Alexandre's avatar
Alexandre committed
588
        if self.pipelining:
589
            self.sock.settimeout(TIMEOUT_READ)
590
        return True
591

592
    def end(self):
593
594
595
        self.session.shutdown()
        self.session.close()

596
597
598
    def send_data(self, data, dump=False):
        if dump:
            dump_data(data, 'data sent')
599
600
601
        length = len(data)
        self.session.send(length.to_bytes(2, byteorder='big') + data)

602
    def receive_data(self, dump=False):
603
604
605
606
        try:
            buf = self.session.recv(2)
        except OpenSSL.SSL.WantReadError:
            return (False, None, None)
607
        size = int.from_bytes(buf, byteorder='big')
608
        data = self.session.recv(size)
609
        if dump:
610
611
            dump_data(data, 'data recv')
        return (True, data, size)
612

613
614
    def send_and_receive(self, request, dump=False):
        self.send_data(request.data, dump=dump)
615
616
        rcode, data, size = self.receive_data(dump=dump)
        request.store_response(rcode, data, size)
617

618
    # this function might need to be move outside
619
    def do_test(self, request, synchronous=True):
620
        self.send_data(request.data)
621
        if synchronous:
622
623
            rcode, data, size = self.receive_data()
            request.store_response(rcode, data, size)
624
            request.check_response(self.debug)
625

626
    # should the pipelining methods be part of ConnectionDOT ?
Alexandre's avatar
Alexandre committed
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    def pipelining_add_request(self, request):
        self.all_requests.append({'request': request, 'response': None}) # No answer yet

    def pipelining_fill_pending(self, index):
        if index < len(self.all_requests):
            request = self.all_requests[index]['request']
            id = request.message.id
            # TODO check there is no duplicate in IDs
            self.pending[id] = (False, index, request)
            self.do_test(request, synchronous = False)

    def pipelining_init_pending(self, max_in_flight):
        for i in range(0, max_in_flight):
            if i == len(self.all_requests):
                break
            self.pipelining_fill_pending(i)
        return i

645
646
    # this method might need to be moved somewhere else in order to avoid
    # calling dns.message.from_wire()
647
    def read_result(self, connection, requests, display_results=True):
648
        rcode, data, size = self.receive_data() # TODO can raise
649
650
                                                    # OpenSSL.SSL.ZeroReturnError
                                                    # if the
651
                                                    # connection was
652
653
                                                    # closed
        if not rcode:
654
            if display_results:
655
656
                print("TIMEOUT")
            return None
657
658
        # TODO remove call to dns.message (use abstraction instead)
        response = dns.message.from_wire(data)
659
660
661
662
        id = response.id
        if id not in requests:
            raise Exception("Received response for ID %s which is unexpected" % id)
        over, rank, request = requests[id]
Alexandre's avatar
Alexandre committed
663
        self.all_requests[rank]['response'] = (rcode, response, size)
664
        requests[id] = (True, rank, request)
665
        if display_results:
666
667
            print()
            print(response)
668
669
        # TODO a timeout if some responses are lost?
        return id
670

671
672
def create_handle(connection):
    def reset_opt_default(handle):
673
674
675
676
677
678
679
        opts = {
                pycurl.NOBODY: False,
                pycurl.POST: False,
                pycurl.POSTFIELDS: '',
                pycurl.URL: ''
               }
        for opt, value in opts.items():
680
            handle.setopt(opt, value)
681

682
    def prepare(handle, connection, request):
683
        if not connection.multistreams:
684
            handle.reset_opt_default(handle)
685
        if request.post:
686
687
688
            handle.setopt(pycurl.POST, True)
            handle.setopt(pycurl.POSTFIELDS, request.data)
            handle.setopt(pycurl.URL, connection.server)
689
        else:
690
691
692
693
694
            handle.setopt(pycurl.HTTPGET, True) # automatically sets CURLOPT_NOBODY to 0
            if request.head:
                handle.setopt(pycurl.NOBODY, True)
            dns_req = base64.urlsafe_b64encode(request.data).decode('UTF8').rstrip('=')
            handle.setopt(pycurl.URL, connection.server + ("?dns=%s" % dns_req))
695
696
697
        handle.buffer = io.BytesIO()
        handle.setopt(pycurl.WRITEDATA, handle.buffer)
        handle.request = request
698
699
700
701
702

    handle = pycurl.Curl()
    # Does not work if pycurl was not compiled with nghttp2 (recent Debian
    # packages are OK) https://github.com/pycurl/pycurl/issues/477
    handle.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
Alexandre's avatar
Alexandre committed
703
    if connection.debug:
704
705
706
707
708
709
710
711
        handle.setopt(pycurl.VERBOSE, True)
    if connection.insecure:
        handle.setopt(pycurl.SSL_VERIFYPEER, False)
        handle.setopt(pycurl.SSL_VERIFYHOST, False)
    if connection.forceIPv4:
        handle.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
    if connection.forceIPv6:
        handle.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V6)
Alexandre's avatar
Alexandre committed
712
713
    if connection.connect_to is not None:
        handle.setopt(pycurl.CONNECT_TO, ["::[%s]:443" % connection.connect_to,])
714
715
716
717
718
    handle.setopt(pycurl.HTTPHEADER,
            ["Accept: application/dns-message", "Content-type: application/dns-message"])
    handle.reset_opt_default = reset_opt_default
    handle.prepare = prepare
    return handle
719
720


Alexandre's avatar
Alexandre committed
721
class ConnectionDOH(Connection):
722
723
724
725
726
727

    def __init__(self, server, servername=None, connect_to=None,
                 forceIPv4=False, forceIPv6=False,
                 insecure=False, verbose=False, debug=False,
                 multistreams=False):

Alexandre's avatar
Alexandre committed
728
        Connection.__init__(self, server, servername=servername, connect_to=connect_to,
729
730
731
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, insecure=insecure,
                verbose=verbose, debug=debug, dot=False)

732
        self.url = server
Alexandre's avatar
Alexandre committed
733
        self.multistreams = multistreams
Alexandre's avatar
Alexandre committed
734
735
736
737

        # temporary tweak to check that the ip family is coherent with
        # user choice on forced IP
        if self.connect_to:
738
            check_ip_address(self.connect_to, forceIPv4=self.forceIPv4, forceIPv6=self.forceIPv6)
Alexandre's avatar
Alexandre committed
739

740
741
        if self.multistreams:
            self.multi = self.create_multi()
742
            self.all_handles = []
743
            self.finished = { 'http': {} }
Alexandre's avatar
Alexandre committed
744
        else:
745
            self.curl_handle = create_handle(self)
746

747
748
749
750
    def create_multi(self):
        multi = pycurl.CurlMulti()
        multi.setopt(pycurl.M_MAX_HOST_CONNECTIONS, 1)
        return multi
751

752
753
754
755
756
    def init_multi(self):
        # perform a first query alone
        # to establish the connection and hence avoid starting
        # the transfer of all the other queries simultaneously
        # query the root NS because this should not impact the resover cache
Alexandre's avatar
Alexandre committed
757
        if self.verbose:
758
759
760
761
            print("Establishing multistreams connection...")
        request = create_request('.', qtype='NS', dot=False)
        try:
            self.do_test(request, synchronous=False)
762
        except OpenSSL.SSL.Error as e:
763
764
            ok = False
            error(e)
765
        self.perform_multi(silent=True, display_results=False, show_time=False)
766
767
768
        self.all_handles = []
        self.finished = { 'http': {} }

769
770
771
    def end(self):
        if not self.multistreams:
            self.curl_handle.close()
772
        else:
773
            self.remove_handles()
774
            self.multi.close()
775

776
777
778
779
780
781
782
    def remove_handles(self):
        n, handle_success, handle_fail = self.multi.info_read()
        handles = handle_success + handle_fail
        for h in handles:
            h.close()
            self.multi.remove_handle(h)

783
    def perform_multi(self, silent=False, display_results=True, show_time=False):
784
785
786
787
788
789
790
791
792
793
        while 1:
            ret, num_handles = self.multi.perform()
            if ret != pycurl.E_CALL_MULTI_PERFORM:
                break
        while num_handles:
            ret = self.multi.select(1.0)
            if ret == -1:
                continue
            while 1:
                ret, num_handles = self.multi.perform()
794
795
                n, handle_pass, handle_fail = self.multi.info_read()
                for handle in handle_pass:
796
                    self.read_result_handle(handle, silent=silent, display_results=display_results, show_time=show_time)
797
798
                if ret != pycurl.E_CALL_MULTI_PERFORM:
                    break
799
800
        n, handle_pass, handle_fail = self.multi.info_read()
        for handle in handle_pass:
801
            self.read_result_handle(handle, silent=silent, display_results=display_results, show_time=show_time)
802

803
804
805
    def send(self, handle):
        handle.buffer = io.BytesIO()
        handle.setopt(pycurl.WRITEDATA, handle.buffer)
Alexandre's avatar
Alexandre committed
806
        try:
807
            handle.perform()
Alexandre's avatar
Alexandre committed
808
809
        except pycurl.error as e:
            error(e.args[1])
810

811
812
813
    def receive(self, handle):
        request = handle.request
        body = handle.buffer.getvalue()
814
        body_size = len(body)
815
        http_code = handle.getinfo(pycurl.RESPONSE_CODE)
Alexandre's avatar
Alexandre committed
816
817
        handle.time = handle.getinfo(pycurl.TOTAL_TIME)
        handle.pretime = handle.getinfo(pycurl.PRETRANSFER_TIME)
818
        try:
Alexandre's avatar
Alexandre committed
819
820
            content_type = handle.getinfo(pycurl.CONTENT_TYPE)
        except TypeError: # This is the exception we get if there is no Content-Type: (for intance in response to HEAD requests)
821
            content_type = None
822
        request.response = body
Alexandre's avatar
Alexandre committed
823
        request.response_size = body_size
824
825
        request.rcode = http_code
        request.ctype = content_type
826
        handle.buffer.close()
827

Alexandre's avatar
Alexandre committed
828
    def send_and_receive(self, handle, dump=False):
829
830
        self.send(handle)
        self.receive(handle)
831

832
    def read_result_handle(self, handle, silent=False, display_results=True, show_time=False):
833
834
        self.receive(handle)
        handle.request.check_response()
835
        if not silent and show_time:
836
            self.print_time(handle)
837
838
839
840
        try:
            self.finished['http'][handle.request.rcode] += 1
        except KeyError:
            self.finished['http'][handle.request.rcode] = 1
841
        if not silent and display_results:
842
843
844
            print("Return code %s (%.2f ms):" % (handle.request.rcode,
                (handle.time - handle.pretime) * 1000))
            print(f"{handle.request.response}\n")
Alexandre's avatar
Alexandre committed
845
846
        handle.close()
        self.multi.remove_handle(handle)
847

848
    def read_results(self, display_results=True, show_time=False):
849
        for handle in self.all_handles:
850
            self.read_result_handle(handle, display_results=display_results, show_time=show_time)
851

852
853
854
855
856
857
858
    def print_time(self, handle):
        print(f'{handle.request.i:3d}', end='   ')
        print(f'({handle.request.rcode})', end='   ')
        print(f'{handle.pretime * 1000:8.3f} ms', end='  ')
        print(f'{handle.time * 1000:8.3f} ms', end='  ')
        print(f'{(handle.time - handle.pretime) * 1000:8.3f} ms')

859
860
    def do_test(self, request, synchronous=True):
        if synchronous:
861
            handle = self.curl_handle
862
        else:
863
            handle = create_handle(self)
864
            self.all_handles.append(handle)
865
        handle.prepare(handle, self, request)
Alexandre's avatar
Alexandre committed
866
        if synchronous:
867
            self.send_and_receive(handle)
Alexandre's avatar
Alexandre committed
868
            request.check_response(self.debug)
869
        else:
870
            self.multi.add_handle(handle)
Alexandre's avatar
Alexandre committed
871

Alexandre's avatar
Alexandre committed
872
873
874
875
876

def get_next_domain(input_file):
    name, rtype = 'framagit.org', 'AAAA'
    line = input_file.readline()
    if line[:-1] == "":
Alexandre's avatar
Alexandre committed
877
        error("Not enough data in %s for the %i tests" % (opts.ifile, opts.tests))
Alexandre's avatar
Alexandre committed
878
879
880
881
882
883
884
    if line.find(' ') == -1:
        name = line[:-1]
        rtype = 'AAAA'
    else:
        (name, rtype) = line.split()
    return name, rtype

885
def print_result(connection, request, prefix=None, display_err=True):
Alexandre's avatar
Alexandre committed
886
    ok = request.ok
887
888
889
890
891
892
893
    dot = connection.dot
    server = connection.server
    rcode = request.rcode
    msg = request.response
    size = request.response_size
    if (dot and rcode) or (not dot and rcode == 200):
        if not monitoring:
Alexandre's avatar
Alexandre committed
894
            if not opts.dot and opts.show_time:
895
                connection.print_time(connection.curl_handle)
Alexandre's avatar
Alexandre committed
896
            if opts.display_results and (not opts.check or opts.verbose):
897
898
                print(msg)
        else:
Alexandre's avatar
Alexandre committed
899
            if opts.expect is not None and opts.expect not in str(request.response):
900
                ok = False
Alexandre's avatar
Alexandre committed
901
                print("%s Cannot find \"%s\" in response" % (server, opts.expect))
902
                sys.exit(STATE_CRITICAL)
903
            if ok and size is not None and size > 0:
904
                print("%s OK - %s" % (server, "No error for %s/%s, %i bytes received" % (name, opts.rtype, size)))
905
906
907
908
909
            else:
                print("%s OK - %s" % (server, "No error"))
            sys.exit(STATE_OK)
    else:
        if not monitoring:
910
            if display_err:
Alexandre's avatar
Alexandre committed
911
                if opts.check: