homer.py 28.6 KB
Newer Older
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
1
2
#!/usr/bin/env python3

3
4
5
6
7
# Homer is a DoH (DNS-over-HTTPS) and DoT (DNS-over-TLS) client. Its
# main purpose is to test DoH and DoT resolvers. Reference site is
# <https://framagit.org/bortzmeyer/homer/> See author, documentation,
# etc, there, or in the README.md included with the distribution.

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
8
9
10
11
12
13
# http://pycurl.io/docs/latest
import pycurl

# http://www.dnspython.org/
import dns.message

14
15
16
# https://github.com/drkjam/netaddr/
import netaddr

17
18
19
20
21
# Octobre 2019: the Python GnuTLS bindings don't work with Python 3. So we use OpenSSL.
# https://www.pyopenssl.org/
# https://pyopenssl.readthedocs.io/
import OpenSSL

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
22
23
24
25
26
27
import io
import sys
import base64
import getopt
import urllib.parse
import time
28
29
import socket
import ctypes
30
import re
31
import os.path
32
33
import hashlib
import base64
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
34

35
# Values that can be changed from the command line
36
dot = False # DoH by default
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
37
38
verbose = False
insecure = False
39
post = False
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
40
head = False
41
42
dnssec = False
edns = True
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
43
rtype = 'AAAA'
44
vhostname = None
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
45
tests = 1 # Number of repeated tests
46
ifile = None # Input file
47
delay = None
48
49
50
forceIPv4 = False
forceIPv6 = False
connectTo = None
51
check = False
52
53
54
# Monitoring plugin only:
host = None
path = None
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
55

56
57
58
# Do not change these
re_host = re.compile(r'^([0-9a-z][0-9a-z-\.]*)|([0-9:]+)|([0-9\.])$')

59
60
61
62
63
64
65
# For the monitoring plugin
STATE_OK = 0
STATE_WARNING = 1
STATE_CRITICAL = 2
STATE_UNKNOWN = 3
STATE_DEPENDENT = 4

Alexandre's avatar
Alexandre committed
66
67
68
69
70
# For the check option
DOH_GET = 0
DOH_POST = 1
DOH_HEAD = 2

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
71
72
73
def error(msg=None):
    if msg is None:
        msg = "Unknown error"
74
75
    if monitoring:
        print("%s: %s" % (url, msg))
Alexandre's avatar
Alexandre committed
76
        sys.exit(STATE_CRITICAL)
77
    else:
Alexandre's avatar
Alexandre committed
78
79
80
        print(msg, file=sys.stderr)
        if check:
            print('KO')
81
        sys.exit(1)
Alexandre's avatar
Alexandre committed
82

Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
83
84
85
def usage(msg=None):
    if msg:
        print(msg,file=sys.stderr)
86
87
    print("Usage: %s [--dot] url-or-servername domain-name [DNS type]" % sys.argv[0], file=sys.stderr)
    print("See the README.md for more details.", file=sys.stderr)
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
88

89
def is_valid_hostname(name):
90
    name = canonicalize(name)
91
    return re_host.search(name)
92

93
94
95
96
97
98
99
100
101
def canonicalize(hostname):
    result = hostname.lower()
    # TODO handle properly the case where it fails with UnicodeError
    # (two consecutive dots for instance) to get a custom exception
    result = result.encode('idna').decode()
    if result[len(result)-1] == '.':
        result = result[:-1]
    return result

102
103
104
105
def is_valid_ip_address(addr):
    try:
        baddr = netaddr.IPAddress(addr)
    except netaddr.core.AddrFormatError:
106
107
        return (False, None)
    return (True, baddr.version)
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def is_valid_url(url):
  try:
    result = urllib.parse.urlparse(url) # A very poor validation, many
    # errors (for instance whitespaces, IPv6 address litterals without
    # brackets...) are ignored.
    return (result.scheme=="https" and result.netloc != "")
  except ValueError:
    return False

