homer.py 33.9 KB
Newer Older
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
1
2
#!/usr/bin/env python3

3
4
5
6
7
# Homer is a DoH (DNS-over-HTTPS) and DoT (DNS-over-TLS) client. Its
# main purpose is to test DoH and DoT resolvers. Reference site is
# <https://framagit.org/bortzmeyer/homer/> See author, documentation,
# etc, there, or in the README.md included with the distribution.

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
8
9
10
11
12
13
# http://pycurl.io/docs/latest
import pycurl

# http://www.dnspython.org/
import dns.message

14
15
16
# https://github.com/drkjam/netaddr/
import netaddr

17
18
19
20
21
# Octobre 2019: the Python GnuTLS bindings don't work with Python 3. So we use OpenSSL.
# https://www.pyopenssl.org/
# https://pyopenssl.readthedocs.io/
import OpenSSL

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
22
23
24
25
26
27
import io
import sys
import base64
import getopt
import urllib.parse
import time
28
29
import socket
import ctypes
30
import re
31
import os.path
32
33
import hashlib
import base64
34
import signal
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
35

36
# Values that can be changed from the command line
37
dot = False # DoH by default
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
38
39
verbose = False
insecure = False
40
post = False
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
41
head = False
42
43
dnssec = False
edns = True
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
44
rtype = 'AAAA'
45
vhostname = None
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
46
tests = 1 # Number of repeated tests
47
ifile = None # Input file
48
delay = None
49
50
51
forceIPv4 = False
forceIPv6 = False
connectTo = None
52
check = False
53
54
mandatory_level = None
check_additional = True
55
56
57
# Monitoring plugin only:
host = None
path = None
58
expect = None
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
59

60
61
62
# Do not change these
re_host = re.compile(r'^([0-9a-z][0-9a-z-\.]*)|([0-9:]+)|([0-9\.])$')

63
64
65
66
67
68
69
# For the monitoring plugin
STATE_OK = 0
STATE_WARNING = 1
STATE_CRITICAL = 2
STATE_UNKNOWN = 3
STATE_DEPENDENT = 4

Alexandre's avatar
Alexandre committed
70
71
72
73
# For the check option
DOH_GET = 0
DOH_POST = 1
DOH_HEAD = 2
74
75
# Is the test mandatory?
mandatory_levels = {"legal": 30, "necessary": 20, "nicetohave": 10}
Alexandre's avatar
Alexandre committed
76

77
78
TIMEOUT_CONN = 2

79
def error(msg=None, exit=True):
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
80
81
    if msg is None:
        msg = "Unknown error"
82
83
    if monitoring:
        print("%s: %s" % (url, msg))
84
85
        if exit:
            sys.exit(STATE_CRITICAL)
86
    else:
Alexandre's avatar
Alexandre committed
87
88
89
        print(msg, file=sys.stderr)
        if check:
            print('KO')
90
91
        if exit:
            sys.exit(1)
Alexandre's avatar
Alexandre committed
92

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
93
94
95
def usage(msg=None):
    if msg:
        print(msg,file=sys.stderr)
96
97
    print("Usage: %s [--dot] url-or-servername domain-name [DNS type]" % sys.argv[0], file=sys.stderr)
    print("See the README.md for more details.", file=sys.stderr)
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
98

99
def is_valid_hostname(name):
100
    name = canonicalize(name)
101
    return re_host.search(name)
102

103
104
105
106
107
108
109
110
111
def canonicalize(hostname):
    result = hostname.lower()
    # TODO handle properly the case where it fails with UnicodeError
    # (two consecutive dots for instance) to get a custom exception
    result = result.encode('idna').decode()
    if result[len(result)-1] == '.':
        result = result[:-1]
    return result

112
113
114
115
def is_valid_ip_address(addr):
    try:
        baddr = netaddr.IPAddress(addr)
    except netaddr.core.AddrFormatError:
116
117
        return (False, None)
    return (True, baddr.version)
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def is_valid_url(url):
  try:
    result = urllib.parse.urlparse(url) # A very poor validation, many
    # errors (for instance whitespaces, IPv6 address litterals without
    # brackets...) are ignored.
    return (result.scheme=="https" and result.netloc != "")
  except ValueError:
    return False

def get_certificate_san(x509cert):
    san = ""
    ext_count = x509cert.get_extension_count()
    for i in range(0, ext_count):
        ext = x509cert.get_extension(i)
        if "subjectAltName" in str(ext.get_short_name()):
            san = str(ext)
    return san

