IMPALA-13123: Add option to run tests with Python 3

This introduces the IMPALA_USE_PYTHON3_TESTS environment variable
to select whether to run tests using the toolchain Python 3.
This is an experimental option, so it defaults to false,
continuing to run tests with Python 2.

This fixes a first batch of Python 2 vs 3 issues:
 - Deciding whether to open a file in bytes mode or text mode
 - Adapting to APIs that operate on bytes in Python 3 (e.g. codecs)
 - Eliminating 'basestring' and 'unicode' locations in tests/ by using
   the recommendations from future
   ( https://python-future.org/compatible_idioms.html#basestring and
     https://python-future.org/compatible_idioms.html#unicode )
 - Uses impala-python3 for bin/start-impala-cluster.py

All fixes leave the Python 2 path working normally.

Testing:
 - Ran an exhaustive run with Python 2 to verify nothing broke
 - Verified that the new environment variable works and that
   it uses Python 3 from the toolchain when specified

Change-Id: I177d9b8eae9b99ba536ca5c598b07208c3887f8c
Reviewed-on: http://gerrit.cloudera.org:8080/21474
Reviewed-by: Michael Smith <michael.smith@cloudera.com>
Reviewed-by: Riza Suminto <riza.suminto@cloudera.com>
Tested-by: Joe McDonnell <joemcdonnell@cloudera.com>
This commit is contained in:
Joe McDonnell
2023-04-04 10:12:12 -07:00
parent 4ba6f9b5a5
commit 8d5adfd0ba
29 changed files with 167 additions and 82 deletions

View File

@@ -325,6 +325,8 @@ export IMPALA_KERBERIZE=false
unset IMPALA_TOOLCHAIN_KUDU_MAVEN_REPOSITORY unset IMPALA_TOOLCHAIN_KUDU_MAVEN_REPOSITORY
unset IMPALA_TOOLCHAIN_KUDU_MAVEN_REPOSITORY_ENABLED unset IMPALA_TOOLCHAIN_KUDU_MAVEN_REPOSITORY_ENABLED
export IMPALA_USE_PYTHON3_TESTS=${IMPALA_USE_PYTHON3_TESTS:-false}
# Source the branch and local config override files here to override any # Source the branch and local config override files here to override any
# variables above or any variables below that allow overriding via environment # variables above or any variables below that allow overriding via environment
# variable. # variable.

26
bin/impala-env-versioned-python Executable file
View File

@@ -0,0 +1,26 @@
#!/bin/bash
#
##############################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
##############################################################################
if [[ "${IMPALA_USE_PYTHON3_TESTS}" == "true" ]]; then
exec impala-python3 "$@"
else
exec impala-python "$@"
fi

View File

@@ -19,5 +19,9 @@
# under the License. # under the License.
############################################################################## ##############################################################################
source $(dirname "$0")/impala-python-common.sh if [[ "${IMPALA_USE_PYTHON3_TESTS}" == "true" ]]; then
source $(dirname "$0")/impala-python3-common.sh
else
source $(dirname "$0")/impala-python-common.sh
fi
exec "$PY_ENV_DIR/bin/py.test" "$@" exec "$PY_ENV_DIR/bin/py.test" "$@"

View File

@@ -1,4 +1,4 @@
#!/usr/bin/env impala-python #!/usr/bin/env impala-python3
# #
# Licensed to the Apache Software Foundation (ASF) under one # Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file # or more contributor license agreements. See the NOTICE file

View File

@@ -37,7 +37,8 @@ def exec_local_command(cmd):
Return: Return:
STDOUT STDOUT
""" """
proc = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) proc = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE,
universal_newlines=True)
output, error = proc.communicate() output, error = proc.communicate()
retcode = proc.poll() retcode = proc.poll()
if retcode: if retcode:

View File

@@ -94,8 +94,9 @@
# This should be used sparingly, because these commands are executed # This should be used sparingly, because these commands are executed
# serially. # serially.
# #
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function, unicode_literals
from builtins import object from builtins import object
import io
import json import json
import os import os
import re import re
@@ -719,7 +720,7 @@ class Statements(object):
# If there is no content to write, skip # If there is no content to write, skip
if not self: return if not self: return
output = self.create + self.load_base + self.load output = self.create + self.load_base + self.load
with open(filename, 'w') as f: with io.open(filename, 'w', encoding='utf-8') as f:
f.write('\n\n'.join(output)) f.write('\n\n'.join(output))
def __bool__(self): def __bool__(self):
@@ -734,7 +735,8 @@ def eval_section(section_str):
cmd = section_str[1:] cmd = section_str[1:]
# Use bash explicitly instead of setting shell=True so we get more advanced shell # Use bash explicitly instead of setting shell=True so we get more advanced shell
# features (e.g. "for i in {1..n}") # features (e.g. "for i in {1..n}")
p = subprocess.Popen(['/bin/bash', '-c', cmd], stdout=subprocess.PIPE) p = subprocess.Popen(['/bin/bash', '-c', cmd], stdout=subprocess.PIPE,
universal_newlines=True)
stdout, stderr = p.communicate() stdout, stderr = p.communicate()
if stderr: print(stderr) if stderr: print(stderr)
assert p.returncode == 0 assert p.returncode == 0

View File

@@ -32,6 +32,7 @@ import time
import shlex import shlex
import getpass import getpass
import re import re
import sys
from beeswaxd import BeeswaxService from beeswaxd import BeeswaxService
from beeswaxd.BeeswaxService import QueryState from beeswaxd.BeeswaxService import QueryState
@@ -419,6 +420,10 @@ class ImpalaBeeswaxClient(object):
return exec_result return exec_result
def __get_query_type(self, query_string): def __get_query_type(self, query_string):
# Python 2's shlex does not work if the query string contains Unicode characters.
# Convert to bytes.
if sys.version_info.major == 2:
query_string = query_string.encode('utf-8')
# Set posix=True and add "'" to escaped quotes # Set posix=True and add "'" to escaped quotes
# to deal with escaped quotes in string literals # to deal with escaped quotes in string literals
lexer = shlex.shlex(query_string.lstrip(), posix=True) lexer = shlex.shlex(query_string.lstrip(), posix=True)