def get_certificate_san(x509cert):
    san = ""
    ext_count = x509cert.get_extension_count()
    for i in range(0, ext_count):
        ext = x509cert.get_extension(i)
        if "subjectAltName" in str(ext.get_short_name()):
            san = str(ext)
    return san

127
128
129
130
131
132
133
134
135
136
137
# Try one possible name. Names must be already canonicalized.
def match_hostname(hostname, possibleMatch):
    if possibleMatch.startswith("*."): # Wildcard
        base = possibleMatch[1:] # Skip the star
        # RFC 6125 says that we MAY accept left-most labels with
        # wildcards included (foo*bar). We don't do it here.
        try:
            (first, rest) = hostname.split(".", maxsplit=1)
        except ValueError: # One-label name
            rest = hostname
        if rest == base[1:]:
138
            return True
139
        if hostname == base[1:]:
140
            return True
141
142
143
144
145
146
147
148
        return False
    else:
        return hostname == possibleMatch

# Try all the names in the certificate
def validate_hostname(hostname, cert):
    # Complete specification is in RFC 6125. It is long and
    # complicated and I'm not sure we do it perfectly.
149
    (is_addr, family) = is_valid_ip_address(hostname)
150
    hostname = canonicalize(hostname)
151
    for alt_name in get_certificate_san(cert).split(", "):
152
        if alt_name.startswith("DNS:") and not is_addr:
153
            (start, base) = alt_name.split("DNS:")
154
            base = canonicalize(base)
155
156
            found = match_hostname(hostname, base)
            if found:
157
                return True
158
159
        elif alt_name.startswith("IP Address:") and is_addr:
            host_i = netaddr.IPAddress(hostname)
160
161
162
            (start, base) = alt_name.split("IP Address:")
            if base.endswith("\n"):
                base = base[:-1]
163
164
165
166
            try:
                base_i = netaddr.IPAddress(base)
            except netaddr.core.AddrFormatError:
                continue # Ignore broken IP addresses in certificates. Are we too liberal?
Alexandre's avatar
Alexandre committed
167
            if host_i == base_i:
168
169
                return True
        else:
170
171
172
            pass # Ignore unknown alternative name types. May be
                 # accept URI alternative names for DoH,
    # According to RFC 6125, we MUST NOT try the Common Name before the Subject Alternative Names.
173
    cn = canonicalize(cert.get_subject().commonName)
174
175
176
    found = match_hostname(hostname, cn)
    if found:
        return True
177
178
    return False

179

Alexandre's avatar
Alexandre committed
180
181
182
183
class CustomException(Exception):
    pass


184
def create_request(dot=dot, trunc=False, **req_args):
185
186
187
188
    if dot:
        request = RequestDoT(**req_args)
    else:
        request = RequestDoH(**req_args)
189
190
191
192
    if trunc:
        request.trunc_data()
    else:
        request.to_wire()
193
194
    return request

Alexandre's avatar
Alexandre committed
195
def create_requests_list(dot=dot, **req_args):
196
197
    requests = []
    if dot:
Alexandre's avatar
Alexandre committed
198
199
        requests.append(('Test 1', create_request(dot=dot, **req_args)))
        requests.append(('Test 2', create_request(dot=dot, **req_args)))
200
        requests.append(('Test truncated data', create_request(dot=dot, trunc=True, **req_args)))
201
202
203
204
    else:
        requests.append(('Test GET', create_request(**req_args), DOH_GET))
        requests.append(('Test POST', create_request(**req_args), DOH_POST))
        requests.append(('Test HEAD', create_request(**req_args), DOH_HEAD))
205
        requests.append(('Test truncated data', create_request(trunc=True, **req_args), DOH_POST))
206
207
    return requests

Alexandre's avatar
Alexandre committed
208

Alexandre's avatar
Alexandre committed
209
class Request:
210
211
212
    def __init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec):
        self.message = dns.message.make_query(qname, dns.rdatatype.from_text(qtype), use_edns=use_edns, want_dnssec=want_dnssec)
        self.message.flags |= dns.flags.AD # Ask for validation