137
138
139
140
141
142
143
144
145
146
147
# Try one possible name. Names must be already canonicalized.
def match_hostname(hostname, possibleMatch):
    if possibleMatch.startswith("*."): # Wildcard
        base = possibleMatch[1:] # Skip the star
        # RFC 6125 says that we MAY accept left-most labels with
        # wildcards included (foo*bar). We don't do it here.
        try:
            (first, rest) = hostname.split(".", maxsplit=1)
        except ValueError: # One-label name
            rest = hostname
        if rest == base[1:]:
148
            return True
149
        if hostname == base[1:]:
150
            return True
151
152
153
154
155
156
157
158
        return False
    else:
        return hostname == possibleMatch

# Try all the names in the certificate
def validate_hostname(hostname, cert):
    # Complete specification is in RFC 6125. It is long and
    # complicated and I'm not sure we do it perfectly.
159
    (is_addr, family) = is_valid_ip_address(hostname)
160
    hostname = canonicalize(hostname)
161
    for alt_name in get_certificate_san(cert).split(", "):
162
        if alt_name.startswith("DNS:") and not is_addr:
163
            (start, base) = alt_name.split("DNS:")
164
            base = canonicalize(base)
165
166
            found = match_hostname(hostname, base)
            if found:
167
                return True
168
169
        elif alt_name.startswith("IP Address:") and is_addr:
            host_i = netaddr.IPAddress(hostname)
170
171
172
            (start, base) = alt_name.split("IP Address:")
            if base.endswith("\n"):
                base = base[:-1]
173
174
175
176
            try:
                base_i = netaddr.IPAddress(base)
            except netaddr.core.AddrFormatError:
                continue # Ignore broken IP addresses in certificates. Are we too liberal?
Alexandre's avatar
Alexandre committed
177
            if host_i == base_i:
178
179
                return True
        else:
180
181
182
            pass # Ignore unknown alternative name types. May be
                 # accept URI alternative names for DoH,
    # According to RFC 6125, we MUST NOT try the Common Name before the Subject Alternative Names.
183
    cn = canonicalize(cert.get_subject().commonName)
184
185
186
    found = match_hostname(hostname, cn)
    if found:
        return True
187
188
    return False

189
190
191
192
193
194
def timeout_connection(signum, frame):
    raise TimeoutConnectionError('Connection timeout')

class TimeoutConnectionError(Exception):
    pass

195

Alexandre's avatar
Alexandre committed
196
197
198
199
class CustomException(Exception):
    pass


Alexandre's avatar
Alexandre committed
200
class Request:
201
    def __init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec):
202
203
        self.message = dns.message.make_query(qname, dns.rdatatype.from_text(qtype),
                use_edns=use_edns, want_dnssec=want_dnssec)
204
        self.message.flags |= dns.flags.AD # Ask for validation
Alexandre's avatar
Alexandre committed
205
        self.ok = True
Alexandre's avatar
Alexandre committed
206

207
208
209
210
211
    def trunc_data(self):
        self.data = self.message.to_wire()
        half = round(len(self.data) / 2)
        self.data = self.data[:half]

Alexandre's avatar
Alexandre committed
212
    def to_wire(self):
213
214
        self.data = self.message.to_wire()

Alexandre's avatar
Alexandre committed
215
216

class RequestDoT(Request):
217
218
219
220
221
    def check_response(self):
        if self.response.id != self.message.id:
            raise Exception("The ID in the answer does not match the one in the query")


Alexandre's avatar
Alexandre committed
222
class RequestDoH(Request):
223
    def __init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec):
Alexandre's avatar
Alexandre committed
224
        Request.__init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec)
225
        self.message.id = 0 # DoH requests that
226
        self.post = False
Alexandre's avatar
Alexandre committed
227
        self.head = False
228

229
    def check_response(self):
Alexandre's avatar
Alexandre committed
230
        ok = self.ok
231
        if self.rcode == 200:
232
233
            if self.ctype != "application/dns-message":
                self.response = "Content type of the response (\"%s\") invalid" % self.ctype
234
235
                ok = False
            else:
Alexandre's avatar
Alexandre committed
236
                if not self.head:
237
                    try:
238
                        self.response = dns.message.from_wire(self.response)
239
240
241
                    except dns.message.TrailingJunk: # Not DNS. Should
                        # not happen for a content type
                        # application/dns-message but who knows?
242
                        self.response = "ERROR Not proper DNS data, trailing junk \"%s\"" % self.response
243
244
                        ok = False
                    except dns.name.BadLabelType: # Not DNS.
245
                        self.response = "ERROR Not proper DNS data (wrong path in the URL?) \"%s\"" % self.response[:100]
246
247
                        ok = False
                else:
Alexandre's avatar
Alexandre committed
248
                    if self.response_size == 0:
249
250
                        self.response = "HEAD successful"
                    else:
251
                        self.response = "ERROR Body length is not null \"%s\"" % self.response[:100]
252
253
254
                        ok = False
        else:
            ok = False
Alexandre's avatar
Alexandre committed
255
            if self.response_size == 0:
256
257
                self.response = "[No details]"
            else:
258
                self.response = self.response
Alexandre's avatar
Alexandre committed
259
        self.ok = ok
Alexandre's avatar
Alexandre committed
260
        return ok
261

262

263
class Connection:
264
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
265
                 dot=dot, verbose=verbose, insecure=insecure):
266
        if dot and not is_valid_hostname(server):
267
            error("DoT requires a host name or IP address, not \"%s\"" % server)
268
        if not dot and not is_valid_url(server):
269
            error("DoH requires a valid HTTPS URL, not \"%s\"" % server)
270
        if forceIPv4 and forceIPv6:
Alexandre's avatar
Alexandre committed
271
            raise CustomException("Force IPv4 *or* IPv6 but not both")
272
        self.dot = dot
273
        self.server = server
274
275
        self.servername = servername
        if self.servername is not None:
Alexandre's avatar
Alexandre committed
276
            self.check = self.servername
277
        else:
Alexandre's avatar
Alexandre committed
278
            self.check = self.server
279
280
281
        self.dot = dot
        self.verbose = verbose
        self.insecure = insecure
282

283
284
    def __str__(self):
        return self.server
285

286
287
288
    def check_ip_address(self, addr):
        (is_addr, self.family) = is_valid_ip_address(addr)
        if not is_addr and not self.dot:
Alexandre's avatar
Alexandre committed
289
            raise CustomException("%s is not IPv4 and not IPv6" % addr)
290
        if forceIPv4 and self.family == 6:
Alexandre's avatar
Alexandre committed
291
            raise CustomException("You cannot force IPv4 with a litteral IPv6 address (%s)" % addr)
292
        elif forceIPv6 and self.family == 4:
Alexandre's avatar
Alexandre committed
293
            raise CustomException("You cannot force IPv6 with a litteral IPv4 address (%s)" % addr)
294
295
296
297
298
299
300
301
302
        if forceIPv4 or self.family == 4:
            self.family = socket.AF_INET
            self.repraddress = addr
        elif forceIPv6 or self.family == 6:
            self.family = socket.AF_INET6
            self.repraddress = f'[{addr}]'
        else:
            self.family = 0

Alexandre's avatar
Alexandre committed
303
    def do_test(self, qname, qtype=rtype):
304
        # Routine doing one actual test. Returns a Request object
Alexandre's avatar
Alexandre committed
305
306
        pass

307
308
309

class ConnectionDoT(Connection):
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
310
                 verbose=verbose, insecure=insecure):
311
        Connection.__init__(self, server, servername=servername, connect=connect,
312
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=True,
313
                verbose=verbose, insecure=insecure)
314
315
316
317
318
        if connect is not None:
            addr = connect
        else:
            addr = self.server
        self.check_ip_address(addr)
Alexandre's avatar
Alexandre committed
319
        addrinfo_list = socket.getaddrinfo(addr, 853, self.family)
320
321
        addrinfo_set = { (addrinfo[4], addrinfo[0]) for addrinfo in addrinfo_list }
        signal.signal(signal.SIGALRM, timeout_connection)
Alexandre's avatar
Alexandre committed
322
        self.success = False
323
324
        for addrinfo in addrinfo_set:
            self.hasher = hashlib.sha256()
325
            if self.connect(addrinfo[0], addrinfo[1]):
Alexandre's avatar
Alexandre committed
326
                self.success = True
327
                break
Alexandre's avatar
Alexandre committed
328
            if self.verbose and connect is None:
Alexandre's avatar
Alexandre committed
329
                print("Trying another IP address")
Alexandre's avatar
Alexandre committed
330
331
        if not self.success:
            if self.verbose and connect is None:
Alexandre's avatar
Alexandre committed
332
                print("No other IP address")
Alexandre's avatar
Alexandre committed
333
334
335
            if connect is None:
                error(f'Could not connect to "{server}"')
            else:
336
                print(f'Could not connect to "{server}" on {connect}')
Alexandre's avatar
Alexandre committed
337