View File

@@ -54,14 +54,16 @@ MAX_SQL_LOGGING_LENGTH = 128 * 1024
def log_sql_stmt(sql_stmt): def log_sql_stmt(sql_stmt):
"""If the 'sql_stmt' is shorter than MAX_SQL_LOGGING_LENGTH, log it unchanged. If """If the 'sql_stmt' is shorter than MAX_SQL_LOGGING_LENGTH, log it unchanged. If
it is larger than MAX_SQL_LOGGING_LENGTH, truncate it and comment it out.""" it is larger than MAX_SQL_LOGGING_LENGTH, truncate it and comment it out."""
# sql_stmt could contain Unicode characters, so explicitly use unicode literals
# so that Python 2 works.
if (len(sql_stmt) <= MAX_SQL_LOGGING_LENGTH): if (len(sql_stmt) <= MAX_SQL_LOGGING_LENGTH):
LOG.info("{0};\n".format(sql_stmt)) LOG.info(u"{0};\n".format(sql_stmt))
else: else:
# The logging output should be valid SQL, so the truncated SQL is commented out. # The logging output should be valid SQL, so the truncated SQL is commented out.
LOG.info("-- Skip logging full SQL statement of length {0}".format(len(sql_stmt))) LOG.info("-- Skip logging full SQL statement of length {0}".format(len(sql_stmt)))
LOG.info("-- Logging a truncated version, commented out:") LOG.info("-- Logging a truncated version, commented out:")
for line in sql_stmt[0:MAX_SQL_LOGGING_LENGTH].split("\n"): for line in sql_stmt[0:MAX_SQL_LOGGING_LENGTH].split("\n"):
LOG.info("-- {0}".format(line)) LOG.info(u"-- {0}".format(line))
LOG.info("-- [...]") LOG.info("-- [...]")
@@ -398,8 +400,10 @@ class ImpylaHS2Connection(ImpalaConnection):
"""Return the string representation of the query id.""" """Return the string representation of the query id."""
guid_bytes = \ guid_bytes = \
operation_handle.get_handle()._last_operation.handle.operationId.guid operation_handle.get_handle()._last_operation.handle.operationId.guid
return "{0}:{1}".format(codecs.encode(guid_bytes[7::-1], 'hex_codec'), # hex_codec works on bytes, so this needs to a decode() to get back to a string
codecs.encode(guid_bytes[16:7:-1], 'hex_codec')) hi_str = codecs.encode(guid_bytes[7::-1], 'hex_codec').decode()
lo_str = codecs.encode(guid_bytes[16:7:-1], 'hex_codec').decode()
return "{0}:{1}".format(hi_str, lo_str)
def get_state(self, operation_handle): def get_state(self, operation_handle):
LOG.info("-- getting state for operation: {0}".format(operation_handle)) LOG.info("-- getting state for operation: {0}".format(operation_handle))

View File

@@ -1194,6 +1194,7 @@ class ImpalaTestSuite(BaseTestSuite):
# read to avoid hanging, especially when running interactively # read to avoid hanging, especially when running interactively
# with py.test. # with py.test.
stdin=open("/dev/null"), stdin=open("/dev/null"),
universal_newlines=True,
env=env) env=env)
(stdout, stderr) = call.communicate() (stdout, stderr) = call.communicate()
call.wait() call.wait()
@@ -1456,8 +1457,14 @@ class ImpalaTestSuite(BaseTestSuite):
found = 0 found = 0
log_file_path = self.__build_log_path(daemon, level) log_file_path = self.__build_log_path(daemon, level)
last_re_result = None last_re_result = None
with open(log_file_path) as log_file: with open(log_file_path, 'rb') as log_file:
for line in log_file: for line in log_file:
# The logs could contain invalid unicode (and end-to-end tests don't control
# the logs from other tests). Skip lines with invalid unicode.
try:
line = line.decode()
except UnicodeDecodeError:
continue
re_result = pattern.search(line) re_result = pattern.search(line)
if re_result: if re_result:
found += 1 found += 1

View File

@@ -83,7 +83,9 @@ class ResourcePoolConfig(object):
# Make sure the change to the file is atomic. Write to a temp file and replace the # Make sure the change to the file is atomic. Write to a temp file and replace the
# original with it. # original with it.
temp_path = file_name + "-temp" temp_path = file_name + "-temp"
file_handle = open(temp_path, "w") # ElementTree.tostring produces a bytestring on Python 3, so open the file in
# binary mode.
file_handle = open(temp_path, "wb")
file_handle.write(ET.tostring(xml_root)) file_handle.write(ET.tostring(xml_root))
file_handle.flush() file_handle.flush()
os.fsync(file_handle.fileno()) os.fsync(file_handle.fileno())

View File