Alexandre's avatar
Alexandre committed
213
        self.ok = True
Alexandre's avatar
Alexandre committed
214

215
216
217
218
219
    def trunc_data(self):
        self.data = self.message.to_wire()
        half = round(len(self.data) / 2)
        self.data = self.data[:half]

Alexandre's avatar
Alexandre committed
220
    def to_wire(self):
221
222
        self.data = self.message.to_wire()

Alexandre's avatar
Alexandre committed
223
224

class RequestDoT(Request):
225
226
227
228
229
    def check_response(self):
        if self.response.id != self.message.id:
            raise Exception("The ID in the answer does not match the one in the query")


Alexandre's avatar
Alexandre committed
230
class RequestDoH(Request):
231
    def __init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec):
Alexandre's avatar
Alexandre committed
232
        Request.__init__(self, qname, qtype=rtype, use_edns=edns, want_dnssec=dnssec)
233
        self.message.id = 0 # DoH requests that
234
        self.post = False
Alexandre's avatar
Alexandre committed
235
        self.head = False
236

237
    def check_response(self):
Alexandre's avatar
Alexandre committed
238
        ok = self.ok
239
        if self.rcode == 200:
240
241
            if self.ctype != "application/dns-message":
                self.response = "Content type of the response (\"%s\") invalid" % self.ctype
242
243
                ok = False
            else:
Alexandre's avatar
Alexandre committed
244
                if not self.head:
245
246
247
248
249
250
251
252
253
254
255
                    try:
                        self.response = dns.message.from_wire(self.body)
                    except dns.message.TrailingJunk: # Not DNS. Should
                        # not happen for a content type
                        # application/dns-message but who knows?
                        self.response = "ERROR Not proper DNS data, trailing junk \"%s\"" % self.body
                        ok = False
                    except dns.name.BadLabelType: # Not DNS.
                        self.response = "ERROR Not proper DNS data (wrong path in the URL?) \"%s\"" % self.body[:100]
                        ok = False
                else:
Alexandre's avatar
Alexandre committed
256
                    if self.response_size == 0:
257
258
259
260
261
262
                        self.response = "HEAD successful"
                    else:
                        self.response = "ERROR Body length is not null \"%s\"" % self.body[:100]
                        ok = False
        else:
            ok = False
Alexandre's avatar
Alexandre committed
263
            if self.response_size == 0:
264
265
266
                self.response = "[No details]"
            else:
                self.response = self.body
Alexandre's avatar
Alexandre committed
267
        self.ok = ok
Alexandre's avatar
Alexandre committed
268
        return ok
269

270

271
class Connection:
272
273
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
                 dot=False, verbose=verbose, insecure=insecure, post=post, head=head):
274
        if dot and not is_valid_hostname(server):
275
            error("DoT requires a host name or IP address, not \"%s\"" % server)
276
277
        if not dot and not is_valid_url(url):
            error("DoH requires a valid HTTPS URL, not \"%s\"" % server)
278
        if forceIPv4 and forceIPv6:
Alexandre's avatar
Alexandre committed
279
            raise CustomException("Force IPv4 *or* IPv6 but not both")
280
        self.server = server
281
282
        self.servername = servername
        if self.servername is not None:
Alexandre's avatar
Alexandre committed
283
            self.check = self.servername
284
        else:
Alexandre's avatar
Alexandre committed
285
            self.check = self.server
286
287
288
        self.dot = dot
        self.verbose = verbose
        self.insecure = insecure
289

290
291
    def __str__(self):
        return self.server
292

293
294
295
    def check_ip_address(self, addr):
        (is_addr, self.family) = is_valid_ip_address(addr)
        if not is_addr and not self.dot:
Alexandre's avatar
Alexandre committed
296
            raise CustomException("%s is not IPv4 and not IPv6" % addr)
297
        if forceIPv4 and self.family == 6:
Alexandre's avatar
Alexandre committed
298
            raise CustomException("You cannot force IPv4 with a litteral IPv6 address (%s)" % addr)
299
        elif forceIPv6 and self.family == 4:
Alexandre's avatar
Alexandre committed
300
            raise CustomException("You cannot force IPv6 with a litteral IPv4 address (%s)" % addr)
301
302
303
304
305
306
307
308
309
        if forceIPv4 or self.family == 4:
            self.family = socket.AF_INET
            self.repraddress = addr
        elif forceIPv6 or self.family == 6:
            self.family = socket.AF_INET6
            self.repraddress = f'[{addr}]'
        else:
            self.family = 0

Alexandre's avatar
Alexandre committed
310
311
312
313
314
315
316
317
    def do_test(self, qname, qtype=rtype):
        # Routine doing one actual test. Returns a tuple, first member is a
        # result (boolean indicating success for DoT, HTTP status code for
        # DoH), second member is a DNS message (or a string if there is an
        # error), third member is the size of the DNS message (or None if no
        # proper response).
        pass

318
319
320
321
322
323
324

class ConnectionDoT(Connection):
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
                 dot=False, verbose=verbose, insecure=insecure, post=post, head=head):
        Connection.__init__(self, server, servername=servername, connect=connect,
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=dot,
                verbose=verbose, insecure=insecure, post=post, head=head)
325
        self.check_ip_address(self.server)
326
        self.hasher = hashlib.sha256()
327
        addrinfo = socket.getaddrinfo(server, 853, self.family)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        # May be loop over the results of getaddrinfo, to test all
        # the IP addresses? See #13.
        self.sock = socket.socket(addrinfo[0][0], socket.SOCK_STREAM)
        self.addr = addrinfo[0][4]
        if self.verbose:
            print("Connecting to %s ..." % str(self.addr))
        # With typical DoT servers, we *must* use TLS 1.2 (otherwise,
        # do_handshake fails with "OpenSSL.SSL.SysCallError: (-1, 'Unexpected
        # EOF')" Typical HTTP servers are more lax.
        self.context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD)
        if self.insecure:
            self.context.set_verify(OpenSSL.SSL.VERIFY_NONE, lambda *x: True)
        else:
            self.context.set_default_verify_paths()
            self.context.set_verify_depth(4) # Seems ignored
            self.context.set_verify(OpenSSL.SSL.VERIFY_PEER | OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT | \
                                    OpenSSL.SSL.VERIFY_CLIENT_ONCE,
                                    lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
        self.session = OpenSSL.SSL.Connection(self.context, self.sock)
        self.session.set_tlsext_host_name(canonicalize(self.check).encode()) # Server Name Indication (SNI)
        self.session.connect((self.addr))
        # TODO We may here have exceptions such as OpenSSL.SSL.ZeroReturnError
        self.session.do_handshake()
        self.cert = self.session.get_peer_certificate()
        # RFC 7858, section 4.2 and appendix A
        self.publickey = self.cert.get_pubkey()
        if verbose:
            print("Certificate #%x for \"%s\", delivered by \"%s\"" % \
                  (self.cert.get_serial_number(),
                   self.cert.get_subject().commonName,
                   self.cert.get_issuer().commonName))
            self.hasher.update(OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
                                                  self.publickey))
            self.digest = self.hasher.digest()
            print("Public key is pin-sha256=\"%s\"" % \
                  base64.standard_b64encode(self.digest).decode())
        if not insecure:
            valid = validate_hostname(self.check, self.cert)
            if not valid:
                error("Certificate error: \"%s\" is not in the certificate" % (self.check))

369
    def end(self):
370
371
372
        self.session.shutdown()
        self.session.close()

373
374
375
376
    def send_data(self, data):
        length = len(data)
        self.session.send(length.to_bytes(2, byteorder='big') + data)

377
    def receive_data(self, request):
Alexandre's avatar
Alexandre committed
378
        buf = self.session.recv(2)
379
380
381
382
383
384
385
386
        request.response_size = int.from_bytes(buf, byteorder='big')
        buf = self.session.recv(request.response_size)
        request.response = dns.message.from_wire(buf)
        request.rcode = True

    def send_and_receive(self, request):
        self.send_data(request.data)
        self.receive_data(request)
387
388

    def do_test(self, qname, qtype=rtype):
389
        request = RequestDoT(qname, qtype, want_dnssec=dnssec, use_edns=edns)
390
        request.to_wire()
391
392
        self.send_and_receive(request)
        request.check_response()
393
        return request
Alexandre's avatar
Alexandre committed
394

395
396
397
398
399
400
401
402
403
404

class ConnectionDoH(Connection):
    def __init__(self, server, servername=None, connect=None, forceIPv4=False, forceIPv6=False,
                 dot=False, verbose=verbose, insecure=insecure, post=post, head=head):
        Connection.__init__(self, server, servername=servername, connect=connect,
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, dot=dot,
                verbose=verbose, insecure=insecure, post=post, head=head)
        self.post = post
        self.head = head
        self.url = server
Alexandre's avatar
Alexandre committed
405
406
        self.connect = connect

407
408
409
410
411
    def create_handle(self):
        self.curl = pycurl.Curl()
        # Does not work if pycurl was not compiled with nghttp2 (recent Debian
        # packages are OK) https://github.com/pycurl/pycurl/issues/477
        self.curl.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
412
        if self.verbose:
413
            self.curl.setopt(pycurl.VERBOSE, True)
414
        if self.insecure:
415
416
            self.curl.setopt(pycurl.SSL_VERIFYPEER, False)
            self.curl.setopt(pycurl.SSL_VERIFYHOST, False)
417
        if forceIPv4:
418
            self.curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
419
        if forceIPv6:
420
            self.curl.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V6)
Alexandre's avatar
Alexandre committed
421
422
        if self.connect is not None:
            self.check_ip_address(self.connect)
423
            self.curl.setopt(pycurl.CONNECT_TO, [f'::{self.repraddress}:443',])
Alexandre's avatar
Alexandre committed
424
        self.curl.setopt(pycurl.HTTPHEADER, ["Accept: application/dns-message", "Content-type: application/dns-message"])
425
426

    def end(self):
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        self.curl.close()

    def set_opt(self, opt, value):
        self.curl.setopt(opt, value)

    def reset_opt_default(self):
        opts = {
                pycurl.NOBODY: False,
                pycurl.POST: False,
                pycurl.POSTFIELDS: '',
                pycurl.URL: ''
               }
        for opt, value in opts.items():
            self.set_opt(opt, value)

    def prepare(self, request):
        try:
            self.reset_opt_default()
        except AttributeError:
            self.create_handle()
447
        if self.post or request.post:
448
            self.prepare_post(request)
449
        elif self.head or request.head:
450
451
452
453
            self.prepare_head(request)
            request.head = True
        else:
            self.prepare_get(request)
454

455
456
457
458
    def prepare_get(self, request):
        self.set_opt(pycurl.HTTPGET, True)
        dns_req = base64.urlsafe_b64encode(request.data).decode('UTF8').rstrip('=')
        self.set_opt(pycurl.URL, self.server + ("?dns=%s" % dns_req))
459

460
    def prepare_post(self, request):
461
        request.post = True
462
463
464
        self.set_opt(pycurl.POST, True)
        self.set_opt(pycurl.POSTFIELDS, request.data)
        self.set_opt(pycurl.URL, self.server)
465

466
    def prepare_head(self, request):
467
        request.head = True
468
469
470
471
472
473
        self.prepare_get(request)
        self.set_opt(pycurl.NOBODY, True)

    def perform(self):
        self.buffer = io.BytesIO()
        self.set_opt(pycurl.WRITEDATA, self.buffer)
Alexandre's avatar
Alexandre committed
474
475
476
477
        try:
            self.curl.perform()
        except pycurl.error as e:
            error(e.args[1])
478
479
480
481
482
483
484

    def receive(self, request):
        body = self.buffer.getvalue()
        body_size = len(body)
        http_code = self.curl.getinfo(pycurl.RESPONSE_CODE)
        content_type = self.curl.getinfo(pycurl.CONTENT_TYPE)
        request.body = body