338
339
340
341
342

    def connect(self, addr, sock_family):
        signal.alarm(TIMEOUT_CONN)
        self.addr = addr
        self.sock = socket.socket(sock_family, socket.SOCK_STREAM)
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        if self.verbose:
            print("Connecting to %s ..." % str(self.addr))
        # With typical DoT servers, we *must* use TLS 1.2 (otherwise,
        # do_handshake fails with "OpenSSL.SSL.SysCallError: (-1, 'Unexpected
        # EOF')" Typical HTTP servers are more lax.
        self.context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD)
        if self.insecure:
            self.context.set_verify(OpenSSL.SSL.VERIFY_NONE, lambda *x: True)
        else:
            self.context.set_default_verify_paths()
            self.context.set_verify_depth(4) # Seems ignored
            self.context.set_verify(OpenSSL.SSL.VERIFY_PEER | OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT | \
                                    OpenSSL.SSL.VERIFY_CLIENT_ONCE,
                                    lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
        self.session = OpenSSL.SSL.Connection(self.context, self.sock)
        self.session.set_tlsext_host_name(canonicalize(self.check).encode()) # Server Name Indication (SNI)
359
360
361
362
363
        try:
            self.session.connect((self.addr))
            # TODO We may here have exceptions such as OpenSSL.SSL.ZeroReturnError
            self.session.do_handshake()
        except TimeoutConnectionError:
Alexandre's avatar
Alexandre committed
364
365
            if self.verbose:
                print("Timeout")
366
            return False
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        self.cert = self.session.get_peer_certificate()
        # RFC 7858, section 4.2 and appendix A
        self.publickey = self.cert.get_pubkey()
        if verbose:
            print("Certificate #%x for \"%s\", delivered by \"%s\"" % \
                  (self.cert.get_serial_number(),
                   self.cert.get_subject().commonName,
                   self.cert.get_issuer().commonName))
            self.hasher.update(OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
                                                  self.publickey))
            self.digest = self.hasher.digest()
            print("Public key is pin-sha256=\"%s\"" % \
                  base64.standard_b64encode(self.digest).decode())
        if not insecure:
            valid = validate_hostname(self.check, self.cert)
            if not valid:
                error("Certificate error: \"%s\" is not in the certificate" % (self.check))
384
        signal.alarm(0)
385
        return True
386

387
    def end(self):
388
389
390
        self.session.shutdown()
        self.session.close()

391
392
393
394
    def send_data(self, data):
        length = len(data)
        self.session.send(length.to_bytes(2, byteorder='big') + data)

395
    def receive_data(self, request):
Alexandre's avatar
Alexandre committed
396
        buf = self.session.recv(2)
397
398
399
400
401
402
403
404
        request.response_size = int.from_bytes(buf, byteorder='big')
        buf = self.session.recv(request.response_size)
        request.response = dns.message.from_wire(buf)
        request.rcode = True

    def send_and_receive(self, request):
        self.send_data(request.data)
        self.receive_data(request)
405
406

    def do_test(self, qname, qtype=rtype):
407
        request = RequestDoT(qname, qtype, want_dnssec=dnssec, use_edns=edns)
408
        request.to_wire()
409
410
        self.send_and_receive(request)
        request.check_response()
411
        return request
Alexandre's avatar
Alexandre committed
412

413
414
415

class ConnectionDoH(Connection):
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
416
                 verbose=verbose, insecure=insecure):
417
        Connection.__init__(self, server, servername=servername, connect=connect,
418
419
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=False,
                verbose=verbose, insecure=insecure)
420
        self.url = server
Alexandre's avatar
Alexandre committed
421
422
        self.connect = connect

423
424
425
426
427
    def create_handle(self):
        self.curl = pycurl.Curl()
        # Does not work if pycurl was not compiled with nghttp2 (recent Debian
        # packages are OK) https://github.com/pycurl/pycurl/issues/477
        self.curl.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
428
        if self.verbose:
429
            self.curl.setopt(pycurl.VERBOSE, True)
430
        if self.insecure:
431
432
            self.curl.setopt(pycurl.SSL_VERIFYPEER, False)
            self.curl.setopt(pycurl.SSL_VERIFYHOST, False)
433
        if forceIPv4:
434
            self.curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
435
        if forceIPv6:
436
            self.curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V6)
Alexandre's avatar
Alexandre committed
437
438
        if self.connect is not None:
            self.check_ip_address(self.connect)
439
            self.curl.setopt(pycurl.CONNECT_TO, [f'::{self.repraddress}:443',])
Alexandre's avatar
Alexandre committed
440
        self.curl.setopt(pycurl.HTTPHEADER, ["Accept: application/dns-message", "Content-type: application/dns-message"])
441
442

    def end(self):
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        self.curl.close()

    def set_opt(self, opt, value):
        self.curl.setopt(opt, value)

    def reset_opt_default(self):
        opts = {
                pycurl.NOBODY: False,
                pycurl.POST: False,
                pycurl.POSTFIELDS: '',
                pycurl.URL: ''
               }
        for opt, value in opts.items():
            self.set_opt(opt, value)

    def prepare(self, request):
        try:
            self.reset_opt_default()
        except AttributeError:
            self.create_handle()
463
        if request.post:
464
            self.prepare_post(request)
465
        elif request.head:
466
467
468
            self.prepare_head(request)
        else:
            self.prepare_get(request)
469

470
471
472
473
    def prepare_get(self, request):
        self.set_opt(pycurl.HTTPGET, True)
        dns_req = base64.urlsafe_b64encode(request.data).decode('UTF8').rstrip('=')
        self.set_opt(pycurl.URL, self.server + ("?dns=%s" % dns_req))
474

475
    def prepare_post(self, request):
476
        request.post = True
477
478
479
        self.set_opt(pycurl.POST, True)
        self.set_opt(pycurl.POSTFIELDS, request.data)
        self.set_opt(pycurl.URL, self.server)
480

481
    def prepare_head(self, request):
482
        request.head = True
483
484
485
486
487
488
        self.prepare_get(request)
        self.set_opt(pycurl.NOBODY, True)

    def perform(self):
        self.buffer = io.BytesIO()
        self.set_opt(pycurl.WRITEDATA, self.buffer)
Alexandre's avatar
Alexandre committed
489
490
491
492
        try:
            self.curl.perform()
        except pycurl.error as e:
            error(e.args[1])
493
494
495
496
497

    def receive(self, request):
        body = self.buffer.getvalue()
        body_size = len(body)
        http_code = self.curl.getinfo(pycurl.RESPONSE_CODE)
498
499
500
501
        try:
            content_type = self.curl.getinfo(pycurl.CONTENT_TYPE)
        except TypeError: # This is the exception we get if there is no Content-Type: (for intance in rsponse to HEAD requests)
            content_type = None
502
        request.response = body
Alexandre's avatar
Alexandre committed
503
        request.response_size = body_size
504
505
506
507
508
509
510
511
        request.rcode = http_code
        request.ctype = content_type
        self.buffer.close()

    def send_and_receive(self, request):
        self.prepare(request)
        self.perform()
        self.receive(request)
512

Alexandre's avatar
Alexandre committed
513
    def do_test(self, qname, qtype=rtype):
514
        request = RequestDoH(qname, qtype, want_dnssec=dnssec, use_edns=edns)
515
516
        request.head = head
        request.post = post
517
        request.to_wire()
518
519
        self.send_and_receive(request)
        request.check_response()
520
        return request
Alexandre's avatar
Alexandre committed
521

Alexandre's avatar
Alexandre committed
522
523
524
525
526
527
528
529
530
531
532
533
534

def get_next_domain(input_file):
    name, rtype = 'framagit.org', 'AAAA'
    line = input_file.readline()
    if line[:-1] == "":
        error("Not enough data in %s for the %i tests" % (ifile, tests))
    if line.find(' ') == -1:
        name = line[:-1]
        rtype = 'AAAA'
    else:
        (name, rtype) = line.split()
    return name, rtype

535
def print_result(connection, request, prefix=None, display_err=True):
Alexandre's avatar
Alexandre committed
536
    ok = request.ok
537
538
539
540
541
542
543
544
545
546
    dot = connection.dot
    server = connection.server
    rcode = request.rcode
    msg = request.response
    size = request.response_size
    if (dot and rcode) or (not dot and rcode == 200):
        if not monitoring:
            if not check or verbose:
                print(msg)
        else:
547
548
549
550
            if expect is not None and expect not in str(request.response):
                ok = False
                print("%s Cannot find \"%s\" in response" % (server, expect))
                sys.exit(STATE_CRITICAL)
551
552
553
554
555
556
557
            if size is not None and size > 0:
                print("%s OK - %s" % (server, "No error for %s/%s, %i bytes received" % (name, rtype, size)))
            else:
                print("%s OK - %s" % (server, "No error"))
            sys.exit(STATE_OK)
    else:
        if not monitoring:
558
559
560
561
562
563
564
565
566
567
568
            if display_err:
                if prefix:
                    print(prefix, end=': ', file=sys.stderr)
                if dot:
                    print("Error: %s" % msg, file=sys.stderr)
                else:
                   try:
                       msg = msg.decode()
                   except (UnicodeDecodeError, AttributeError):
                       pass # Sometimes, msg can be binary, or Latin-1
                   print("HTTP error %i: %s" % (rcode, msg), file=sys.stderr) 
569
570
571
572
573
574
575
576
577
        else:
            if not dot:
                print("%s HTTP error - %i: %s" % (server, rcode, msg))
            else:
                print("%s Error - %i: %s" % (server, rcode, msg))
            sys.exit(STATE_CRITICAL)
        ok = False
    return ok

Alexandre's avatar
Alexandre committed
578
579
580
581
582
583
584
585
586
587
588
589
590
591
def create_request(dot=dot, trunc=False, **req_args):
    if dot:
        request = RequestDoT(**req_args)
    else:
        request = RequestDoH(**req_args)
    if trunc:
        request.trunc_data()
    else:
        request.to_wire()
    return request

def create_requests_list(dot=dot, **req_args):
    requests = []
    if dot:
592
593
594
595
596
597
        requests.append(('Test 1', create_request(dot=dot, **req_args),
                        mandatory_levels["legal"]))
        requests.append(('Test 2', create_request(dot=dot, **req_args),
                         mandatory_levels["necessary"])) # RFC 7858,
        # section 3.3, SHOULD accept several requests on one connection.
        # TODO we miss the tests of pipelining and out-of-order.
Alexandre's avatar
Alexandre committed
598
    else:
599
600
601
602
603
604
605
606
        requests.append(('Test GET', create_request(**req_args), DOH_GET,
                         mandatory_levels["legal"])) # RFC 8484, section 4.1
        requests.append(('Test POST', create_request(**req_args), DOH_POST,
                         mandatory_levels["legal"])) # RFC 8484, section 4.1
        requests.append(('Test HEAD', create_request(**req_args), DOH_HEAD,
                         mandatory_levels["nicetohave"])) # HEAD
        # method is not mentioned in RFC 8484 (see section 4.1), so
        # just "nice to have".
Alexandre's avatar
Alexandre committed
607
    return requests
608

609
def run_check_default(connection):
Alexandre's avatar
Alexandre committed
610
    ok = True
611
612
613
614
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    requests = create_requests_list(dot=dot, **req_args)
    for request_pack in requests:
        if dot:
615
            test_name, request, mandatory = request_pack
616
        else:
617
            test_name, request, method, mandatory = request_pack
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        if verbose:
            print(test_name)
        if not dot:
            if method == DOH_POST:
                request.post = True
            elif method == DOH_HEAD:
                request.head = True
        try:
            connection.send_and_receive(request)
        except CustomException as e:
            ok = False
            error(e)
            break
        request.check_response()
632
633
634
635
        if not print_result(connection, request, prefix=test_name, display_err=False):
            if mandatory >= mandatory_level:
                print_result(connection, request, prefix=test_name, display_err=True)
                ok = False
636
            break
Alexandre's avatar
Alexandre committed
637
    return ok
638

639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
def run_check_mime(connection, accept="application/dns-message", content_type="application/dns-message"):
    if dot:
        return True
    ok = True
    header = [f"Accept: {accept}", f"Content-type: {content_type}"]
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    request = create_request(**req_args)
    connection.curl.setopt(pycurl.HTTPHEADER, header)
    try:
        connection.send_and_receive(request)
    except CustomException as e:
        ok = False
        error(e)
    request.check_response()
    if not print_result(connection, request, prefix=f"Test Header {', '.join(header)}"):
        ok = False
    default = "application/dns-message"
    default_header = [f"Accept: {default}", f"Content-type: {default}"]
    connection.curl.setopt(pycurl.HTTPHEADER, default_header)
    return ok

660
661
662
663
664
665
666
667
668
669
670
671
def run_check_trunc(connection):
    ok = True
    test_name = 'Test truncated data'
    if verbose:
        print(test_name)
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    if dot:
        request = create_request(dot=dot, trunc=True, **req_args)
    else:
        request = create_request(trunc=True, **req_args)
        request.post = True
    try:
672
        # 8.8.8.8 replies FORMERR but most DoT servers violently shut down the connection (which is legal)
673
674
675
676
        connection.send_and_receive(request)
    except CustomException as e:
        ok = False
        error(e)
677
678
    except OpenSSL.SSL.ZeroReturnError: # This is acceptable
        return ok
679
    request.check_response()
680
681
    if print_result(connection, request, prefix=test_name, display_err=False): # The test must fail, or returns FORMERR. 
        ok = (request.rcode == dns.rcode.FORMERR)
682
683
    return ok

684
def run_check_additionals(connection):
685
686
    if not run_check_trunc(connection):
        return False
687
688
689
690
691
    # The DoH server is right to reject these (Example: 'HTTP
    # error 415: only Content-Type: application/dns-message is
    # supported')
    run_check_mime(connection, accept="text/html")
    run_check_mime(connection, content_type="text/html")
692
693
694
695
696
    return True

def run_check(connection):
    if not run_check_default(connection):
        return False
697
    if check_additional and not run_check_additionals(connection):
698
699
700
        return False
    return True

701
702
703
704
705
706
707
708
def resolved_ips(host, port, family, dot=dot):
    try:
        addr_list = socket.getaddrinfo(host, port, family)
    except socket.gaierror:
        error(f'Could not resolve "{url}"')
    ip_set = { addr[4][0] for addr in addr_list }
    return ip_set

709
# Main program
710
711
712
713
714
715
me = os.path.basename(sys.argv[0])
monitoring = (me == "check_doh" or me == "check_dot")
if not monitoring:
    name = None
    message = None
    try:
716
        optlist, args = getopt.getopt (sys.argv[1:], "hvPkeV:r:f:d:t46",
717
718
                                       ["help", "verbose", "dot", "head",
                                        "insecure", "POST", "vhost=",
719
720
                                        "dnssec", "noedns","repeat=", "file=", "delay=", "v4only", "v6only",
                                        "check", "mandatory-level="])
721
722
723
724
        for option, value in optlist:
            if option == "--help" or option == "-h":
                usage()
                sys.exit(0)
725
726
            elif option == "--dot" or option == "-t":
                dot = True
727
728
            elif option == "--verbose" or option == "-v":
                verbose = True
729
            elif option == "--HEAD" or option == "--head" or option == "-e":
730
                head = True
731
            elif option == "--POST" or option == "--post" or option == "-P":
732
                post = True
733
734
            elif option == "--vhost" or option == "-V":
                vhostname = value
735
736
            elif option == "--insecure" or option == "-k":
                insecure = True
737
738
739
740
            elif option == "--dnssec":
                dnssec = True
            elif option == "--noedns":
                edns = False
741
742
743
744
745
746
747
748
749
750
            elif option == "--repeat" or option == "-r":
                tests = int(value)
                if tests <= 1:
                    error("--repeat needs a value > 1")
            elif option == "--delay" or option == "-d":
                delay = float(value)
                if delay <= 0:
                    error("--delay needs a value > 0")
            elif option == "--file" or option == "-f":
                ifile = value
751
752
753
754
            elif option == "-4" or option == "v4only":
                forceIPv4 = True
            elif option == "-6" or option == "v6only":
                forceIPv6 = True
755
756
            elif option == "--check":
                check = True
757
758
            elif option == "--mandatory-level":
                mandatory_level = value
759
760
761
762
763
764
765
766
767
768
769
770
            else:
                error("Unknown option %s" % option)
    except getopt.error as reason:
        usage(reason)
        sys.exit(1)
    if tests <= 1 and delay is not None:
        error("--delay makes no sense if there is no repetition")
    if post and head:
        usage("POST or HEAD but not both")
        sys.exit(1)
    if dot and (post or head):
        usage("POST or HEAD makes non sense for DoT")
Alexandre's avatar
Alexandre committed
771
        sys.exit(1)
772
773
774
775
776
777
778
779
780
781
    if mandatory_level is not None and \
       mandatory_level not in mandatory_levels.keys():
        usage("Unknown mandatory level \"%s\"" % mandatory_level)
        sys.exit(1)
    if mandatory_level is not None and not check:
        usage("--mandatory-level only makes sense with --check")
        sys.exit(1)
    if mandatory_level is None:
        mandatory_level = "necessary"
    mandatory_level = mandatory_levels[mandatory_level]
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
    if ifile is None and (len(args) != 2 and len(args) != 3):
        usage("Wrong number of arguments")
        sys.exit(1)
    if ifile is not None and len(args) != 1:
        usage("Wrong number of arguments (if --file is used, do not indicate the domain name)")
        sys.exit(1)
    url = args[0]
    if ifile is None:
        name = args[1]
        if len(args) == 3:
            rtype = args[2]
else: # Monitoring plugin
    dot = (me == "check_dot")
    name = None
    try:
797
        optlist, args = getopt.getopt (sys.argv[1:], "H:n:p:V:t:e:Pih46")
798
799
800
801
        for option, value in optlist:
            if option == "-H":
                host = value
            elif option == "-V":
Alexandre's avatar
Alexandre committed
802
                vhostname = value
803
804
805
806
            elif option == "-n":
                name = value
            elif option == "-t":
                rtype = value
807
808
            elif option == "-e":
                expect = value
809
810
811
812
813
814
            elif option == "-p":
                path = value
            elif option == "-P":
                post = True
            elif option == "-h":
                head = True
815
816
            elif option == "-i":
                insecure = True
817
818
819
820
            elif option == "-4":
                forceIPv4 = True
            elif option == "-6":
                forceIPv6 = True
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
            else:
                # Should never occur, it is trapped by getopt
                print("Unknown option %s" % option)
                sys.exit(STATE_UNKNOWN)
    except getopt.error as reason:
        print("Option parsing problem %s" % reason)
        sys.exit(STATE_UNKNOWN)
    if len(args) > 0:
        print("Too many arguments (\"%s\")" % args)
        sys.exit(STATE_UNKNOWN)
    if host is None or name is None:
        print("Host (-H) and name to lookup (-n) are necessary")
        sys.exit(STATE_UNKNOWN)
    if post and head:
        print("POST or HEAD but not both")
        sys.exit(STATE_UNKNOWN)
837
838
839
840
841
842
    if dot and (post or head):
        print("POST or HEAD makes no sense for DoT")
        sys.exit(STATE_UNKNOWN)
    if dot and path:
        print("URL path makes no sense for DoT")
        sys.exit(STATE_UNKNOWN)
843
844
845
    if dot:
        url = host
    else:
846
847
        if vhostname is None or vhostname == host:
            connectTo = None
848
            url = "https://%s/" % host
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
849
        else:
850
            connectTo = host
Alexandre's avatar
Alexandre committed
851
            url = "https://%s/" % vhostname
852
853
854
855
        if path is not None:
            if path.startswith("/"):
                path = path[1:]
            url += path
856

857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
# retrieve all ips when using --check
# not necessary if connectTo is already defined
# as it is the case with --monitoring
if not check or connectTo is not None:
    ip_set = {connectTo, }
else:
    if dot:
        port = 853
        if not is_valid_hostname(url):
            error("DoT requires a host name or IP address, not \"%s\"" % url)
        netloc = url
    else:
        port = 443
        if not is_valid_url(url):
            error("DoH requires a valid HTTPS URL, not \"%s\"" % url)
Alexandre's avatar
Alexandre committed
872
873
874
875
876
877
        try:
            url_parts = urllib.parse.urlparse(url) # A very poor validation, many
            # errors (for instance whitespaces, IPv6 address litterals without
            # brackets...) are ignored.
        except ValueError:
            error(f'The provided url "{url}" could not be parsed')
878
879
880
881
882
883
884
885
886
        netloc = url_parts.netloc
    if forceIPv4:
        family = socket.AF_INET
    elif forceIPv6:
        family = socket.AF_INET6
    else:
        family = 0
    ip_set = resolved_ips(netloc, port, family, dot)

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
887
ok = True
888
889
for connectTo in ip_set:
    start = time.time()
890
    if dot and vhostname is not None:
891
892
893
        extracheck = vhostname
    else:
        extracheck = None
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
    if verbose and check and connectTo:
        print(f'Checking "{url}" on {connectTo} ...')
    try:
        if dot:
            conn = ConnectionDoT(url, servername=extracheck, connect=connectTo, verbose=verbose,
                              forceIPv4=forceIPv4, forceIPv6=forceIPv6,
                              insecure=insecure)
        else:
            conn = ConnectionDoH(url, servername=extracheck, connect=connectTo, verbose=verbose,
                              forceIPv4=forceIPv4, forceIPv6=forceIPv6,
                              insecure=insecure)
    except TimeoutError:
        error("timeout")
    except ConnectionRefusedError:
        error("Connection to server refused")
    except ValueError:
        error(f'"{url}" not a name or an IP address')
    except socket.gaierror:
        error(f'Could not resolve "{url}"')
    except CustomException as e:
        error(e)
Alexandre's avatar
Alexandre committed
915
    if conn.dot and not conn.success:
916
        ok = False
Alexandre's avatar
Alexandre committed
917
        continue
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
    if ifile is not None:
        input = open(ifile)
    if not check:
        for i in range (0, tests):
            if tests > 1:
                print("\nTest %i" % i)
            if ifile is not None:
                name, rtype = get_next_domain(input)
            try:
                request = conn.do_test(name, rtype)
            except (OpenSSL.SSL.Error, CustomException) as e:
                ok = False
                error(e)
                break
            if not print_result(conn, request):
                ok = False
            if tests > 1 and i == 0:
                start2 = time.time()
            if delay is not None:
                time.sleep(delay)
938
    else:
939
940
941
942
943
944
945
946
947
948
949
        ok = run_check(conn)
    stop = time.time()
    if tests > 1:
        extra = ", %.2f ms/request if we ignore the first one" % ((stop-start2)*1000/(tests-1))
    else:
        extra = ""
    if not monitoring and (not check or verbose):
        print("\nTotal elapsed time: %.2f seconds (%.2f ms/request %s)" % (stop-start, (stop-start)*1000/tests, extra))
    if ifile is not None:
        input.close()
    conn.end()
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
950
if ok:
951
    print('OK')
952
953
954
955
    if not monitoring:
        sys.exit(0)
    else:
        sys.exit(STATE_OK)
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
956
else:
957
    print('KO')
958
959
960
961
    if not monitoring:
        sys.exit(1)
    else:
        sys.exit(STATE_CRITICAL)