@@ -324,7 +324,7 @@ def load_table_info_dimension(workload_name, exploration_strategy, file_formats=
vector_values = [] vector_values = []
with open(test_vector_file, 'rb') as vector_file: with open(test_vector_file, 'r') as vector_file:
for line in vector_file.readlines(): for line in vector_file.readlines():
if line.strip().startswith('#'): if line.strip().startswith('#'):
continue continue

View File

@@ -19,9 +19,15 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from builtins import map, range from builtins import map, range
# Python 3 doesn't have the "unicode" type, as its regular string is Unicode. This
# replaces Python 2's unicode with future's str. On Python 3, it is the builtin string.
# On Python 2, it uses future's str implementation, which is similar to Python 3's string
# but subclasses "unicode". See https://python-future.org/compatible_idioms.html#unicode
from builtins import str as unicode_compat
import logging import logging
import math import math
import re import re
import sys
from functools import wraps from functools import wraps
from tests.util.test_file_parser import (join_section_lines, remove_comments, from tests.util.test_file_parser import (join_section_lines, remove_comments,
@@ -143,7 +149,10 @@ class ResultRow(object):
"""Allows accessing a column value using the column alias or the position of the """Allows accessing a column value using the column alias or the position of the
column in the result set. All values are returned as strings and an exception is column in the result set. All values are returned as strings and an exception is
thrown if the column label or column position does not exist.""" thrown if the column label or column position does not exist."""
if isinstance(key, basestring): # Python 2's str type won't match unicode type. This is ok, because currently the
# key is never unicode. On Python 3, str is unicode, and this would not have that
# limitation.
if isinstance(key, str):
for col in self.columns: for col in self.columns:
if col.column_label == key.lower(): return col.value if col.column_label == key.lower(): return col.value
raise IndexError('No column with label: ' + key) raise IndexError('No column with label: ' + key)
@@ -258,8 +267,8 @@ def verify_query_result_is_subset(expected_results, actual_results):
"""Check whether the results in expected_results are a subset of the results in """Check whether the results in expected_results are a subset of the results in
actual_results. This uses set semantics, i.e. any duplicates are ignored.""" actual_results. This uses set semantics, i.e. any duplicates are ignored."""
expected_literals, expected_non_literals = expected_results.separate_rows() expected_literals, expected_non_literals = expected_results.separate_rows()
expected_literal_strings = set([unicode(row) for row in expected_literals]) expected_literal_strings = set([unicode_compat(row) for row in expected_literals])
actual_literal_strings = set([unicode(row) for row in actual_results.rows]) actual_literal_strings = set([unicode_compat(row) for row in actual_results.rows])
# Expected literal strings must all be present in the actual strings. # Expected literal strings must all be present in the actual strings.
assert expected_literal_strings <= actual_literal_strings assert expected_literal_strings <= actual_literal_strings
# Expected patterns must be present in the actual strings. # Expected patterns must be present in the actual strings.
@@ -270,17 +279,17 @@ def verify_query_result_is_subset(expected_results, actual_results):
matched = True matched = True
break break
assert matched, u"Could not find expected row {0} in actual rows:\n{1}".format( assert matched, u"Could not find expected row {0} in actual rows:\n{1}".format(
unicode(expected_row), unicode(actual_results)) unicode_compat(expected_row), unicode_compat(actual_results))
def verify_query_result_is_superset(expected_results, actual_results): def verify_query_result_is_superset(expected_results, actual_results):
"""Check whether the results in expected_results are a superset of the results in """Check whether the results in expected_results are a superset of the results in
actual_results. This uses set semantics, i.e. any duplicates are ignored.""" actual_results. This uses set semantics, i.e. any duplicates are ignored."""
expected_literals, expected_non_literals = expected_results.separate_rows() expected_literals, expected_non_literals = expected_results.separate_rows()
expected_literal_strings = set([unicode(row) for row in expected_literals]) expected_literal_strings = set([unicode_compat(row) for row in expected_literals])
# Check that all actual rows are present in either expected_literal_strings or # Check that all actual rows are present in either expected_literal_strings or
# expected_non_literals. # expected_non_literals.
for actual_row in actual_results.rows: for actual_row in actual_results.rows:
if unicode(actual_row) in expected_literal_strings: if unicode_compat(actual_row) in expected_literal_strings:
# Matched to a literal string # Matched to a literal string
continue continue
matched = False matched = False
@@ -289,7 +298,7 @@ def verify_query_result_is_superset(expected_results, actual_results):
matched = True matched = True
break break
assert matched, u"Could not find actual row {0} in expected rows:\n{1}".format( assert matched, u"Could not find actual row {0} in expected rows:\n{1}".format(
unicode(actual_row), unicode(expected_results)) unicode_compat(actual_row), unicode_compat(expected_results))
def verify_query_result_is_equal(expected_results, actual_results): def verify_query_result_is_equal(expected_results, actual_results):
assert_args_not_none(expected_results, actual_results) assert_args_not_none(expected_results, actual_results)
@@ -301,8 +310,8 @@ def verify_query_result_is_not_in(banned_results, actual_results):
banned_literals, banned_non_literals = banned_results.separate_rows() banned_literals, banned_non_literals = banned_results.separate_rows()
# Part 1: No intersection with the banned literals # Part 1: No intersection with the banned literals
banned_literals_set = set([unicode(row) for row in banned_literals]) banned_literals_set = set([unicode_compat(row) for row in banned_literals])
actual_set = set(map(unicode, actual_results.rows)) actual_set = set(map(unicode_compat, actual_results.rows))
assert banned_literals_set.isdisjoint(actual_set) assert banned_literals_set.isdisjoint(actual_set)
# Part 2: Walk through each banned non-literal / regex and make sure that no row # Part 2: Walk through each banned non-literal / regex and make sure that no row
@@ -315,7 +324,7 @@ def verify_query_result_is_not_in(banned_results, actual_results):
matched = True matched = True
break break
assert not matched, u"Found banned row {0} in actual rows:\n{1}".format( assert not matched, u"Found banned row {0} in actual rows:\n{1}".format(
unicode(banned_row), unicode(actual_results)) unicode_compat(banned_row), unicode_compat(actual_results))
# Global dictionary that maps the verification type to appropriate verifier. # Global dictionary that maps the verification type to appropriate verifier.
# The RESULTS section of a .test file is tagged with the verifier type. We may # The RESULTS section of a .test file is tagged with the verifier type. We may
@@ -391,7 +400,7 @@ def verify_raw_results(test_section, exec_result, file_format, result_section,
expected_results = None expected_results = None
if result_section in test_section: if result_section in test_section:
expected_results = remove_comments(test_section[result_section]) expected_results = remove_comments(test_section[result_section])
if isinstance(expected_results, str): if sys.version_info.major == 2 and isinstance(expected_results, str):
# Always convert 'str' to 'unicode' since pytest will fail to report assertion # Always convert 'str' to 'unicode' since pytest will fail to report assertion
# failures when any 'str' values contain non-ascii bytes (IMPALA-10419). # failures when any 'str' values contain non-ascii bytes (IMPALA-10419).
try: try:
@@ -539,7 +548,7 @@ def parse_result_rows(exec_result, escape_strings=True):
for i in range(len(cols)): for i in range(len(cols)):
if col_types[i] in ['STRING', 'CHAR', 'VARCHAR', 'BINARY']: if col_types[i] in ['STRING', 'CHAR', 'VARCHAR', 'BINARY']:
col = cols[i] col = cols[i]
if isinstance(col, str): if sys.version_info.major == 2 and isinstance(col, str):
try: try:
col = col.decode('utf-8') col = col.decode('utf-8')
except UnicodeDecodeError as e: except UnicodeDecodeError as e:

View File

@@ -18,7 +18,7 @@
# py.test configuration module # py.test configuration module
# #
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from builtins import map, range from builtins import map, range, zip
from impala.dbapi import connect as impala_connect from impala.dbapi import connect as impala_connect
from kudu import connect as kudu_connect from kudu import connect as kudu_connect
from random import choice, sample from random import choice, sample
@@ -196,7 +196,7 @@ def pytest_assertrepr_compare(op, left, right):
if isinstance(left, QueryTestResult) and isinstance(right, QueryTestResult) and \ if isinstance(left, QueryTestResult) and isinstance(right, QueryTestResult) and \
op == "==": op == "==":
result = ['Comparing QueryTestResults (expected vs actual):'] result = ['Comparing QueryTestResults (expected vs actual):']
for l, r in map(None, left.rows, right.rows): for l, r in zip(left.rows, right.rows):
result.append("%s == %s" % (l, r) if l == r else "%s != %s" % (l, r)) result.append("%s == %s" % (l, r) if l == r else "%s != %s" % (l, r))
if len(left.rows) != len(right.rows): if len(left.rows) != len(right.rows):
result.append('Number of rows returned (expected vs actual): ' result.append('Number of rows returned (expected vs actual): '
@@ -291,7 +291,7 @@ def testid_checksum(request):
# "'abort_on_error': 1, 'exec_single_node_rows_threshold': 0, 'batch_size': 0, " # "'abort_on_error': 1, 'exec_single_node_rows_threshold': 0, 'batch_size': 0, "
# "'num_nodes': 0} | query_type: SELECT | cancel_delay: 3 | action: WAIT | " # "'num_nodes': 0} | query_type: SELECT | cancel_delay: 3 | action: WAIT | "
# "query: select l_returnflag from lineitem]") # "query: select l_returnflag from lineitem]")
return '{0:x}'.format(crc32(request.node.nodeid) & 0xffffffff) return '{0:x}'.format(crc32(request.node.nodeid.encode('utf-8')) & 0xffffffff)
@pytest.fixture @pytest.fixture

View File

@@ -96,7 +96,7 @@ class TestParquetMaxPageHeader(CustomClusterTestSuite):
random_text2 = "".join([random.choice(string.ascii_letters) random_text2 = "".join([random.choice(string.ascii_letters)
for i in range(self.MAX_STRING_LENGTH)]) for i in range(self.MAX_STRING_LENGTH)])
put = subprocess.Popen(["hdfs", "dfs", "-put", "-d", "-f", "-", file_name], put = subprocess.Popen(["hdfs", "dfs", "-put", "-d", "-f", "-", file_name],
stdin=subprocess.PIPE, bufsize=-1) stdin=subprocess.PIPE, bufsize=-1, universal_newlines=True)
put.stdin.write(random_text1 + "\n") put.stdin.write(random_text1 + "\n")
put.stdin.write(random_text2) put.stdin.write(random_text2)
put.stdin.close() put.stdin.close()

View File

@@ -921,7 +921,9 @@ class TestQueryLogTableAll(TestQueryLogTableBase):
sqls["select 1"] = True sqls["select 1"] = True
control_queries_count = 0 control_queries_count = 0
for sql, experiment_control in sqls.items(): # Note: This needs to iterate over a copy of sqls.items(), because it modifies
# sqls as it iterates.
for sql, experiment_control in list(sqls.items()):
results = client.execute(sql) results = client.execute(sql)
assert results.success, "could not execute query '{0}'".format(sql) assert results.success, "could not execute query '{0}'".format(sql)
sqls[sql] = results.query_id sqls[sql] = results.query_id

View File

@@ -19,7 +19,7 @@
# #
# Tests Impala properly handles errors when reading and writing data. # Tests Impala properly handles errors when reading and writing data.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function, unicode_literals
import pytest import pytest
import subprocess import subprocess

View File

@@ -28,6 +28,7 @@ from thrift.protocol import TBinaryProtocol
from tests.common.impala_test_suite import ImpalaTestSuite, IMPALAD_HS2_HOST_PORT from tests.common.impala_test_suite import ImpalaTestSuite, IMPALAD_HS2_HOST_PORT
from tests.common.test_result_verifier import error_msg_expected from tests.common.test_result_verifier import error_msg_expected
from time import sleep, time from time import sleep, time
import sys
def add_session_helper(self, protocol_version, conf_overlay, close_session, fn): def add_session_helper(self, protocol_version, conf_overlay, close_session, fn):
@@ -91,8 +92,11 @@ def needs_session_cluster_properties(protocol_version=
def operation_id_to_query_id(operation_id): def operation_id_to_query_id(operation_id):
lo, hi = operation_id.guid[:8], operation_id.guid[8:] lo, hi = operation_id.guid[:8], operation_id.guid[8:]
lo = ''.join(['%0.2X' % ord(c) for c in lo[::-1]]) if sys.version_info.major < 3:
hi = ''.join(['%0.2X' % ord(c) for c in hi[::-1]]) lo = [ord(x) for x in lo]
hi = [ord(x) for x in hi]
lo = ''.join(['%0.2X' % c for c in lo[::-1]])
hi = ''.join(['%0.2X' % c for c in hi[::-1]])
return "%s:%s" % (lo, hi) return "%s:%s" % (lo, hi)
@@ -100,13 +104,13 @@ def create_session_handle_without_secret(session_handle):
"""Create a HS2 session handle with the same session ID as 'session_handle' but a """Create a HS2 session handle with the same session ID as 'session_handle' but a
bogus secret of the right length, i.e. 16 bytes.""" bogus secret of the right length, i.e. 16 bytes."""
return TCLIService.TSessionHandle(TCLIService.THandleIdentifier( return TCLIService.TSessionHandle(TCLIService.THandleIdentifier(
session_handle.sessionId.guid, r"xxxxxxxxxxxxxxxx")) session_handle.sessionId.guid, b"xxxxxxxxxxxxxxxx"))
def create_op_handle_without_secret(op_handle): def create_op_handle_without_secret(op_handle):
"""Create a HS2 operation handle with same parameters as 'op_handle' but with a bogus """Create a HS2 operation handle with same parameters as 'op_handle' but with a bogus
secret of the right length, i.e. 16 bytes.""" secret of the right length, i.e. 16 bytes."""
op_id = TCLIService.THandleIdentifier(op_handle.operationId.guid, r"xxxxxxxxxxxxxxxx") op_id = TCLIService.THandleIdentifier(op_handle.operationId.guid, b"xxxxxxxxxxxxxxxx")
return TCLIService.TOperationHandle( return TCLIService.TOperationHandle(
op_id, op_handle.operationType, op_handle.hasResultSet) op_id, op_handle.operationType, op_handle.hasResultSet)
@@ -290,11 +294,16 @@ class HS2TestSuite(ImpalaTestSuite):
for col_type in HS2TestSuite.HS2_V6_COLUMN_TYPES: for col_type in HS2TestSuite.HS2_V6_COLUMN_TYPES:
typed_col = getattr(c, col_type) typed_col = getattr(c, col_type)
if typed_col != None: if typed_col != None:
indicator = ord(typed_col.nulls[i // 8]) indicator = typed_col.nulls[i // 8]
if sys.version_info.major < 3:
indicator = ord(indicator)
if indicator & (1 << (i % 8)): if indicator & (1 << (i % 8)):
row.append("NULL") row.append("NULL")
else: else:
row.append(str(typed_col.values[i])) if isinstance(typed_col.values[i], bytes):
row.append(typed_col.values[i].decode())
else:
row.append(str(typed_col.values[i]))
break break
formatted += (", ".join(row) + "\n") formatted += (", ".join(row) + "\n")
return (num_rows, formatted) return (num_rows, formatted)

View File

@@ -465,8 +465,8 @@ class TestHS2(HS2TestSuite):
impalad)""" impalad)"""
operation_handle = TCLIService.TOperationHandle() operation_handle = TCLIService.TOperationHandle()
operation_handle.operationId = TCLIService.THandleIdentifier() operation_handle.operationId = TCLIService.THandleIdentifier()
operation_handle.operationId.guid = "short" operation_handle.operationId.guid = b"short"
operation_handle.operationId.secret = "short_secret" operation_handle.operationId.secret = b"short_secret"
assert len(operation_handle.operationId.guid) != 16 assert len(operation_handle.operationId.guid) != 16
assert len(operation_handle.operationId.secret) != 16 assert len(operation_handle.operationId.secret) != 16
operation_handle.operationType = TCLIService.TOperationType.EXECUTE_STATEMENT operation_handle.operationType = TCLIService.TOperationType.EXECUTE_STATEMENT
@@ -485,8 +485,8 @@ class TestHS2(HS2TestSuite):
def test_invalid_query_handle(self): def test_invalid_query_handle(self):
operation_handle = TCLIService.TOperationHandle() operation_handle = TCLIService.TOperationHandle()
operation_handle.operationId = TCLIService.THandleIdentifier() operation_handle.operationId = TCLIService.THandleIdentifier()
operation_handle.operationId.guid = "\x01\x23\x45\x67\x89\xab\xcd\xef76543210" operation_handle.operationId.guid = b"\x01\x23\x45\x67\x89\xab\xcd\xef76543210"
operation_handle.operationId.secret = "PasswordIsPencil" operation_handle.operationId.secret = b"PasswordIsPencil"
operation_handle.operationType = TCLIService.TOperationType.EXECUTE_STATEMENT operation_handle.operationType = TCLIService.TOperationType.EXECUTE_STATEMENT
operation_handle.hasResultSet = False operation_handle.hasResultSet = False

View File

@@ -682,8 +682,8 @@ class TestHdfsParquetTableStatsWriter(ImpalaTestSuite):
ColumnStats('bigint_col', 0, 90, 0), ColumnStats('bigint_col', 0, 90, 0),
ColumnStats('float_col', 0, RoundFloat(9.9, 1), 0), ColumnStats('float_col', 0, RoundFloat(9.9, 1), 0),
ColumnStats('double_col', 0, RoundFloat(90.9, 1), 0), ColumnStats('double_col', 0, RoundFloat(90.9, 1), 0),
ColumnStats('date_string_col', '01/01/09', '12/31/10', 0), ColumnStats('date_string_col', b'01/01/09', b'12/31/10', 0),
ColumnStats('string_col', '0', '9', 0), ColumnStats('string_col', b'0', b'9', 0),
ColumnStats('timestamp_col', TimeStamp('2009-01-01 00:00:00.0'), ColumnStats('timestamp_col', TimeStamp('2009-01-01 00:00:00.0'),
TimeStamp('2010-12-31 05:09:13.860000'), 0), TimeStamp('2010-12-31 05:09:13.860000'), 0),
ColumnStats('year', 2009, 2010, 0), ColumnStats('year', 2009, 2010, 0),
@@ -732,15 +732,15 @@ class TestHdfsParquetTableStatsWriter(ImpalaTestSuite):
# Expected values for tpch_parquet.customer # Expected values for tpch_parquet.customer
expected_min_max_values = [ expected_min_max_values = [
ColumnStats('c_custkey', 1, 150000, 0), ColumnStats('c_custkey', 1, 150000, 0),
ColumnStats('c_name', 'Customer#000000001', 'Customer#000150000', 0), ColumnStats('c_name', b'Customer#000000001', b'Customer#000150000', 0),
ColumnStats('c_address', ' 2uZwVhQvwA', 'zzxGktzXTMKS1BxZlgQ9nqQ', 0), ColumnStats('c_address', b' 2uZwVhQvwA', b'zzxGktzXTMKS1BxZlgQ9nqQ', 0),
ColumnStats('c_nationkey', 0, 24, 0), ColumnStats('c_nationkey', 0, 24, 0),
ColumnStats('c_phone', '10-100-106-1617', '34-999-618-6881', 0), ColumnStats('c_phone', b'10-100-106-1617', b'34-999-618-6881', 0),
ColumnStats('c_acctbal', Decimal('-999.99'), Decimal('9999.99'), 0), ColumnStats('c_acctbal', Decimal('-999.99'), Decimal('9999.99'), 0),
ColumnStats('c_mktsegment', 'AUTOMOBILE', 'MACHINERY', 0), ColumnStats('c_mktsegment', b'AUTOMOBILE', b'MACHINERY', 0),
ColumnStats('c_comment', ' Tiresias according to the slyly blithe instructions ' ColumnStats('c_comment', b' Tiresias according to the slyly blithe instructions '
'detect quickly at the slyly express courts. express dinos wake ', b'detect quickly at the slyly express courts. express dinos wake ',
'zzle. blithely regular instructions cajol', 0), b'zzle. blithely regular instructions cajol', 0),
] ]
self._ctas_table_and_verify_stats(vector, unique_database, tmpdir.strpath, self._ctas_table_and_verify_stats(vector, unique_database, tmpdir.strpath,
@@ -750,13 +750,13 @@ class TestHdfsParquetTableStatsWriter(ImpalaTestSuite):
"""Test that we don't write min/max statistics for null columns. Ensure null_count """Test that we don't write min/max statistics for null columns. Ensure null_count
is set for columns with null values.""" is set for columns with null values."""
expected_min_max_values = [ expected_min_max_values = [
ColumnStats('a', 'a', 'a', 0), ColumnStats('a', b'a', b'a', 0),
ColumnStats('b', '', '', 0), ColumnStats('b', b'', b'', 0),
ColumnStats('c', None, None, 1), ColumnStats('c', None, None, 1),
ColumnStats('d', None, None, 1), ColumnStats('d', None, None, 1),
ColumnStats('e', None, None, 1), ColumnStats('e', None, None, 1),
ColumnStats('f', 'a\x00b', 'a\x00b', 0), ColumnStats('f', b'a\x00b', b'a\x00b', 0),
ColumnStats('g', '\x00', '\x00', 0) ColumnStats('g', b'\x00', b'\x00', 0)
] ]
self._ctas_table_and_verify_stats(vector, unique_database, tmpdir.strpath, self._ctas_table_and_verify_stats(vector, unique_database, tmpdir.strpath,
@@ -778,9 +778,9 @@ class TestHdfsParquetTableStatsWriter(ImpalaTestSuite):
""".format(qualified_table_name) """.format(qualified_table_name)
self.execute_query(insert_stmt) self.execute_query(insert_stmt)
expected_min_max_values = [ expected_min_max_values = [
ColumnStats('c3', 'abc', 'xy', 0), ColumnStats('c3', b'abc', b'xy', 0),
ColumnStats('vc', 'abc banana', 'ghj xyz', 0), ColumnStats('vc', b'abc banana', b'ghj xyz', 0),
ColumnStats('st', 'abc xyz', 'lorem ipsum', 0) ColumnStats('st', b'abc xyz', b'lorem ipsum', 0)
] ]
self._ctas_table_and_verify_stats(vector, unique_database, tmpdir.strpath, self._ctas_table_and_verify_stats(vector, unique_database, tmpdir.strpath,
qualified_table_name, expected_min_max_values) qualified_table_name, expected_min_max_values)
@@ -875,10 +875,10 @@ class TestHdfsParquetTableStatsWriter(ImpalaTestSuite):
# Expected values for tpch_parquet.customer # Expected values for tpch_parquet.customer
expected_min_max_values = [ expected_min_max_values = [
ColumnStats('id', '8600000US00601', '8600000US999XX', 0), ColumnStats('id', b'8600000US00601', b'8600000US999XX', 0),
ColumnStats('zip', '00601', '999XX', 0), ColumnStats('zip', b'00601', b'999XX', 0),
ColumnStats('description1', '\"00601 5-Digit ZCTA', '\"999XX 5-Digit ZCTA', 0), ColumnStats('description1', b'\"00601 5-Digit ZCTA', b'\"999XX 5-Digit ZCTA', 0),
ColumnStats('description2', ' 006 3-Digit ZCTA\"', ' 999 3-Digit ZCTA\"', 0), ColumnStats('description2', b' 006 3-Digit ZCTA\"', b' 999 3-Digit ZCTA\"', 0),
ColumnStats('income', 0, 189570, 29), ColumnStats('income', 0, 189570, 29),
] ]

View File

@@ -916,9 +916,11 @@ class TestObservability(ImpalaTestSuite):
assert "Resizes:" in runtime_profile assert "Resizes:" in runtime_profile
nprobes = re.search('Probes:.*\((\d+)\)', runtime_profile) nprobes = re.search('Probes:.*\((\d+)\)', runtime_profile)
# Probes and travel can be 0. The number can be an integer or float with K. # Probes and travel can be 0. The number can be an integer or float with K.
assert nprobes and len(nprobes.groups()) == 1 and nprobes.group(1) >= 0 # The number extracted is the number inside parenthesis, which is always
# an integer.
assert nprobes and len(nprobes.groups()) == 1 and int(nprobes.group(1)) >= 0
ntravel = re.search('Travel:.*\((\d+)\)', runtime_profile) ntravel = re.search('Travel:.*\((\d+)\)', runtime_profile)
assert ntravel and len(ntravel.groups()) == 1 and ntravel.group(1) >= 0 assert ntravel and len(ntravel.groups()) == 1 and int(ntravel.group(1)) >= 0
def test_query_profle_hashtable(self): def test_query_profle_hashtable(self):
"""Test that the profile for join/aggregate contains hash table related """Test that the profile for join/aggregate contains hash table related

View File

@@ -320,7 +320,7 @@ class TestParquetBloomFilter(ImpalaTestSuite):
row_group = file_meta_data.row_groups[0] row_group = file_meta_data.row_groups[0]
assert len(schemas) == len(row_group.columns) assert len(schemas) == len(row_group.columns)
col_to_bloom_filter = dict() col_to_bloom_filter = dict()
with open(filename) as file_handle: with open(filename, 'rb') as file_handle:
for i, column in enumerate(row_group.columns): for i, column in enumerate(row_group.columns):
column_meta_data = column.meta_data column_meta_data = column.meta_data
if column_meta_data and column_meta_data.bloom_filter_offset: if column_meta_data and column_meta_data.bloom_filter_offset:

View File

@@ -73,7 +73,7 @@ class TestHdfsParquetTableIndexWriter(ImpalaTestSuite):
row_group = file_meta_data.row_groups[0] row_group = file_meta_data.row_groups[0]
assert len(schemas) == len(row_group.columns) assert len(schemas) == len(row_group.columns)
row_group_index = [] row_group_index = []
with open(parquet_file) as file_handle: with open(parquet_file, 'rb') as file_handle:
for column, schema in zip(row_group.columns, schemas): for column, schema in zip(row_group.columns, schemas):
column_index_offset = column.column_index_offset column_index_offset = column.column_index_offset
column_index_length = column.column_index_length column_index_length = column.column_index_length
@@ -170,7 +170,7 @@ class TestHdfsParquetTableIndexWriter(ImpalaTestSuite):
if not null_page: if not null_page:
page_min_value = decode_stats_value(column_info.schema, page_min_str) page_min_value = decode_stats_value(column_info.schema, page_min_str)
# If type is str, page_min_value might have been truncated. # If type is str, page_min_value might have been truncated.
if isinstance(page_min_value, basestring): if isinstance(page_min_value, bytes):
assert page_min_value >= column_min_value[:len(page_min_value)] assert page_min_value >= column_min_value[:len(page_min_value)]
else: else:
assert page_min_value >= column_min_value assert page_min_value >= column_min_value
@@ -180,9 +180,9 @@ class TestHdfsParquetTableIndexWriter(ImpalaTestSuite):
if not null_page: if not null_page:
page_max_value = decode_stats_value(column_info.schema, page_max_str) page_max_value = decode_stats_value(column_info.schema, page_max_str)
# If type is str, page_max_value might have been truncated and incremented. # If type is str, page_max_value might have been truncated and incremented.
if (isinstance(page_max_value, basestring) and if (isinstance(page_max_value, bytes)
len(page_max_value) == PAGE_INDEX_MAX_STRING_LENGTH): and len(page_max_value) == PAGE_INDEX_MAX_STRING_LENGTH):
max_val_prefix = page_max_value.rstrip('\0') max_val_prefix = page_max_value.rstrip(b'\0')
assert max_val_prefix[:-1] <= column_max_value assert max_val_prefix[:-1] <= column_max_value
else: else:
assert page_max_value <= column_max_value assert page_max_value <= column_max_value
@@ -389,7 +389,7 @@ class TestHdfsParquetTableIndexWriter(ImpalaTestSuite):
column = row_group_indexes[0][0] column = row_group_indexes[0][0]
assert len(column.column_index.max_values) == 1 assert len(column.column_index.max_values) == 1
max_value = column.column_index.max_values[0] max_value = column.column_index.max_values[0]
assert max_value == 'aab' assert max_value == b'aab'
def test_row_count_limit(self, vector, unique_database, tmpdir): def test_row_count_limit(self, vector, unique_database, tmpdir):
"""Tests that we can set the page row count limit via a query option. """Tests that we can set the page row count limit via a query option.

View File

@@ -1543,7 +1543,8 @@ class TestTextSplitDelimiters(ImpalaTestSuite):
query = "create table %s (s string) location '%s'" % (qualified_table_name, location) query = "create table %s (s string) location '%s'" % (qualified_table_name, location)
self.client.execute(query) self.client.execute(query)
with tempfile.NamedTemporaryFile() as f: # Passing "w+" to NamedTemporaryFile prevents it from opening the file in bytes mode
with tempfile.NamedTemporaryFile(mode="w+") as f:
f.write(data) f.write(data)
f.flush() f.flush()
self.filesystem_client.copy_from_local(f.name, location) self.filesystem_client.copy_from_local(f.name, location)

View File

@@ -1,4 +1,4 @@
#!/usr/bin/env impala-python #!/usr/bin/env impala-env-versioned-python
# #
# Licensed to the Apache Software Foundation (ASF) under one # Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file # or more contributor license agreements. See the NOTICE file
@@ -48,7 +48,7 @@ VALID_TEST_DIRS = ['failure', 'query_test', 'stress', 'unittests', 'aux_query_te
TEST_HELPER_DIRS = ['aux_parquet_data_load', 'comparison', 'benchmark', TEST_HELPER_DIRS = ['aux_parquet_data_load', 'comparison', 'benchmark',
'custom_cluster', 'util', 'experiments', 'verifiers', 'common', 'custom_cluster', 'util', 'experiments', 'verifiers', 'common',
'performance', 'beeswax', 'aux_custom_cluster_tests', 'performance', 'beeswax', 'aux_custom_cluster_tests',
'authorization', 'test-hive-udfs'] 'authorization', 'test-hive-udfs', '__pycache__']
TEST_DIR = os.path.join(os.environ['IMPALA_HOME'], 'tests') TEST_DIR = os.path.join(os.environ['IMPALA_HOME'], 'tests')
RESULT_DIR = os.path.join(os.environ['IMPALA_EE_TEST_LOGS_DIR'], 'results') RESULT_DIR = os.path.join(os.environ['IMPALA_EE_TEST_LOGS_DIR'], 'results')

View File

@@ -321,7 +321,7 @@ class ImpalaShell(object):
stdout_arg = stdout_file if stdout_file is not None else PIPE stdout_arg = stdout_file if stdout_file is not None else PIPE
stderr_arg = stderr_file if stderr_file is not None else PIPE stderr_arg = stderr_file if stderr_file is not None else PIPE
return Popen(cmd, shell=False, stdout=stdout_arg, stdin=PIPE, stderr=stderr_arg, return Popen(cmd, shell=False, stdout=stdout_arg, stdin=PIPE, stderr=stderr_arg,
env=build_shell_env(env)) universal_newlines=True, env=build_shell_env(env))
def get_unused_port(): def get_unused_port():

View File

@@ -19,6 +19,7 @@ from __future__ import absolute_import, division, print_function
from builtins import map from builtins import map
import os import os
import struct import struct
import sys
from datetime import date, datetime, time, timedelta from datetime import date, datetime, time, timedelta
from decimal import Decimal from decimal import Decimal
@@ -28,7 +29,7 @@ from subprocess import check_call
from thrift.protocol import TCompactProtocol from thrift.protocol import TCompactProtocol
from thrift.transport import TTransport from thrift.transport import TTransport
PARQUET_VERSION_NUMBER = 'PAR1' PARQUET_VERSION_NUMBER = b'PAR1'
def create_protocol(serialized_object_buffer): def create_protocol(serialized_object_buffer):
@@ -99,10 +100,14 @@ def decode_decimal(schema, value):
assert schema.type_length == len(value) assert schema.type_length == len(value)
assert schema.type == Type.FIXED_LEN_BYTE_ARRAY assert schema.type == Type.FIXED_LEN_BYTE_ARRAY
numeric = Decimal(reduce(lambda x, y: x * 256 + y, list(map(ord, value)))) if sys.version_info.major < 3:
byte_values = list(map(ord, value))
else:
byte_values = list(value)
numeric = Decimal(reduce(lambda x, y: x * 256 + y, byte_values))
# Compute two's complement for negative values. # Compute two's complement for negative values.
if (ord(value[0]) > 127): if (byte_values[0] > 127):
bit_width = 8 * len(value) bit_width = 8 * len(value)
numeric = numeric - (2 ** bit_width) numeric = numeric - (2 ** bit_width)
@@ -154,7 +159,7 @@ def get_parquet_metadata(filename):
file path. file path.
""" """
file_size = os.path.getsize(filename) file_size = os.path.getsize(filename)
with open(filename) as f: with open(filename, 'rb') as f:
# Check file starts and ends with magic bytes # Check file starts and ends with magic bytes
start_magic = f.read(len(PARQUET_VERSION_NUMBER)) start_magic = f.read(len(PARQUET_VERSION_NUMBER))
assert start_magic == PARQUET_VERSION_NUMBER assert start_magic == PARQUET_VERSION_NUMBER

View File

@@ -259,7 +259,7 @@ class HadoopFsCommandLineClient(BaseFilesystem):
Overwrites files by default to avoid S3 consistency issues. Specifes the '-d' option Overwrites files by default to avoid S3 consistency issues. Specifes the '-d' option
by default, which 'Skip[s] creation of temporary file with the suffix ._COPYING_.' to by default, which 'Skip[s] creation of temporary file with the suffix ._COPYING_.' to
avoid extraneous copies on S3. 'src' must be either a string or a list of strings.""" avoid extraneous copies on S3. 'src' must be either a string or a list of strings."""
assert isinstance(src, list) or isinstance(src, basestring) assert isinstance(src, list) or isinstance(src, str)
src_list = src if isinstance(src, list) else [src] src_list = src if isinstance(src, list) else [src]
(status, stdout, stderr) = self._hadoop_fs_shell(['-copyFromLocal', '-d', '-f'] + (status, stdout, stderr) = self._hadoop_fs_shell(['-copyFromLocal', '-d', '-f'] +
src_list + [dst]) src_list + [dst])

View File

@@ -52,7 +52,8 @@ def exec_process_async(cmd):
LOG.debug('Executing: %s' % (cmd,)) LOG.debug('Executing: %s' % (cmd,))
# Popen needs a list as its first parameter. The first element is the command, # Popen needs a list as its first parameter. The first element is the command,
# with the rest being arguments. # with the rest being arguments.
return Popen(shlex.split(cmd), shell=False, stdout=PIPE, stderr=PIPE) return Popen(shlex.split(cmd), shell=False, stdout=PIPE, stderr=PIPE,
universal_newlines=True)
def shell(cmd, cmd_prepend="set -euo pipefail\n", stdout=PIPE, stderr=STDOUT, def shell(cmd, cmd_prepend="set -euo pipefail\n", stdout=PIPE, stderr=STDOUT,
timeout_secs=None, **popen_kwargs): timeout_secs=None, **popen_kwargs):

View File

@@ -115,7 +115,7 @@ def parse_table_constraints(constraints_file):
if not os.path.isfile(constraints_file): if not os.path.isfile(constraints_file):
LOG.info('No schema constraints file file found') LOG.info('No schema constraints file file found')
else: else:
with open(constraints_file, 'rb') as constraints_file: with open(constraints_file, 'r') as constraints_file:
for line in constraints_file.readlines(): for line in constraints_file.readlines():
line = line.strip() line = line.strip()
if not line or line.startswith('#'): if not line or line.startswith('#'):
@@ -164,7 +164,10 @@ def parse_test_file(test_file_name, valid_section_names, skip_unknown_sections=T
""" """
with open(test_file_name, 'rb') as test_file: with open(test_file_name, 'rb') as test_file:
file_data = test_file.read() file_data = test_file.read()
if encoding: file_data = file_data.decode(encoding) if encoding:
file_data = file_data.decode(encoding)
else:
file_data = file_data.decode('utf-8')
if os.environ["USE_APACHE_HIVE"] == "true": if os.environ["USE_APACHE_HIVE"] == "true":
# Remove Hive 4.0 feature for tpcds_schema_template.sql # Remove Hive 4.0 feature for tpcds_schema_template.sql
if "tpcds_schema_template" in test_file_name: if "tpcds_schema_template" in test_file_name: