Files
impala/tests/comparison/db_connection.py
Taras Bobrovytsky 524ac88589 Make it possible to disable analytic functions in Qgen
If BASIC cannot be assigned as select item category because data type
is Float and an AGG is present, assign AGG instead of ANALYTIC. This
should make it possible to disable analytic functions from the query
profile.

Change-Id: Ic025fc44fc9e10f501afd25bf11023c8b6068ec9
Reviewed-on: http://gerrit.cloudera.org:8080/2234
Reviewed-by: Taras Bobrovytsky <tbobrovytsky@cloudera.com>
Tested-by: Internal Jenkins
2016-02-19 00:03:23 -08:00

1109 lines
34 KiB
Python

# Copyright (c) 2015 Cloudera, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''This module is intended to standardize workflows when working with various databases
such as Impala, Postgresql, etc. Even with pep-249 (DB API 2), workflows differ
slightly. For example Postgresql does not allow changing databases from within a
connection.
'''
import hashlib
import impala.dbapi
import shelve
from abc import ABCMeta, abstractmethod
from contextlib import closing
from decimal import Decimal as PyDecimal
from itertools import combinations, ifilter, izip
from logging import getLogger
from os import symlink, unlink
from pyparsing import (
alphanums,
delimitedList,
Forward,
Group,
Literal,
nums,
Suppress,
Word)
from re import compile
from tempfile import gettempdir
from threading import Lock
from time import time
from common import (
ArrayColumn,
Column,
MapColumn,
StructColumn,
Table,
TableExprList)
from db_types import (
Char,
Decimal,
Double,
EXACT_TYPES,
Float,
get_char_class,
get_decimal_class,
get_varchar_class,
Int,
String,
Timestamp,
TinyInt,
VarChar)
LOG = getLogger(__name__)
HIVE = "HIVE"
IMPALA = "IMPALA"
MYSQL = "MYSQL"
ORACLE = "ORACLE"
POSTGRESQL = "POSTGRESQL"
class DbCursor(object):
'''Wraps a DB API 2 cursor to provide access to the related conn. This class
implements the DB API 2 interface by delegation.
'''
@staticmethod
def describe_common_tables(cursors):
'''Find and return a TableExprList containing Table objects that the given conns
have in common.
'''
common_table_names = None
for cursor in cursors:
table_names = set(cursor.list_table_names())
if common_table_names is None:
common_table_names = table_names
else:
common_table_names &= table_names
common_table_names = sorted(common_table_names)
tables = TableExprList()
for table_name in common_table_names:
common_table = None
mismatch = False
for cursor in cursors:
table = cursor.describe_table(table_name)
if common_table is None:
common_table = table
continue
if not table.cols:
LOG.debug('%s has no remaining columns', table_name)
mismatch = True
break
if len(common_table.cols) != len(table.cols):
LOG.debug('Ignoring table %s.'
' It has a different number of columns across databases.', table_name)
mismatch = True
break
for left, right in izip(common_table.cols, table.cols):
if not left.name == right.name and left.type == right.type:
LOG.debug('Ignoring table %s. It has different columns %s vs %s.' %
(table_name, left, right))
mismatch = True
break
if mismatch:
break
if not mismatch:
tables.append(common_table)
return tables
SQL_TYPE_PATTERN = compile(r'([^()]+)(\((\d+,? ?)*\))?')
TYPE_NAME_ALIASES = \
dict((type_.name().upper(), type_.name().upper()) for type_ in EXACT_TYPES)
TYPES_BY_NAME = dict((type_.name().upper(), type_) for type_ in EXACT_TYPES)
EXACT_TYPES_TO_SQL = dict((type_, type_.name().upper()) for type_ in EXACT_TYPES)
@classmethod
def make_insert_sql_from_data(cls, table, rows):
if not rows:
raise Exception('At least one row is required')
if not table.cols:
raise Exception('At least one col is required')
sql = 'INSERT INTO %s VALUES ' % table.name
for row_idx, row in enumerate(rows):
if row_idx > 0:
sql += ', '
sql += '('
for col_idx, col in enumerate(table.cols):
if col_idx > 0:
sql += ', '
val = row[col_idx]
if val is None:
sql += 'NULL'
elif issubclass(col.type, Timestamp):
sql += "TIMESTAMP '%s'" % val
elif issubclass(col.type, Char):
sql += "'%s'" % val.replace("'", "''")
else:
sql += str(val)
sql += ')'
return sql
def __init__(self, conn, cursor):
self._conn = conn
self._cursor = cursor
def __getattr__(self, attr):
return getattr(self._cursor, attr)
def __setattr__(self, attr, value):
# Transfer unknown attributes to the underlying cursor.
if attr not in ["_conn", "_cursor"]:
setattr(self._cursor, attr, value)
object.__setattr__(self, attr, value)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close(quiet=True)
@property
def db_type(self):
return self._conn.db_type
@property
def conn(self):
return self._conn
@property
def db_name(self):
return self._conn.db_name
def execute(self, sql, *args, **kwargs):
LOG.debug('%s: %s' % (self.db_type, sql))
if self.conn.sql_log:
self.conn.sql_log.write('\nQuery: %s' % sql)
return self._cursor.execute(sql, *args, **kwargs)
def execute_and_fetchall(self, sql, *args, **kwargs):
self.execute(sql, *args, **kwargs)
return self.fetchall()
def close(self, quiet=False):
try:
self._cursor.close()
except Exception as e:
if quiet:
LOG.debug('Error closing cursor: %s', e, exc_info=True)
else:
raise e
def reconnect(self):
self.conn.reconnect()
self._cursor = self.conn.cursor()._cursor
def list_db_names(self):
'''Return a list of database names always in lowercase.'''
rows = self.execute_and_fetchall(self.make_list_db_names_sql())
return [row[0].lower() for row in rows]
def make_list_db_names_sql(self):
return 'SHOW DATABASES'
def create_db(self, db_name):
LOG.info("Creating database %s", db_name)
db_name = db_name.lower()
self.execute('CREATE DATABASE ' + db_name)
def drop_db_if_exists(self, db_name):
'''This should not be called from a conn to the database being dropped.'''
db_name = db_name.lower()
if db_name not in self.list_db_names():
return
if self.db_name and self.db_name.lower() == db_name:
raise Exception('Cannot drop database while still connected to it')
self.drop_db(db_name)
def drop_db(self, db_name):
LOG.info("Dropping database %s", db_name)
db_name = db_name.lower()
self.execute('DROP DATABASE ' + db_name)
def ensure_empty_db(self, db_name):
self.drop_db_if_exists(db_name)
self.create_db(db_name)
def list_table_names(self):
'''Return a list of table names always in lowercase.'''
rows = self.execute_and_fetchall(self.make_list_table_names_sql())
return [row[0].lower() for row in rows]
def make_list_table_names_sql(self):
return 'SHOW TABLES'
def parse_col_desc(self, data_type):
''' Returns a prased output based on type describe output.
data_type is string that should look like this:
"bigint"
or like this:
"array<struct<
field_51:int,
field_52:bigint,
field_53:int,
field_54:boolean
>>"
In the first case, this method would return: 'bigint'
In the second case, it would return
['array',
['struct',
['field_51', 'int'],
['field_52', 'bigint'],
['field_53', 'int'],
['field_54', 'boolean']]]
This output is used to create the appropriate columns by self.create_column().
'''
COMMA, LPAR, RPAR, COLON, LBRA, RBRA = map(Suppress, ",<>:()")
t_bigint = Literal('bigint')
t_int = Literal('int')
t_integer = Literal('integer')
t_smallint = Literal('smallint')
t_tinyint = Literal('tinyint')
t_boolean = Literal('boolean')
t_string = Literal('string')
t_timestamp = Literal('timestamp')
t_timestamp_without_time_zone = Literal('timestamp without time zone')
t_float = Literal('float')
t_double = Literal('double')
t_real = Literal('real')
t_double_precision = Literal('double precision')
t_decimal = Group(Literal('decimal') + LBRA + Word(nums) + COMMA + Word(nums) + RBRA)
t_numeric = Group(Literal('numeric') + LBRA + Word(nums) + COMMA + Word(nums) + RBRA)
t_char = Group(Literal('char') + LBRA + Word(nums) + RBRA)
t_character = Group(Literal('character') + LBRA + Word(nums) + RBRA)
t_varchar = (Group(Literal('varchar') + LBRA + Word(nums) + RBRA) |
Literal('varchar'))
t_character_varying = Group(Literal('character varying') + LBRA + Word(nums) + RBRA)
t_struct = Forward()
t_array = Forward()
t_map = Forward()
complex_type = (t_struct | t_array | t_map)
any_type = (
complex_type |
t_bigint |
t_int |
t_integer |
t_smallint |
t_tinyint |
t_boolean |
t_string |
t_timestamp |
t_timestamp_without_time_zone |
t_float |
t_double |
t_real |
t_double_precision |
t_decimal |
t_numeric |
t_char |
t_character |
t_character_varying |
t_varchar)
struct_field_name = Word(alphanums + '_')
struct_field_pair = Group(struct_field_name + COLON + any_type)
t_struct << Group(Literal('struct') + LPAR + delimitedList(struct_field_pair) + RPAR)
t_array << Group(Literal('array') + LPAR + any_type + RPAR)
t_map << Group(Literal('map') + LPAR + any_type + COMMA + any_type + RPAR)
return any_type.parseString(data_type)[0]
def create_column(self, col_name, col_type):
''' Takes the output from parse_col_desc and creates the right column type. This
method returns one of Column, ArrayColumn, MapColumn, StructColumn.'''
if isinstance(col_type, str):
if col_type.upper() == 'VARCHAR':
col_type = 'STRING'
type_name = self.TYPE_NAME_ALIASES.get(col_type.upper())
return Column(owner=None,
name=col_name.lower(),
exact_type=self.TYPES_BY_NAME[type_name])
general_class = col_type[0]
if general_class.upper() == 'ARRAY':
return ArrayColumn(
owner=None,
name=col_name.lower(),
item=self.create_column(col_name='item', col_type=col_type[1]))
if general_class.upper() == 'MAP':
return MapColumn(
owner=None,
name=col_name.lower(),
key=self.create_column(col_name='key', col_type=col_type[1]),
value=self.create_column(col_name='value', col_type=col_type[2]))
if general_class.upper() == 'STRUCT':
struct_col = StructColumn(owner=None, name=col_name.lower())
for field_name, field_type in col_type[1:]:
struct_col.add_col(self.create_column(field_name, field_type))
return struct_col
general_class = self.TYPE_NAME_ALIASES.get(col_type[0].upper())
if general_class.upper() == 'DECIMAL':
return Column(owner=None,
name=col_name.lower(),
exact_type=get_decimal_class(int(col_type[1]), int(col_type[2])))
if general_class.upper() == 'CHAR':
return Column(owner=None,
name=col_name.lower(),
exact_type=get_char_class(int(col_type[1])))
if general_class.upper() == 'VARCHAR':
type_size = int(col_type[1])
if type_size <= VarChar.MAX:
cur_type = get_varchar_class(type_size)
else:
cur_type = self.TYPES_BY_NAME['STRING']
return Column(owner=None,
name=col_name.lower(),
exact_type=cur_type)
raise Exception('unable to parse: {0}, type: {1}'.format(col_name, col_type))
def create_table_from_describe(self, table_name, describe_rows):
table = Table(table_name.lower())
for row in describe_rows:
col_name, data_type = row[:2]
col_type = self.parse_col_desc(data_type)
col = self.create_column(col_name, col_type)
table.add_col(col)
return table
def describe_table(self, table_name):
'''Return a Table with table and col names always in lowercase.'''
describe_rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name))
table = self.create_table_from_describe(table_name, describe_rows)
self.load_unique_col_metadata(table)
return table
def make_describe_table_sql(self, table_name):
return 'DESCRIBE ' + table_name
def parse_data_type(self, type_name, type_size):
if type_name in ('DECIMAL', 'NUMERIC'):
return get_decimal_class(*type_size)
if type_name == 'CHAR':
return get_char_class(*type_size)
if type_name == 'VARCHAR':
if type_size and type_size[0] <= VarChar.MAX:
return get_varchar_class(*type_size)
type_name = 'STRING'
return self.TYPES_BY_NAME[type_name]
def create_table(self, table):
LOG.info('Creating table %s', table.name)
if not table.cols:
raise Exception('At least one col is required')
table_sql = self.make_create_table_sql(table)
LOG.debug(table_sql)
self.execute(table_sql)
LOG.debug('Created table %s', table.name)
def make_create_table_sql(self, table):
sql = 'CREATE TABLE %s (%s)' % (
table.name,
', '.join('%s %s' %
(col.name, self.get_sql_for_data_type(col.exact_type)) +
('' if self.conn.data_types_are_implictly_nullable else ' NULL')
for col in table.cols))
return sql
def get_sql_for_data_type(self, data_type):
if issubclass(data_type, VarChar):
return 'VARCHAR(%s)' % data_type.MAX
if issubclass(data_type, Char):
return 'CHAR(%s)' % data_type.MAX
if issubclass(data_type, Decimal):
return 'DECIMAL(%s, %s)' % (data_type.MAX_DIGITS, data_type.MAX_FRACTIONAL_DIGITS)
return self.EXACT_TYPES_TO_SQL[data_type]
def drop_table(self, table_name, if_exists=True):
LOG.info('Dropping table %s', table_name)
self.execute('DROP TABLE IF EXISTS ' + table_name.lower())
def drop_view(self, view_name, if_exists=True):
LOG.info('Dropping view %s', view_name)
self.execute('DROP VIEW IF EXISTS ' + view_name.lower())
def index_table(self, table_name):
LOG.info('Indexing table %s', table_name)
table = self.describe_table(table_name)
for col in table.cols:
index_name = '%s_%s' % (table_name, col.name)
if self.db_name:
index_name = '%s_%s' % (self.db_name, index_name)
# Older versions of Postgres require index name to be included in the statement.
# In the past, there were some failures when trying to index a table with long
# column names because there is a limit to how long index name is allowed to be,
# This is why index_name is hashed and truncated to 10 characters.
index_name = 'ind_' + hashlib.sha256(index_name).hexdigest()[:10]
self.execute('CREATE INDEX %s ON %s(%s)' % (index_name, table_name, col.name))
def search_for_unique_cols(self, table=None, table_name=None, depth=2):
if not table:
table = self.describe_table(table_name)
sql_templ = 'SELECT COUNT(*) FROM %s GROUP BY %%s HAVING COUNT(*) > 1' % table.name
unique_cols = list()
for current_depth in xrange(1, depth + 1):
for cols in combinations(table.cols, current_depth): # redundant combos excluded
cols = set(cols)
if any(ifilter(lambda unique_subset: unique_subset < cols, unique_cols)):
# cols contains a combo known to be unique
continue
col_names = ', '.join(col.name for col in cols)
sql = sql_templ % col_names
LOG.debug('Checking column combo (%s) for uniqueness' % col_names)
self.execute(sql)
if not self.fetchone():
LOG.debug('Found unique column combo (%s)' % col_names)
unique_cols.append(cols)
return unique_cols
def persist_unique_col_metadata(self, table):
if not table.unique_cols:
return
with closing(shelve.open('/tmp/query_generator.shelve', writeback=True)) as store:
if self.db_type not in store:
store[self.db_type] = dict()
db_type_store = store[self.db_type]
if self.db_name not in db_type_store:
db_type_store[self.db_name] = dict()
db_store = db_type_store[self.db_name]
db_store[table.name] = [[col.name for col in cols] for cols in table.unique_cols]
def load_unique_col_metadata(self, table):
with closing(shelve.open('/tmp/query_generator.shelve')) as store:
db_type_store = store.get(self.db_type)
if not db_type_store:
return
db_store = db_type_store.get(self.db_name)
if not db_store:
return
unique_col_names = db_store.get(table.name)
if not unique_col_names:
return
unique_cols = list()
for entry in unique_col_names:
col_names = set(entry)
cols = set((col for col in table.cols if col.name in col_names))
if len(col_names) != len(cols):
raise Exception("Incorrect unique column data for %s" % table.name)
unique_cols.append(cols)
table.unique_cols = unique_cols
class DbConnection(object):
__metaclass__ = ABCMeta
LOCK = Lock()
PORT = None
USER_NAME = None
PASSWORD = None
_CURSOR_CLASS = DbCursor
def __init__(self, host_name="localhost", port=None, user_name=None, password=None,
db_name=None, log_sql=False):
self._host_name = host_name
self._port = port or self.PORT
self._user_name = user_name or self.USER_NAME
self._password = password or self.PASSWORD
self.db_name = db_name
self._conn = None
self._connect()
if log_sql:
with DbConnection.LOCK:
sql_log_path = gettempdir() + '/sql_log_%s_%s.sql' \
% (self.db_type.lower(), time())
self.sql_log = open(sql_log_path, 'w')
link = gettempdir() + '/sql_log_%s.sql' % self.db_type.lower()
try:
unlink(link)
except OSError as e:
if not 'No such file' in str(e):
raise e
try:
symlink(sql_log_path, link)
except OSError as e:
raise e
else:
self.sql_log = None
def _clone(self, db_name, **kwargs):
return type(self)(host_name=self._host_name, port=self._port,
user_name=self._user_name, password=self._password, db_name=db_name, **kwargs)
def clone(self, db_name):
return self._clone(db_name)
def __getattr__(self, attr):
if attr == "_conn":
raise AttributeError()
return getattr(self._conn, attr)
def __setattr__(self, attr, value):
_conn = getattr(self, "_conn", None)
if not _conn or not hasattr(_conn, attr):
object.__setattr__(self, attr, value)
else:
setattr(self._conn, attr, value)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close(quiet=True)
@property
def db_type(self):
return self._DB_TYPE
@abstractmethod
def _connect(self):
pass
def reconnect(self):
self.close(quiet=True)
self._connect()
def close(self, quiet=False):
'''Close the underlying conn.'''
if not self._conn:
return
try:
self._conn.close()
self._conn = None
except Exception as e:
if quiet:
LOG.debug('Error closing connection: %s' % e)
else:
raise e
def cursor(self):
return self._CURSOR_CLASS(self, self._conn.cursor())
@property
def supports_kill(self):
return False
def kill(self):
'''Kill the current connection and any currently running queries associated with the
conn.
'''
raise Exception('Killing connection is not supported')
@property
def supports_index_creation(self):
return True
@property
def data_types_are_implictly_nullable(self):
return False
class ImpalaCursor(DbCursor):
@classmethod
def make_insert_sql_from_data(cls, table, rows):
if not rows:
raise Exception('At least one row is required')
if not table.cols:
raise Exception('At least one col is required')
sql = 'INSERT INTO %s VALUES ' % table.name
for row_idx, row in enumerate(rows):
if row_idx > 0:
sql += ', '
sql += '('
for col_idx, col in enumerate(table.cols):
if col_idx > 0:
sql += ', '
val = row[col_idx]
if val is None:
sql += 'NULL'
elif issubclass(col.type, Timestamp):
sql += "'%s'" % val
elif issubclass(col.type, Char):
val = val.replace("'", "''")
sql += "'%s'" % val
else:
sql += str(val)
sql += ')'
return sql
@property
def cluster(self):
return self.conn.cluster
def invalidate_metadata(self, table_name=None):
self.execute("INVALIDATE METADATA %s" % (table_name or ""))
def drop_db(self, db_name):
'''This should not be called when connected to the database being dropped.'''
LOG.info("Dropping database %s", db_name)
self.execute('DROP DATABASE %s CASCADE' % db_name)
def compute_stats(self, table_name=None):
if table_name:
self.execute("COMPUTE STATS %s" % table_name)
else:
for table_name in self.list_table_names():
self.execute("COMPUTE STATS %s" % table_name)
def create_table(self, table):
# The Hive metastore has a limit to the amount of schema it can store inline.
# Beyond this limit, the schema needs to be stored in HDFS and Hive is given a
# URL instead.
if table.storage_format == "AVRO" and not table.schema_location \
and len(table.get_avro_schema()) > 4000:
self.upload_avro_schema(table)
super(ImpalaCursor, self).create_table(table)
def upload_avro_schema(self, table):
self.ensure_schema_location(table)
avro_schema = table.get_avro_schema()
self.cluster.hdfs.create_client().write(
table.schema_location, data=avro_schema, overwrite=True)
def ensure_schema_location(self, table):
if not table.schema_location:
db_name = self.conn.db_name or "default"
table.schema_location = '%s/%s.db/%s.avsc' % (
self.cluster.hive.warehouse_dir, db_name, table.name)
def ensure_storage_location(self, table):
if not table.storage_location:
db_name = self.conn.db_name or "default"
table.storage_location = '%s/%s.db/data/%s' % (
self.cluster.hive.warehouse_dir, db_name, table.name)
def make_create_table_sql(self, table):
sql = super(ImpalaCursor, self).make_create_table_sql(table)
if table.storage_format != 'TEXTFILE':
sql += "\nSTORED AS " + table.storage_format
if table.storage_location:
sql = sql.replace("CREATE TABLE", "CREATE EXTERNAL TABLE")
sql += "\nLOCATION '%s'" % table.storage_location
if table.storage_format == 'AVRO':
if table.schema_location:
sql += "\nTBLPROPERTIES ('avro.schema.url' = '%s')" % table.schema_location
else:
avro_schema = table.get_avro_schema()
if len(avro_schema) > 4000:
raise Exception("Avro schema exceeds 4000 character limit. Create a file"
" containing the schema instead and set 'table.schema_location'.")
sql += "\nTBLPROPERTIES ('avro.schema.literal' = '%s')" % avro_schema
return sql
def get_sql_for_data_type(self, data_type):
if issubclass(data_type, String):
return 'STRING'
return super(ImpalaCursor, self).get_sql_for_data_type(data_type)
def close(self, quiet=False):
try:
# Explicitly close the operation to avoid issues like
# https://issues.cloudera.org/browse/IMPALA-2562.
# This can be remove if https://github.com/cloudera/impyla/pull/142 is merged.
self._cursor.close_operation()
self._cursor.close()
except Exception as e:
if quiet:
LOG.debug('Error closing cursor: %s', e, exc_info=True)
else:
raise e
class ImpalaConnection(DbConnection):
PORT = 21050 # For HS2
_DB_TYPE = IMPALA
_CURSOR_CLASS = ImpalaCursor
_NON_KERBEROS_AUTH_MECH = 'NOSASL'
def __init__(self, use_kerberos=False, **kwargs):
self._use_kerberos = use_kerberos
self.cluster = None
DbConnection.__init__(self, **kwargs)
def clone(self, db_name):
clone = self._clone(db_name, use_kerberos=self._use_kerberos)
clone.cluster = self.cluster
return clone
@property
def data_types_are_implictly_nullable(self):
return True
@property
def supports_index_creation(self):
return False
def cursor(self):
cursor = super(ImpalaConnection, self).cursor()
cursor.arraysize = 1024 # Try to match the default batch size
return cursor
def _connect(self):
self._conn = impala.dbapi.connect(
host=self._host_name,
port=self._port,
user=self._user_name,
password=self._password,
database=self.db_name,
timeout=(60 * 60),
auth_mechanism=('GSSAPI' if self._use_kerberos else self._NON_KERBEROS_AUTH_MECH))
class HiveConnection(ImpalaConnection):
PORT = 11050
_DB_TYPE = HIVE
_NON_KERBEROS_AUTH_MECH = 'PLAIN'
class PostgresqlCursor(DbCursor):
TYPE_NAME_ALIASES = dict(DbCursor.TYPE_NAME_ALIASES)
TYPE_NAME_ALIASES.update({
'INTEGER': 'INT',
'NUMERIC': 'DECIMAL',
'REAL': 'FLOAT',
'DOUBLE PRECISION': 'DOUBLE',
'CHARACTER': 'CHAR',
'CHARACTER VARYING': 'VARCHAR',
'TIMESTAMP WITHOUT TIME ZONE': 'TIMESTAMP'})
EXACT_TYPES_TO_SQL = dict(DbCursor.EXACT_TYPES_TO_SQL)
EXACT_TYPES_TO_SQL.update({
Double: 'DOUBLE PRECISION',
Float: 'REAL',
Int: 'INTEGER',
Timestamp: 'TIMESTAMP WITHOUT TIME ZONE',
TinyInt: 'SMALLINT'})
@classmethod
def make_insert_sql_from_data(cls, table, rows):
if not rows:
raise Exception('At least one row is required')
if not table.cols:
raise Exception('At least one col is required')
sql = 'INSERT INTO %s VALUES ' % table.name
for row_idx, row in enumerate(rows):
if row_idx > 0:
sql += ', '
sql += '('
for col_idx, col in enumerate(table.cols):
if col_idx > 0:
sql += ', '
val = row[col_idx]
if val is None:
sql += 'NULL'
elif issubclass(col.type, Timestamp):
sql += "TIMESTAMP '%s'" % val
elif issubclass(col.type, Char):
val = val.replace("'", "''")
val = val.replace('\\', '\\\\')
sql += "'%s'" % val
else:
sql += str(val)
sql += ')'
return sql
def make_list_db_names_sql(self):
return 'SELECT datname FROM pg_database'
def make_list_table_names_sql(self):
return '''
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public' '''
def make_describe_table_sql(self, table_name):
# When doing a CREATE TABLE AS SELECT... a column may end up with type "Numeric".
# We'll assume that's a DOUBLE.
return '''
SELECT column_name,
CASE data_type
WHEN 'character' THEN
data_type || '(' || character_maximum_length || ')'
WHEN 'character varying' THEN
data_type || '(' || character_maximum_length || ')'
WHEN 'numeric' THEN
data_type
|| '('
|| numeric_precision
|| ', '
|| numeric_scale
|| ')'
ELSE data_type
END data_type
FROM information_schema.columns
WHERE table_name = '%s'
ORDER BY ordinal_position''' % \
table_name
def get_sql_for_data_type(self, data_type):
if issubclass(data_type, String):
return 'VARCHAR(%s)' % String.MAX
return super(PostgresqlCursor, self).get_sql_for_data_type(data_type)
class PostgresqlConnection(DbConnection):
PORT = 5432
USER_NAME = "postgres"
_DB_TYPE = POSTGRESQL
_CURSOR_CLASS = PostgresqlCursor
@property
def supports_kill(self):
return True
def kill(self):
self._conn.cancel()
def _connect(self):
try:
import psycopg2
except ImportError as e:
if "No module named psycopg2" not in str(e):
raise e
import os
import subprocess
from tests.util.shell_util import shell
LOG.info("psycopg2 module not found; attempting to install it...")
pip_path = os.path.join(os.environ["IMPALA_HOME"], "infra", "python", "env",
"bin", "pip")
try:
shell(pip_path + " install psycopg2", stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
LOG.info("psycopg2 installation complete.")
except Exception as e:
LOG.error("psycopg2 installation failed. Try installing python-dev and"
" libpq-dev then try again.")
raise e
import psycopg2
self._conn = psycopg2.connect(
host=self._host_name,
port=self._port,
user=self._user_name,
password=self._password,
database=self.db_name)
self._conn.autocommit = True
class MySQLConnection(DbConnection):
PORT = 3306
USER_NAME = "root"
def _connect(self):
try:
import MySQLdb
except Exception:
print('Error importing MySQLdb. Please make sure it is installed. '
'See the README for details.')
raise
self._conn = MySQLdb.connect(
host=self._host_name,
port=self._port,
user=self._user_name,
passwd=self._password,
db=self.db_name)
self._conn.autocommit = True
class MySQLConnection(DbConnection):
def __init__(self, client, conn, db_name=None):
DbConnection.__init__(self, client, conn, db_name=db_name)
self._session_id = self.execute_and_fetchall('SELECT connection_id()')[0][0]
@property
def supports_kill_connection(self):
return True
def kill_connection(self):
with self._client.open_connection(db_name=self.db_name) as conn:
conn.execute('KILL %s' % (self._session_id))
class MySQLCursor(DbCursor):
def describe_table(self, table_name):
'''Return a Table with table and col names always in lowercase.'''
rows = self.conn.execute_and_fetchall(
self.make_describe_table_sql(table_name))
table = Table(table_name.lower())
cols = table.cols # This is a copy
for row in rows:
col_name, data_type = row[:2]
if data_type == 'tinyint(1)':
# Just assume this is a boolean...
data_type = 'boolean'
if 'decimal' not in data_type and '(' in data_type:
# Strip the size of the data type
data_type = data_type[:data_type.index('(')]
cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type)))
table.cols = cols
return table
def index_table(self, table_name):
table = self.describe_table(table_name)
with self.conn.open_cursor() as cursor:
for col in table.cols:
try:
cursor.execute('ALTER TABLE %s ADD INDEX (%s)' % (table_name, col.name))
except Exception as e:
if 'Incorrect index name' not in str(e):
raise
# Some sort of MySQL bug...
LOG.warn('Could not create index on %s.%s: %s' % (table_name, col.name, e))
def make_create_table_sql(self, table):
table_sql = super(MySQLConnection, self).make_create_table_sql(table)
table_sql += ' ENGINE = MYISAM'
return table_sql
class OracleCursor(DbCursor):
def make_list_table_names_sql(self):
return 'SELECT table_name FROM user_tables'
def drop_db(self, db_name):
self.execute('DROP USER %s CASCADE' % db_name)
def drop_db_if_exists(self, db_name):
'''This should not be called from a conn to the database being dropped.'''
try:
self.drop_db(db_name)
except Exception as e:
if 'ORA-01918' not in str(e): # Ignore if the user doesn't exist
raise
def create_db(self, db_name):
self.execute(
'CREATE USER %s IDENTIFIED BY %s DEFAULT TABLESPACE USERS' % (db_name, db_name))
self.execute('GRANT ALL PRIVILEGES TO %s' % db_name)
def make_describe_table_sql(self, table_name):
# Recreate the data types as defined in the model
return '''
SELECT
column_name,
CASE
WHEN data_type = 'NUMBER' AND data_scale = 0 THEN
data_type || '(' || data_precision || ')'
WHEN data_type = 'NUMBER' THEN
data_type || '(' || data_precision || ', ' || data_scale || ')'
WHEN data_type IN ('CHAR', 'VARCHAR2') THEN
data_type || '(' || data_length || ')'
WHEN data_type LIKE 'TIMESTAMP%%' THEN
'TIMESTAMP'
ELSE data_type
END
FROM all_tab_columns
WHERE owner = '%s' AND table_name = '%s'
ORDER BY column_id''' \
% (self.schema.upper(), table_name.upper())
class OracleConnection(DbConnection):
PORT = 1521
USER_NAME = 'system'
PASSWORD = 'oracle'
_CURSOR_CLASS = OracleCursor
def __init__(self, service='XE', **kwargs):
self._service = service
DbConnection.__init__(self, **kwargs)
self._conn.outputtypehandler = OracleTypeConverter.convert_type
self._conn.autocommit = True
def clone(self, db_name):
return self._clone(db_name, service=self._service)
@property
def schema(self):
return self.db_name or self._user_name
def _connect(self):
try:
import cx_Oracle
except:
print('Error importing cx_Oracle. Please make sure it is installed. '
'See the README for details.')
raise
self._conn = cx_Oracle.connect('%(user)s/%(password)s@%(host)s:%(port)s/%(service)s'
% (self._user_name, self._password, self._host_name, self._port, self._service))
def cursor(self):
cursor = super(OracleConnection, self).cursor()
if self.db_name and self.db_name != self._user_name:
cursor.execute('ALTER SESSION SET CURRENT_SCHEMA = %s' % self.db_name)
return cursor
class OracleTypeConverter(object):
__imported_types = False
__char_type = None
__number_type = None
@classmethod
def convert_type(cls, cursor, name, default_type, size, precision, scale):
if not cls.__imported_types:
from cx_Oracle import FIXED_CHAR, NUMBER
cls.__char_type = FIXED_CHAR
cls.__number_type = NUMBER
cls.__imported_types = True
if default_type == cls.__char_type and size == 1:
return cursor.var(str, 1, cursor.arraysize, outconverter=cls.convert_boolean)
if default_type == cls.__number_type and scale:
return cursor.var(str, 100, cursor.arraysize, outconverter=PyDecimal)
@classmethod
def convert_boolean(cls, val):
if not val:
return
return val == 'T'