connection.py 19.3 KB
Newer Older
Alexandre's avatar
Alexandre committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import io
import socket
import signal
import hashlib
import base64

try:
    # http://pycurl.io/docs/latest
    import pycurl

    # Octobre 2019: the Python GnuTLS bindings don't work with Python 3. So we use OpenSSL.
    # https://www.pyopenssl.org/
    # https://pyopenssl.readthedocs.io/
    import OpenSSL

    # http://www.dnspython.org/
    import dns.message
except ImportError as e:
    print("Error: missing module")
    print(e)
    sys.exit(1)

import homer.utils
import homer.exceptions

class Connection:

    def __init__(self, server, servername=None, connect_to=None,
                 forceIPv4=False, forceIPv6=False, insecure=False,
                 verbose=False, debug=False, dot=False):

        if dot and not homer.is_valid_hostname(server):
            raise homer.ConnectionDOTException("DoT requires a host name or IP address, not \"%s\"" % server)

        if not dot and not homer.is_valid_url(server):
            raise homer.ConnectionDOHException("DoH requires a valid HTTPS URL, not \"%s\"" % server)

        if forceIPv4 and forceIPv6:
            raise homer.ConnectionException("Force IPv4 *or* IPv6 but not both")

        self.dot = dot
        self.server = server
        self.servername = servername
        if self.servername is not None:
            self.check_name_cert = self.servername
        else:
            self.check_name_cert = self.server
        self.verbose = verbose
        self.debug = debug
        self.insecure = insecure
        self.forceIPv4 = forceIPv4
        self.forceIPv6 = forceIPv6
        self.connect_to = connect_to

    def __str__(self):
        return self.server

    def do_test(self, request):
        # Routine doing one actual test. Returns nothing
        pass


class ConnectionDOT(Connection):

    def __init__(self, server, servername=None, connect_to=None,
                 forceIPv4=False, forceIPv6=False, insecure=False,
                 verbose=False, debug=False,
                 sni=True, key=None, pipelining=False):

70
        super().__init__(server, servername=servername, connect_to=connect_to,
Alexandre's avatar
Alexandre committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, insecure=insecure,
                verbose=verbose, debug=debug, dot=True)

        self.sni = sni
        self.key = key
        self.pipelining = pipelining
        if self.pipelining:
            self.all_requests = [] # Currently, we load everything in memory
                                   # since we want to keep everything,
                                   # anyway. May be in the future, if we don't
                                   # want to keep individual results, we'll use
                                   # an iterator to fill a smaller table.
                                   # all_requests is indexed by its rank in the input file.
            self.pending = {} # pending is indexed by the query ID, and its
                              # maximum size is max_in_flight.

        # establish the connection
        self.connect()

    def connect(self):
        # if connect_to is defined, it means we know the IP address of the
        # server and therefore we can establish a connection with it
        # otherwise we only have a domain name and we should loop on all
        # resolved IPs until a connection can be established
        # getaddrinfo provides a list of resolved IPs, when connect_to is
        # defined this list will have only one element
        # so we can loop on the items until a connection is made
        # the list is converted into a set of tuples to avoid duplicates

        self.success = False

        if self.connect_to is not None: # the server's IP address is known
            addr = self.connect_to
        else:
            addr = self.server # otherwise keep the server name

        family = homer.get_addrfamily(addr, forceIPv4=self.forceIPv4, forceIPv6=self.forceIPv6)
108
        addrinfo_list = socket.getaddrinfo(addr, homer.PORT_DOT, family)
Alexandre's avatar
Alexandre committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        addrinfo_set = { (addrinfo[4], addrinfo[0]) for addrinfo in addrinfo_list }

        signal.signal(signal.SIGALRM, homer.exceptions.timeout_connection)

        for addrinfo in addrinfo_set:
            if self.establish_session(addrinfo[0], addrinfo[1]):
                self.success = True
                break
            if self.verbose and self.connect_to is None:
                print("Could not connect to %s" % addrinfo[0][0])
                print("Trying another IP address")

        # we could not establish a connection
        if not self.success:
            # we tried all the resolved IPs
            if self.verbose and self.connect_to is None:
                print("No other IP address")

            error = "Could not connect to \"%s\"" % self.server
            if self.connect_to is not None:
                error += " on %s" % self.connect_to
            raise homer.ConnectionDOTException(error)

    def establish_session(self, addr, sock_family):
        """Return True if a TLS session is established."""

        self.hasher = hashlib.sha256()

        # start the timer
        signal.alarm(homer.TIMEOUT_CONN)

        self.sock = socket.socket(sock_family, socket.SOCK_STREAM)

        if self.verbose:
            print("Connecting to %s ..." % addr[0])

        # With typical DoT servers, we *must* use TLS 1.2 (otherwise,
        # do_handshake fails with "OpenSSL.SSL.SysCallError: (-1, 'Unexpected
        # EOF')" Typical HTTP servers are more lax.
        self.context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD)
        if self.insecure:
            self.context.set_verify(OpenSSL.SSL.VERIFY_NONE, lambda *x: True)
        else:
            self.context.set_default_verify_paths()
            self.context.set_verify_depth(4) # Seems ignored
            self.context.set_verify(OpenSSL.SSL.VERIFY_PEER | OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT | \
                                    OpenSSL.SSL.VERIFY_CLIENT_ONCE,
                                    lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
        self.session = OpenSSL.SSL.Connection(self.context, self.sock)
        if self.sni:
            self.session.set_tlsext_host_name(homer.canonicalize(self.check_name_cert).encode())

        try:
            self.session.connect((addr))
            self.session.do_handshake()
        except homer.exceptions.TimeoutConnectionError:
            if self.verbose:
                print("Timeout")
            return False
        except OSError:
            if self.verbose:
                print("Cannot connect")
            return False
        except OpenSSL.SSL.SysCallError as e:
            if self.verbose:
                print("OpenSSL error: %s" % e.args[1], file=sys.stderr)
            return False
        except OpenSSL.SSL.ZeroReturnError:
            # see #18
            if self.verbose:
                print("Error: The SSL connection has been closed (try with --nosni to avoid sending SNI ?)", file=sys.stderr)
            return False
        except OpenSSL.SSL.Error as e:
            if self.verbose:
                print("OpenSSL error: %s" % ', '.join(err[0][2] for err in e.args), file=sys.stderr)
            return False

        # RFC 7858, section 4.2 and appendix A
        self.cert = self.session.get_peer_certificate()
        self.publickey = self.cert.get_pubkey()
        if self.debug or self.key is not None:
            self.hasher.update(OpenSSL.crypto.dump_publickey(OpenSSL.crypto.FILETYPE_ASN1,
                                                  self.publickey))
            self.digest = self.hasher.digest()
            key_string = base64.standard_b64encode(self.digest).decode()
        if self.debug:
            print("Certificate #%x for \"%s\", delivered by \"%s\"" % \
                  (self.cert.get_serial_number(),
                   self.cert.get_subject().commonName,
                   self.cert.get_issuer().commonName))
            print("Public key is pin-sha256=\"%s\"" % key_string)
        if not self.insecure:
            if self.key is None:
                valid = homer.validate_hostname(self.check_name_cert, self.cert)
                if not valid:
                    error("Certificate error: \"%s\" is not in the certificate" % (self.check_name_cert), exit=False)
                    return False
            else:
                if key_string != self.key:
                    error("Key error: expected \"%s\", got \"%s\"" % (self.key, key_string), exit=False)
                    return False

        # restore the timer
        signal.alarm(0)
        # and start a new timer when pipelining requests
        if self.pipelining:
            self.sock.settimeout(homer.TIMEOUT_READ)
        return True

    def end(self):
        self.session.shutdown()
        self.session.close()

    def send_data(self, data, dump=False):
        if dump:
224
            homer.dump_data(data, 'data sent')
Alexandre's avatar
Alexandre committed
225
226
227
228
229
230
231
232
233
234
235
        length = len(data)
        self.session.send(length.to_bytes(2, byteorder='big') + data)

    def receive_data(self, dump=False):
        try:
            buf = self.session.recv(2)
        except OpenSSL.SSL.WantReadError:
            return (False, None, None)
        size = int.from_bytes(buf, byteorder='big')
        data = self.session.recv(size)
        if dump:
236
            homer.dump_data(data, 'data recv')
Alexandre's avatar
Alexandre committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        return (True, data, size)

    def send_and_receive(self, request, dump=False):
        self.send_data(request.data, dump=dump)
        rcode, data, size = self.receive_data(dump=dump)
        request.store_response(rcode, data, size)

    # this function might need to be move outside
    def do_test(self, request, synchronous=True):
        self.send_data(request.data)
        if synchronous:
            rcode, data, size = self.receive_data()
            request.store_response(rcode, data, size)
            request.check_response(self.debug)

    # should the pipelining methods be part of ConnectionDOT ?
    def pipelining_add_request(self, request):
        self.all_requests.append({'request': request, 'response': None}) # No answer yet

    def pipelining_fill_pending(self, index):
        if index < len(self.all_requests):
            request = self.all_requests[index]['request']
            id = request.message.id
            # TODO check there is no duplicate in IDs
            self.pending[id] = (False, index, request)
            self.do_test(request, synchronous = False)

    def pipelining_init_pending(self, max_in_flight):
        for i in range(0, max_in_flight):
            if i == len(self.all_requests):
                break
            self.pipelining_fill_pending(i)
        return i

    # this method might need to be moved somewhere else in order to avoid
    # calling dns.message.from_wire()
    def read_result(self, connection, requests, display_results=True):
        rcode, data, size = self.receive_data() # TODO can raise
                                                    # OpenSSL.SSL.ZeroReturnError
                                                    # if the
                                                    # connection was
                                                    # closed
        if not rcode:
            if display_results:
                print("TIMEOUT")
            return None
        # TODO remove call to dns.message (use abstraction instead)
        response = dns.message.from_wire(data)
        id = response.id
        if id not in requests:
            raise homer.PipeliningException("Received response for ID %s which is unexpected" % id)
        over, rank, request = requests[id]
        self.all_requests[rank]['response'] = (rcode, response, size)
        requests[id] = (True, rank, request)
        if display_results:
            print()
            print(response)
        # TODO a timeout if some responses are lost?
        return id

def create_handle(connection):
    def reset_opt_default(handle):
        opts = {
                pycurl.NOBODY: False,
                pycurl.POST: False,
                pycurl.POSTFIELDS: '',
                pycurl.URL: ''
               }
        for opt, value in opts.items():
            handle.setopt(opt, value)

    def prepare(handle, connection, request):
        if not connection.multistreams:
            handle.reset_opt_default(handle)
        if request.post:
            handle.setopt(pycurl.POST, True)
            handle.setopt(pycurl.POSTFIELDS, request.data)
            handle.setopt(pycurl.URL, connection.server)
        else:
            handle.setopt(pycurl.HTTPGET, True) # automatically sets CURLOPT_NOBODY to 0
            if request.head:
                handle.setopt(pycurl.NOBODY, True)
            dns_req = base64.urlsafe_b64encode(request.data).decode('UTF8').rstrip('=')
            handle.setopt(pycurl.URL, connection.server + ("?dns=%s" % dns_req))
        handle.buffer = io.BytesIO()
        handle.setopt(pycurl.WRITEDATA, handle.buffer)
        handle.request = request

    handle = pycurl.Curl()
    # Does not work if pycurl was not compiled with nghttp2 (recent Debian
    # packages are OK) https://github.com/pycurl/pycurl/issues/477
    handle.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
    if connection.debug:
        handle.setopt(pycurl.VERBOSE, True)
    if connection.insecure:
        handle.setopt(pycurl.SSL_VERIFYPEER, False)
        handle.setopt(pycurl.SSL_VERIFYHOST, False)
    if connection.forceIPv4:
        handle.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V4)
    if connection.forceIPv6:
        handle.setopt(pycurl.IPRESOLVE, pycurl.IPRESOLVE_V6)
    if connection.connect_to is not None:
339
        handle.setopt(pycurl.CONNECT_TO, ["::[%s]:%d" % (connection.connect_to, homer.PORT_DOH),])
Alexandre's avatar
Alexandre committed
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    handle.setopt(pycurl.HTTPHEADER,
            ["Accept: application/dns-message", "Content-type: application/dns-message"])
    handle.reset_opt_default = reset_opt_default
    handle.prepare = prepare
    return handle


class ConnectionDOH(Connection):

    def __init__(self, server, servername=None, connect_to=None,
                 forceIPv4=False, forceIPv6=False,
                 insecure=False, verbose=False, debug=False,
                 multistreams=False):

354
        super().__init__(server, servername=servername, connect_to=connect_to,
Alexandre's avatar
Alexandre committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
                forceIPv4=forceIPv4, forceIPv6=forceIPv6, insecure=insecure,
                verbose=verbose, debug=debug, dot=False)

        self.url = server
        self.multistreams = multistreams

        # temporary tweak to check that the ip family is coherent with
        # user choice on forced IP
        if self.connect_to:
            homer.check_ip_address(self.connect_to, forceIPv4=self.forceIPv4, forceIPv6=self.forceIPv6)

        if self.multistreams:
            self.multi = self.create_multi()
            self.all_handles = []
            self.finished = { 'http': {} }
        else:
            self.curl_handle = create_handle(self)

    def create_multi(self):
        multi = pycurl.CurlMulti()
        multi.setopt(pycurl.M_MAX_HOST_CONNECTIONS, 1)
        return multi

    def init_multi(self):
        # perform a first query alone
        # to establish the connection and hence avoid starting
        # the transfer of all the other queries simultaneously
        # query the root NS because this should not impact the resover cache
        if self.verbose:
            print("Establishing multistreams connection...")
        request = homer.create_request('.', qtype='NS', dot=False)
        self.do_test(request, synchronous=False)
        self.perform_multi(silent=True, display_results=False, show_time=False)
        self.all_handles = []
        self.finished = { 'http': {} }

    def end(self):
        if not self.multistreams:
            self.curl_handle.close()
        else:
            self.remove_handles()
            self.multi.close()

    def remove_handles(self):
        n, handle_success, handle_fail = self.multi.info_read()
        handles = handle_success + handle_fail
        for h in handles:
            h.close()
            self.multi.remove_handle(h)

    def perform_multi(self, silent=False, display_results=True, show_time=False):
        while 1:
            ret, num_handles = self.multi.perform()
            if ret != pycurl.E_CALL_MULTI_PERFORM:
                break
        while num_handles:
            ret = self.multi.select(1.0)
            if ret == -1:
                continue
            while 1:
                ret, num_handles = self.multi.perform()
                n, handle_pass, handle_fail = self.multi.info_read()
                for handle in handle_pass:
                    self.read_result_handle(handle, silent=silent, display_results=display_results, show_time=show_time)
                if ret != pycurl.E_CALL_MULTI_PERFORM:
                    break
        n, handle_pass, handle_fail = self.multi.info_read()
        for handle in handle_pass:
            self.read_result_handle(handle, silent=silent, display_results=display_results, show_time=show_time)

    def send(self, handle):
        handle.buffer = io.BytesIO()
        handle.setopt(pycurl.WRITEDATA, handle.buffer)
        try:
            handle.perform()
        except pycurl.error as e:
            raise homer.DOHException(e.args[1])

    def receive(self, handle):
        request = handle.request
        body = handle.buffer.getvalue()
        body_size = len(body)
        http_code = handle.getinfo(pycurl.RESPONSE_CODE)
        handle.time = handle.getinfo(pycurl.TOTAL_TIME)
        handle.pretime = handle.getinfo(pycurl.PRETRANSFER_TIME)
        try:
            content_type = handle.getinfo(pycurl.CONTENT_TYPE)
        except TypeError: # This is the exception we get if there is no Content-Type: (for intance in response to HEAD requests)
            content_type = None
        request.response = body
        request.response_size = body_size
        request.rcode = http_code
        request.ctype = content_type
        handle.buffer.close()

    def send_and_receive(self, handle, dump=False):
        self.send(handle)
        self.receive(handle)

    def read_result_handle(self, handle, silent=False, display_results=True, show_time=False):
        self.receive(handle)
        handle.request.check_response()
        if not silent and show_time:
            self.print_time(handle)
        try:
            self.finished['http'][handle.request.rcode] += 1
        except KeyError:
            self.finished['http'][handle.request.rcode] = 1
        if not silent and display_results:
            print("Return code %s (%.2f ms):" % (handle.request.rcode,
                (handle.time - handle.pretime) * 1000))
            print(f"{handle.request.response}\n")
        handle.close()
        self.multi.remove_handle(handle)

    def read_results(self, display_results=True, show_time=False):
        for handle in self.all_handles:
            self.read_result_handle(handle, display_results=display_results, show_time=show_time)

    def print_time(self, handle):
        print(f'{handle.request.i:3d}', end='   ')
        print(f'({handle.request.rcode})', end='   ')
        print(f'{handle.pretime * 1000:8.3f} ms', end='  ')
        print(f'{handle.time * 1000:8.3f} ms', end='  ')
        print(f'{(handle.time - handle.pretime) * 1000:8.3f} ms')

    def do_test(self, request, synchronous=True):
        if synchronous:
            handle = self.curl_handle
        else:
            handle = create_handle(self)
            self.all_handles.append(handle)
        handle.prepare(handle, self, request)
        if synchronous:
            self.send_and_receive(handle)
            request.check_response(self.debug)
        else:
            self.multi.add_handle(handle)