mirror of
https://github.com/apache/impala.git
synced 2025-12-23 11:55:25 -05:00
IMPALA_THRIFT_PY_VERSION is also bumped to 0.16.0p3. As 0.16.0p3 Thrift does not contain Python related patches and Impyla 0.18.0 depends on Thrift 0.16.0, now we are consistently using Thrift 0.16.0 in all Python code. This also bumps the Thrift in the shell's ext-py directory to 0.16.0 (based on the Thrift 0.16.0 pypi tarball with the egg directory removed). Testing: - Ran a GVO job Change-Id: I7265558b0e07959c606cba73cd251c3edfcb3ed5 Reviewed-on: http://gerrit.cloudera.org:8080/18456 Reviewed-by: Michael Smith <michael.smith@cloudera.com> Reviewed-by: Wenzhe Zhou <wzhou@cloudera.com> Tested-by: Impala Public Jenkins <impala-public-jenkins@cloudera.com> Reviewed-by: Joe McDonnell <joemcdonnell@cloudera.com>
354 lines
15 KiB
Python
354 lines
15 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
#
|
|
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import platform
|
|
import ssl
|
|
import sys
|
|
import tempfile
|
|
import threading
|
|
import unittest
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
|
|
import _import_local_thrift # noqa
|
|
|
|
SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
|
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
|
|
SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
|
|
SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
|
|
SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
|
|
CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
|
|
CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
|
|
CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
|
|
CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
|
|
CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
|
|
|
|
TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256'
|
|
|
|
|
|
class ServerAcceptor(threading.Thread):
|
|
def __init__(self, server, expect_failure=False):
|
|
super(ServerAcceptor, self).__init__()
|
|
self.daemon = True
|
|
self._server = server
|
|
self._listening = threading.Event()
|
|
self._port = None
|
|
self._port_bound = threading.Event()
|
|
self._client = None
|
|
self._client_accepted = threading.Event()
|
|
self._expect_failure = expect_failure
|
|
frame = inspect.stack(3)[2]
|
|
self.name = frame[3]
|
|
del frame
|
|
|
|
def run(self):
|
|
self._server.listen()
|
|
self._listening.set()
|
|
|
|
try:
|
|
address = self._server.handle.getsockname()
|
|
if len(address) > 1:
|
|
# AF_INET addresses are 2-tuples (host, port) and AF_INET6 are
|
|
# 4-tuples (host, port, ...), but in each case port is in the second slot.
|
|
self._port = address[1]
|
|
finally:
|
|
self._port_bound.set()
|
|
|
|
try:
|
|
self._client = self._server.accept()
|
|
if self._client:
|
|
self._client.read(5) # hello
|
|
self._client.write(b"there")
|
|
except Exception:
|
|
logging.exception('error on server side (%s):' % self.name)
|
|
if not self._expect_failure:
|
|
raise
|
|
finally:
|
|
self._client_accepted.set()
|
|
|
|
def await_listening(self):
|
|
self._listening.wait()
|
|
|
|
@property
|
|
def port(self):
|
|
self._port_bound.wait()
|
|
return self._port
|
|
|
|
@property
|
|
def client(self):
|
|
self._client_accepted.wait()
|
|
return self._client
|
|
|
|
def close(self):
|
|
if self._client:
|
|
self._client.close()
|
|
self._server.close()
|
|
|
|
|
|
# Python 2.6 compat
|
|
class AssertRaises(object):
|
|
def __init__(self, expected):
|
|
self._expected = expected
|
|
|
|
def __enter__(self):
|
|
pass
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
if not exc_type or not issubclass(exc_type, self._expected):
|
|
raise Exception('fail')
|
|
return True
|
|
|
|
|
|
class TSSLSocketTest(unittest.TestCase):
|
|
def _server_socket(self, **kwargs):
|
|
return TSSLServerSocket(port=0, **kwargs)
|
|
|
|
@contextmanager
|
|
def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs):
|
|
acc = ServerAcceptor(server, expect_failure)
|
|
try:
|
|
acc.start()
|
|
acc.await_listening()
|
|
|
|
host, port = ('localhost', acc.port) if path is None else (None, None)
|
|
client = TSSLSocket(host, port, unix_socket=path, **client_kwargs)
|
|
yield acc, client
|
|
finally:
|
|
acc.close()
|
|
|
|
def _assert_connection_failure(self, server, path=None, **client_args):
|
|
logging.disable(logging.CRITICAL)
|
|
try:
|
|
with self._connectable_client(server, True, path=path, **client_args) as (acc, client):
|
|
# We need to wait for a connection failure, but not too long. 20ms is a tunable
|
|
# compromise between test speed and stability
|
|
client.setTimeout(20)
|
|
with self._assert_raises(TTransportException):
|
|
client.open()
|
|
client.write(b"hello")
|
|
client.read(5) # b"there"
|
|
finally:
|
|
logging.disable(logging.NOTSET)
|
|
|
|
def _assert_raises(self, exc):
|
|
if sys.hexversion >= 0x020700F0:
|
|
return self.assertRaises(exc)
|
|
else:
|
|
return AssertRaises(exc)
|
|
|
|
def _assert_connection_success(self, server, path=None, **client_args):
|
|
with self._connectable_client(server, path=path, **client_args) as (acc, client):
|
|
try:
|
|
client.open()
|
|
client.write(b"hello")
|
|
self.assertEqual(client.read(5), b"there")
|
|
self.assertTrue(acc.client is not None)
|
|
finally:
|
|
client.close()
|
|
|
|
# deprecated feature
|
|
def test_deprecation(self):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
|
|
TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
|
|
self.assertEqual(len(w), 1)
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
|
|
# Deprecated signature
|
|
# def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
|
|
TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
|
|
self.assertEqual(len(w), 7)
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
|
|
# Deprecated signature
|
|
# def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
|
|
TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS)
|
|
self.assertEqual(len(w), 3)
|
|
|
|
# deprecated feature
|
|
def test_set_cert_reqs_by_validate(self):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
|
|
c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
|
|
self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
|
|
|
|
c1 = TSSLSocket('localhost', 0, validate=False)
|
|
self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
|
|
|
|
self.assertEqual(len(w), 2)
|
|
|
|
# deprecated feature
|
|
def test_set_validate_by_cert_reqs(self):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
|
|
c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE)
|
|
self.assertFalse(c1.validate)
|
|
|
|
c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
|
|
self.assertTrue(c2.validate)
|
|
|
|
c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
|
|
self.assertTrue(c3.validate)
|
|
|
|
self.assertEqual(len(w), 3)
|
|
|
|
def test_unix_domain_socket(self):
|
|
if platform.system() == 'Windows':
|
|
print('skipping test_unix_domain_socket')
|
|
return
|
|
fd, path = tempfile.mkstemp()
|
|
os.close(fd)
|
|
os.unlink(path)
|
|
try:
|
|
server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT)
|
|
self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE)
|
|
finally:
|
|
os.unlink(path)
|
|
|
|
def test_server_cert(self):
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
|
|
self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
|
|
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
|
|
# server cert not in ca_certs
|
|
self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
|
|
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
|
|
self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE)
|
|
|
|
def test_set_server_cert(self):
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT)
|
|
with self._assert_raises(Exception):
|
|
server.certfile = 'foo'
|
|
with self._assert_raises(Exception):
|
|
server.certfile = None
|
|
server.certfile = SERVER_CERT
|
|
self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
|
|
|
|
def test_client_cert(self):
|
|
if not _match_has_ipaddress:
|
|
print('skipping test_client_cert')
|
|
return
|
|
server = self._server_socket(
|
|
cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
|
|
certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
|
|
self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
|
|
|
|
server = self._server_socket(
|
|
cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
|
|
certfile=SERVER_CERT, ca_certs=CLIENT_CA)
|
|
self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
|
|
|
|
server = self._server_socket(
|
|
cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
|
|
certfile=SERVER_CERT, ca_certs=CLIENT_CA)
|
|
self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
|
|
|
|
server = self._server_socket(
|
|
cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
|
|
certfile=SERVER_CERT, ca_certs=CLIENT_CA)
|
|
self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
|
|
|
|
def test_ciphers(self):
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
|
|
self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
|
|
|
|
if not TSSLSocket._has_ciphers:
|
|
# unittest.skip is not available for Python 2.6
|
|
print('skipping test_ciphers')
|
|
return
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
|
|
self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
|
|
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
|
|
self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
|
|
|
|
def test_ssl2_and_ssl3_disabled(self):
|
|
if not hasattr(ssl, 'PROTOCOL_SSLv3'):
|
|
print('PROTOCOL_SSLv3 is not available')
|
|
else:
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
|
|
self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
|
|
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
|
|
self._assert_connection_failure(server, ca_certs=SERVER_CERT)
|
|
|
|
if not hasattr(ssl, 'PROTOCOL_SSLv2'):
|
|
print('PROTOCOL_SSLv2 is not available')
|
|
else:
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
|
|
self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
|
|
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
|
|
self._assert_connection_failure(server, ca_certs=SERVER_CERT)
|
|
|
|
def test_newer_tls(self):
|
|
if not TSSLSocket._has_ssl_context:
|
|
# unittest.skip is not available for Python 2.6
|
|
print('skipping test_newer_tls')
|
|
return
|
|
if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
|
|
print('PROTOCOL_TLSv1_2 is not available')
|
|
else:
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
|
|
self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
|
|
|
|
if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
|
|
print('PROTOCOL_TLSv1_1 is not available')
|
|
else:
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
|
|
self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
|
|
|
|
if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
|
|
print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
|
|
else:
|
|
server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
|
|
self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
|
|
|
|
def test_ssl_context(self):
|
|
if not TSSLSocket._has_ssl_context:
|
|
# unittest.skip is not available for Python 2.6
|
|
print('skipping test_ssl_context')
|
|
return
|
|
server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
|
server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
|
|
server_context.load_verify_locations(CLIENT_CA)
|
|
server_context.verify_mode = ssl.CERT_REQUIRED
|
|
server = self._server_socket(ssl_context=server_context)
|
|
|
|
client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
|
client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
|
|
client_context.load_verify_locations(SERVER_CERT)
|
|
client_context.verify_mode = ssl.CERT_REQUIRED
|
|
|
|
self._assert_connection_success(server, ssl_context=client_context)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
logging.basicConfig(level=logging.WARN)
|
|
from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress
|
|
from thrift.transport.TTransport import TTransportException
|
|
|
|
unittest.main()
|