Alexandre's avatar
Alexandre committed
485
        request.response_size = body_size
486
487
488
489
490
491
492
493
        request.rcode = http_code
        request.ctype = content_type
        self.buffer.close()

    def send_and_receive(self, request):
        self.prepare(request)
        self.perform()
        self.receive(request)
494

Alexandre's avatar
Alexandre committed
495
    def do_test(self, qname, qtype=rtype):
496
        request = RequestDoH(qname, qtype, want_dnssec=dnssec, use_edns=edns)
497
        request.to_wire()
498
499
        self.send_and_receive(request)
        request.check_response()
500
        return request
Alexandre's avatar
Alexandre committed
501

Alexandre's avatar
Alexandre committed
502
503
504
505
506
507
508
509
510
511
512
513
514

def get_next_domain(input_file):
    name, rtype = 'framagit.org', 'AAAA'
    line = input_file.readline()
    if line[:-1] == "":
        error("Not enough data in %s for the %i tests" % (ifile, tests))
    if line.find(' ') == -1:
        name = line[:-1]
        rtype = 'AAAA'
    else:
        (name, rtype) = line.split()
    return name, rtype

515
def print_result(connection, request, prefix=None):
Alexandre's avatar
Alexandre committed
516
    ok = request.ok
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    dot = connection.dot
    server = connection.server
    rcode = request.rcode
    msg = request.response
    size = request.response_size
    if (dot and rcode) or (not dot and rcode == 200):
        if not monitoring:
            if not check or verbose:
                print(msg)
        else:
            if size is not None and size > 0:
                print("%s OK - %s" % (server, "No error for %s/%s, %i bytes received" % (name, rtype, size)))
            else:
                print("%s OK - %s" % (server, "No error"))
            sys.exit(STATE_OK)
    else:
        if not monitoring:
            if prefix:
                print(prefix, end=': ', file=sys.stderr)
            if dot:
                print("Error: %s" % msg, file=sys.stderr)
            else:
               try:
                   msg = msg.decode()
               except (UnicodeDecodeError, AttributeError):
                   pass # Sometimes, msg can be binary, or Latin-1
               print("HTTP error %i: %s" % (rcode, msg), file=sys.stderr)
        else:
            if not dot:
                print("%s HTTP error - %i: %s" % (server, rcode, msg))
            else:
                print("%s Error - %i: %s" % (server, rcode, msg))
            sys.exit(STATE_CRITICAL)
        ok = False
    return ok


554
def run_check_default(connection):
Alexandre's avatar
Alexandre committed
555
    ok = True
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    requests = create_requests_list(dot=dot, **req_args)
    for request_pack in requests:
        if dot:
            test_name, request = request_pack
        else:
            test_name, request, method = request_pack
        if verbose:
            print(test_name)
        if not dot:
            if method == DOH_POST:
                request.post = True
            elif method == DOH_HEAD:
                request.head = True
        try:
            connection.send_and_receive(request)
        except CustomException as e:
            ok = False
            error(e)
            break
        request.check_response()
577
        if not print_result(connection, request, prefix=test_name):
578
579
            ok = False
            break
Alexandre's avatar
Alexandre committed
580
    return ok
581

582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
def run_check_mime(connection, accept="application/dns-message", content_type="application/dns-message"):
    if dot:
        return True
    ok = True
    header = [f"Accept: {accept}", f"Content-type: {content_type}"]
    req_args = { 'qname': name, 'qtype': rtype, 'use_edns': edns, 'want_dnssec': dnssec }
    request = create_request(**req_args)
    connection.curl.setopt(pycurl.HTTPHEADER, header)
    try:
        connection.send_and_receive(request)
    except CustomException as e:
        ok = False
        error(e)
    request.check_response()
    if not print_result(connection, request, prefix=f"Test Header {', '.join(header)}"):
        ok = False
    default = "application/dns-message"
    default_header = [f"Accept: {default}", f"Content-type: {default}"]
    connection.curl.setopt(pycurl.HTTPHEADER, default_header)
    return ok


def run_check_additionals(connection):
    if not run_check_mime(connection, accept="text/html"):
        return False
    if not run_check_mime(connection, content_type="text/html"):
        return False
    return True

def run_check(connection):
    if not run_check_default(connection):
        return False
    if not run_check_additionals(connection):
        return False
    return True

618
# Main program
619
620
621
622
623
624
me = os.path.basename(sys.argv[0])
monitoring = (me == "check_doh" or me == "check_dot")
if not monitoring:
    name = None
    message = None
    try:
625
        optlist, args = getopt.getopt (sys.argv[1:], "hvPkeV:r:f:d:t46",
626
627
                                       ["help", "verbose", "dot", "head",
                                        "insecure", "POST", "vhost=",
628
                                        "dnssec", "noedns","repeat=", "file=", "delay=", "v4only", "v6only", "check"])
629
630
631
632
        for option, value in optlist:
            if option == "--help" or option == "-h":
                usage()
                sys.exit(0)
633
634
            elif option == "--dot" or option == "-t":
                dot = True
635
636
            elif option == "--verbose" or option == "-v":
                verbose = True
637
            elif option == "--HEAD" or option == "--head" or option == "-e":
638
                head = True
639
            elif option == "--POST" or option == "--post" or option == "-P":
640
                post = True
641
642
            elif option == "--vhost" or option == "-V":
                vhostname = value
643
644
            elif option == "--insecure" or option == "-k":
                insecure = True
645
646
647
648
            elif option == "--dnssec":
                dnssec = True
            elif option == "--noedns":
                edns = False
649
650
651
652
653
654
655
656
657
658
            elif option == "--repeat" or option == "-r":
                tests = int(value)
                if tests <= 1:
                    error("--repeat needs a value > 1")
            elif option == "--delay" or option == "-d":
                delay = float(value)
                if delay <= 0:
                    error("--delay needs a value > 0")
            elif option == "--file" or option == "-f":
                ifile = value
659
660
661
662
            elif option == "-4" or option == "v4only":
                forceIPv4 = True
            elif option == "-6" or option == "v6only":
                forceIPv6 = True
663
664
            elif option == "--check":
                check = True
665
666
667
668
669
670
671
672
673
674
675
676
            else:
                error("Unknown option %s" % option)
    except getopt.error as reason:
        usage(reason)
        sys.exit(1)
    if tests <= 1 and delay is not None:
        error("--delay makes no sense if there is no repetition")
    if post and head:
        usage("POST or HEAD but not both")
        sys.exit(1)
    if dot and (post or head):
        usage("POST or HEAD makes non sense for DoT")
Alexandre's avatar
Alexandre committed
677
        sys.exit(1)
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
    if ifile is None and (len(args) != 2 and len(args) != 3):
        usage("Wrong number of arguments")
        sys.exit(1)
    if ifile is not None and len(args) != 1:
        usage("Wrong number of arguments (if --file is used, do not indicate the domain name)")
        sys.exit(1)
    url = args[0]
    if ifile is None:
        name = args[1]
        if len(args) == 3:
            rtype = args[2]
else: # Monitoring plugin
    dot = (me == "check_dot")
    name = None
    try:
693
        optlist, args = getopt.getopt (sys.argv[1:], "H:n:p:V:t:Pih46")
694
695
696
697
        for option, value in optlist:
            if option == "-H":
                host = value
            elif option == "-V":
Alexandre's avatar
Alexandre committed
698
                vhostname = value
699
700
701
702
703
704
705
706
707
708
            elif option == "-n":
                name = value
            elif option == "-t":
                rtype = value
            elif option == "-p":
                path = value
            elif option == "-P":
                post = True
            elif option == "-h":
                head = True
709
710
            elif option == "-i":
                insecure = True
711
712
713
714
            elif option == "-4":
                forceIPv4 = True
            elif option == "-6":
                forceIPv6 = True
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
            else:
                # Should never occur, it is trapped by getopt
                print("Unknown option %s" % option)
                sys.exit(STATE_UNKNOWN)
    except getopt.error as reason:
        print("Option parsing problem %s" % reason)
        sys.exit(STATE_UNKNOWN)
    if len(args) > 0:
        print("Too many arguments (\"%s\")" % args)
        sys.exit(STATE_UNKNOWN)
    if host is None or name is None:
        print("Host (-H) and name to lookup (-n) are necessary")
        sys.exit(STATE_UNKNOWN)
    if post and head:
        print("POST or HEAD but not both")
        sys.exit(STATE_UNKNOWN)
731
732
733
734
735
736
    if dot and (post or head):
        print("POST or HEAD makes no sense for DoT")
        sys.exit(STATE_UNKNOWN)
    if dot and path:
        print("URL path makes no sense for DoT")
        sys.exit(STATE_UNKNOWN)
737
738
739
    if dot:
        url = host
    else:
740
741
        if vhostname is None or vhostname == host:
            connectTo = None
742
            url = "https://%s/" % host
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
743
        else:
744
            connectTo = host
Alexandre's avatar
Alexandre committed
745
            url = "https://%s/" % vhostname
746
747
748
749
        if path is not None:
            if path.startswith("/"):
                path = path[1:]
            url += path
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
750
ok = True
Alexandre's avatar
Alexandre committed
751
start = time.time()
752
try:
753
    if dot and vhostname is not None:
754
755
756
        extracheck = vhostname
    else:
        extracheck = None
757
758
759
760
761
762
763
764
    if dot:
        conn = ConnectionDoT(url, dot=dot, servername=extracheck, connect=connectTo, verbose=verbose,
                          forceIPv4=forceIPv4, forceIPv6=forceIPv6,
                          insecure=insecure, post=post, head=head)
    else:
        conn = ConnectionDoH(url, dot=dot, servername=extracheck, connect=connectTo, verbose=verbose,
                          forceIPv4=forceIPv4, forceIPv6=forceIPv6,
                          insecure=insecure, post=post, head=head)
765
766
except TimeoutError:
    error("timeout")
767
768
except ConnectionRefusedError:
    error("Connection to server refused")
Alexandre's avatar
Alexandre committed
769
except ValueError:
Alexandre's avatar
Alexandre committed
770
    error(f'"{url}" not a name or an IP address')
Alexandre's avatar
Alexandre committed
771
772
except CustomException as e:
    error(e)
773
774
if ifile is not None:
    input = open(ifile)
775
776
777
778
779
780
781
if not check:
    for i in range (0, tests):
        if tests > 1:
            print("\nTest %i" % i)
        if ifile is not None:
            name, rtype = get_next_domain(input)
        try:
782
            request = conn.do_test(name, rtype)
783
        except (OpenSSL.SSL.Error, CustomException) as e:
784
785
786
            ok = False
            error(e)
            break
787
        if not print_result(conn, request):
788
789
790
791
792
793
            ok = False
        if tests > 1 and i == 0:
            start2 = time.time()
        if delay is not None:
            time.sleep(delay)
else:
Alexandre's avatar
Alexandre committed
794
    ok = run_check(conn)
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
795
stop = time.time()
796
797
798
799
if tests > 1:
    extra = ", %.2f ms/request if we ignore the first one" % ((stop-start2)*1000/(tests-1))
else:
    extra = ""
800
if not monitoring and (not check or verbose):
801
    print("\nTotal elapsed time: %.2f seconds (%.2f ms/request %s)" % (stop-start, (stop-start)*1000/tests, extra))
802
803
if ifile is not None:
    input.close()
804
conn.end()
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
805
if ok:
806
    print('OK')
807
808
809
810
    if not monitoring:
        sys.exit(0)
    else:
        sys.exit(STATE_OK)
Stephane Bortzmeyer's avatar
Stephane Bortzmeyer committed
811
else:
812
    print('KO')
813
814
815
816
    if not monitoring:
        sys.exit(1)
    else:
        sys.exit(STATE_CRITICAL)