mirror of
https://github.com/apache/impala.git
synced 2025-12-19 18:12:08 -05:00
test_load_table_with_primary_key_attr has been intermittently fail with following error message since Python3 switch in IMPALA-14333. dbm.error: db type could not be determined Error happen inside load_unique_col_metadata. But both load_unique_col_metadata and persist_unique_col_metadata are seemingly unused anymore. While the root cause is still unknown, commenting load_unique_col_metadata in db_connection.py can make the tests pass. Tests are not affected because the they do not assert against any unique columns. This patch implement that workaround until we can revisit the issue with query generator infrastructure. Fix few flake8 warnings in db_connection.py. Testing: Loop and pass test_show_create_table 10 times using downstream jenkins. Change-Id: I2caf6b9cc5a4f06237323cb32bbd47afef576fa1 Reviewed-on: http://gerrit.cloudera.org:8080/23412 Reviewed-by: Impala Public Jenkins <impala-public-jenkins@cloudera.com> Tested-by: Impala Public Jenkins <impala-public-jenkins@cloudera.com>
1258 lines
40 KiB
Python
1258 lines
40 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
'''This module is intended to standardize workflows when working with various databases
|
|
such as Impala, Postgresql, etc. Even with pep-249 (DB API 2), workflows differ
|
|
slightly. For example Postgresql does not allow changing databases from within a
|
|
connection.
|
|
|
|
'''
|
|
from __future__ import absolute_import, division, print_function
|
|
from builtins import filter, map, range, zip
|
|
from future.utils import with_metaclass
|
|
import hashlib
|
|
import impala.dbapi
|
|
import re
|
|
import shelve
|
|
from abc import ABCMeta, abstractmethod
|
|
from contextlib import closing
|
|
from decimal import Decimal as PyDecimal
|
|
from itertools import combinations
|
|
from logging import getLogger
|
|
from os import symlink, unlink
|
|
from pyparsing import (
|
|
alphanums,
|
|
delimitedList,
|
|
Forward,
|
|
Group,
|
|
Literal,
|
|
nums,
|
|
Suppress,
|
|
Word)
|
|
from tempfile import gettempdir
|
|
from textwrap import dedent
|
|
from threading import Lock
|
|
from time import time
|
|
|
|
from tests.comparison.common import (
|
|
ArrayColumn,
|
|
Column,
|
|
MapColumn,
|
|
StructColumn,
|
|
Table,
|
|
TableExprList)
|
|
from tests.comparison.db_types import (
|
|
Char,
|
|
Decimal,
|
|
Double,
|
|
EXACT_TYPES,
|
|
Float,
|
|
get_char_class,
|
|
get_decimal_class,
|
|
get_varchar_class,
|
|
Int,
|
|
String,
|
|
Timestamp,
|
|
TinyInt,
|
|
VarChar)
|
|
|
|
LOG = getLogger(__name__)
|
|
|
|
HIVE = "HIVE"
|
|
IMPALA = "IMPALA"
|
|
MYSQL = "MYSQL"
|
|
ORACLE = "ORACLE"
|
|
POSTGRESQL = "POSTGRESQL"
|
|
|
|
|
|
class DbCursor(object):
|
|
'''Wraps a DB API 2 cursor to provide access to the related conn. This class
|
|
implements the DB API 2 interface by delegation.
|
|
|
|
'''
|
|
|
|
@staticmethod
|
|
def describe_common_tables(cursors):
|
|
'''Find and return a TableExprList containing Table objects that the given conns
|
|
have in common.
|
|
'''
|
|
common_table_names = None
|
|
for cursor in cursors:
|
|
table_names = set(cursor.list_table_names())
|
|
if common_table_names is None:
|
|
common_table_names = table_names
|
|
else:
|
|
common_table_names &= table_names
|
|
common_table_names = sorted(common_table_names)
|
|
|
|
tables = TableExprList()
|
|
for table_name in common_table_names:
|
|
common_table = None
|
|
mismatch = False
|
|
for cursor in cursors:
|
|
table = cursor.describe_table(table_name)
|
|
if common_table is None:
|
|
common_table = table
|
|
continue
|
|
if not table.cols:
|
|
LOG.debug('%s has no remaining columns', table_name)
|
|
mismatch = True
|
|
break
|
|
if len(common_table.cols) != len(table.cols):
|
|
LOG.debug('Ignoring table %s.'
|
|
' It has a different number of columns across databases.', table_name)
|
|
mismatch = True
|
|
break
|
|
if common_table.primary_key_names != table.primary_key_names:
|
|
LOG.debug(
|
|
'Ignoring table {name} because of differing primary keys: '
|
|
'{common_table_keys} vs. {table_keys}'.format(
|
|
name=table_name, common_table_keys=common_table.primary_key_names,
|
|
table_keys=table.primary_key_names))
|
|
mismatch = True
|
|
break
|
|
for left, right in zip(common_table.cols, table.cols):
|
|
if not (left.name == right.name and left.type == right.type):
|
|
LOG.debug('Ignoring table %s. It has different columns %s vs %s.' %
|
|
(table_name, left, right))
|
|
mismatch = True
|
|
break
|
|
if mismatch:
|
|
break
|
|
if not mismatch:
|
|
tables.append(common_table)
|
|
|
|
return tables
|
|
|
|
SQL_TYPE_PATTERN = re.compile(r'([^()]+)(\((\d+,? ?)*\))?')
|
|
TYPE_NAME_ALIASES = \
|
|
dict((type_.name().upper(), type_.name().upper()) for type_ in EXACT_TYPES)
|
|
TYPES_BY_NAME = dict((type_.name().upper(), type_) for type_ in EXACT_TYPES)
|
|
EXACT_TYPES_TO_SQL = dict((type_, type_.name().upper()) for type_ in EXACT_TYPES)
|
|
|
|
@classmethod
|
|
def make_insert_sql_from_data(cls, table, rows):
|
|
if not rows:
|
|
raise Exception('At least one row is required')
|
|
if not table.cols:
|
|
raise Exception('At least one col is required')
|
|
|
|
sql = 'INSERT INTO %s VALUES ' % table.name
|
|
for row_idx, row in enumerate(rows):
|
|
if row_idx > 0:
|
|
sql += ', '
|
|
sql += '('
|
|
for col_idx, col in enumerate(table.cols):
|
|
if col_idx > 0:
|
|
sql += ', '
|
|
val = row[col_idx]
|
|
if val is None:
|
|
sql += 'NULL'
|
|
elif issubclass(col.type, Timestamp):
|
|
sql += "TIMESTAMP '%s'" % val
|
|
elif issubclass(col.type, Char):
|
|
sql += "'%s'" % val.replace("'", "''")
|
|
else:
|
|
sql += str(val)
|
|
sql += ')'
|
|
return sql
|
|
|
|
def __init__(self, conn, cursor):
|
|
self._conn = conn
|
|
self._cursor = cursor
|
|
|
|
def __getattr__(self, attr):
|
|
return getattr(self._cursor, attr)
|
|
|
|
def __setattr__(self, attr, value):
|
|
# Transfer unknown attributes to the underlying cursor.
|
|
if attr not in ["_conn", "_cursor"]:
|
|
setattr(self._cursor, attr, value)
|
|
object.__setattr__(self, attr, value)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.close(quiet=True)
|
|
|
|
@property
|
|
def db_type(self):
|
|
return self._conn.db_type
|
|
|
|
@property
|
|
def conn(self):
|
|
return self._conn
|
|
|
|
@property
|
|
def db_name(self):
|
|
return self._conn.db_name
|
|
|
|
def execute(self, sql, *args, **kwargs):
|
|
LOG.debug('%s: %s' % (self.db_type, sql))
|
|
if self.conn.sql_log:
|
|
self.conn.sql_log.write('\nQuery: %s' % sql)
|
|
return self._cursor.execute(sql, *args, **kwargs)
|
|
|
|
def execute_and_fetchall(self, sql, *args, **kwargs):
|
|
self.execute(sql, *args, **kwargs)
|
|
return self.fetchall()
|
|
|
|
def close(self, quiet=False):
|
|
try:
|
|
self._cursor.close()
|
|
except Exception as e:
|
|
if quiet:
|
|
LOG.debug('Error closing cursor: %s', e, exc_info=True)
|
|
else:
|
|
raise e
|
|
|
|
def reconnect(self):
|
|
self.conn.reconnect()
|
|
self._cursor = self.conn.cursor()._cursor
|
|
|
|
def list_db_names(self):
|
|
'''Return a list of database names always in lowercase.'''
|
|
rows = self.execute_and_fetchall(self.make_list_db_names_sql())
|
|
return [row[0].lower() for row in rows]
|
|
|
|
def make_list_db_names_sql(self):
|
|
return 'SHOW DATABASES'
|
|
|
|
def create_db(self, db_name):
|
|
LOG.info("Creating database %s", db_name)
|
|
db_name = db_name.lower()
|
|
self.execute('CREATE DATABASE ' + db_name)
|
|
|
|
def create_db_if_not_exists(self, db_name):
|
|
LOG.info("Creating database %s", db_name)
|
|
db_name = db_name.lower()
|
|
self.execute('CREATE DATABASE IF NOT EXISTS ' + db_name)
|
|
|
|
def drop_db_if_exists(self, db_name):
|
|
'''This should not be called from a conn to the database being dropped.'''
|
|
db_name = db_name.lower()
|
|
if db_name not in self.list_db_names():
|
|
return
|
|
if self.db_name and self.db_name.lower() == db_name:
|
|
raise Exception('Cannot drop database while still connected to it')
|
|
self.drop_db(db_name)
|
|
|
|
def drop_db(self, db_name):
|
|
LOG.info("Dropping database %s", db_name)
|
|
db_name = db_name.lower()
|
|
self.execute('DROP DATABASE ' + db_name)
|
|
|
|
def ensure_empty_db(self, db_name):
|
|
self.drop_db_if_exists(db_name)
|
|
self.create_db(db_name)
|
|
|
|
def list_table_names(self):
|
|
'''Return a list of table names always in lowercase.'''
|
|
rows = self.execute_and_fetchall(self.make_list_table_names_sql())
|
|
return [row[0].lower() for row in rows]
|
|
|
|
def make_list_table_names_sql(self):
|
|
return 'SHOW TABLES'
|
|
|
|
def parse_col_desc(self, data_type):
|
|
''' Returns a prased output based on type describe output.
|
|
data_type is string that should look like this:
|
|
"bigint"
|
|
or like this:
|
|
"array<struct<
|
|
field_51:int,
|
|
field_52:bigint,
|
|
field_53:int,
|
|
field_54:boolean
|
|
>>"
|
|
In the first case, this method would return: 'bigint'
|
|
In the second case, it would return
|
|
['array',
|
|
['struct',
|
|
['field_51', 'int'],
|
|
['field_52', 'bigint'],
|
|
['field_53', 'int'],
|
|
['field_54', 'boolean']]]
|
|
This output is used to create the appropriate columns by self.create_column().
|
|
'''
|
|
|
|
COMMA, LPAR, RPAR, COLON, LBRA, RBRA = map(Suppress, ",<>:()")
|
|
|
|
t_bigint = Literal('bigint')
|
|
t_int = Literal('int')
|
|
t_integer = Literal('integer')
|
|
t_smallint = Literal('smallint')
|
|
t_tinyint = Literal('tinyint')
|
|
t_boolean = Literal('boolean')
|
|
t_string = Literal('string')
|
|
t_timestamp = Literal('timestamp')
|
|
t_timestamp_without_time_zone = Literal('timestamp without time zone')
|
|
t_float = Literal('float')
|
|
t_double = Literal('double')
|
|
t_real = Literal('real')
|
|
t_double_precision = Literal('double precision')
|
|
|
|
t_decimal = Group(Literal('decimal') + LBRA + Word(nums) + COMMA + Word(nums) + RBRA)
|
|
t_numeric = Group(Literal('numeric') + LBRA + Word(nums) + COMMA + Word(nums) + RBRA)
|
|
t_char = Group(Literal('char') + LBRA + Word(nums) + RBRA)
|
|
t_character = Group(Literal('character') + LBRA + Word(nums) + RBRA)
|
|
t_varchar = (Group(Literal('varchar') + LBRA + Word(nums) + RBRA) |
|
|
Literal('varchar'))
|
|
t_character_varying = Group(Literal('character varying') + LBRA + Word(nums) + RBRA)
|
|
|
|
t_struct = Forward()
|
|
t_array = Forward()
|
|
t_map = Forward()
|
|
|
|
complex_type = (t_struct | t_array | t_map)
|
|
|
|
any_type = (
|
|
complex_type |
|
|
t_bigint |
|
|
t_int |
|
|
t_integer |
|
|
t_smallint |
|
|
t_tinyint |
|
|
t_boolean |
|
|
t_string |
|
|
t_timestamp |
|
|
t_timestamp_without_time_zone |
|
|
t_float |
|
|
t_double |
|
|
t_real |
|
|
t_double_precision |
|
|
t_decimal |
|
|
t_numeric |
|
|
t_char |
|
|
t_character |
|
|
t_character_varying |
|
|
t_varchar)
|
|
|
|
struct_field_name = Word(alphanums + '_')
|
|
struct_field_pair = Group(struct_field_name + COLON + any_type)
|
|
|
|
t_struct << Group(Literal('struct') + LPAR + delimitedList(struct_field_pair) + RPAR)
|
|
t_array << Group(Literal('array') + LPAR + any_type + RPAR)
|
|
t_map << Group(Literal('map') + LPAR + any_type + COMMA + any_type + RPAR)
|
|
|
|
return any_type.parseString(data_type)[0]
|
|
|
|
def create_column(self, col_name, col_type):
|
|
''' Takes the output from parse_col_desc and creates the right column type. This
|
|
method returns one of Column, ArrayColumn, MapColumn, StructColumn.'''
|
|
if isinstance(col_type, str):
|
|
if col_type.upper() == 'VARCHAR':
|
|
col_type = 'STRING'
|
|
type_name = self.TYPE_NAME_ALIASES.get(col_type.upper())
|
|
return Column(owner=None,
|
|
name=col_name.lower(),
|
|
exact_type=self.TYPES_BY_NAME[type_name])
|
|
|
|
general_class = col_type[0]
|
|
|
|
if general_class.upper() == 'ARRAY':
|
|
return ArrayColumn(
|
|
owner=None,
|
|
name=col_name.lower(),
|
|
item=self.create_column(col_name='item', col_type=col_type[1]))
|
|
|
|
if general_class.upper() == 'MAP':
|
|
return MapColumn(
|
|
owner=None,
|
|
name=col_name.lower(),
|
|
key=self.create_column(col_name='key', col_type=col_type[1]),
|
|
value=self.create_column(col_name='value', col_type=col_type[2]))
|
|
|
|
if general_class.upper() == 'STRUCT':
|
|
struct_col = StructColumn(owner=None, name=col_name.lower())
|
|
for field_name, field_type in col_type[1:]:
|
|
struct_col.add_col(self.create_column(field_name, field_type))
|
|
return struct_col
|
|
|
|
general_class = self.TYPE_NAME_ALIASES.get(col_type[0].upper())
|
|
|
|
if general_class.upper() == 'DECIMAL':
|
|
return Column(owner=None,
|
|
name=col_name.lower(),
|
|
exact_type=get_decimal_class(int(col_type[1]), int(col_type[2])))
|
|
|
|
if general_class.upper() == 'CHAR':
|
|
return Column(owner=None,
|
|
name=col_name.lower(),
|
|
exact_type=get_char_class(int(col_type[1])))
|
|
|
|
if general_class.upper() == 'VARCHAR':
|
|
type_size = int(col_type[1])
|
|
if type_size <= VarChar.MAX:
|
|
cur_type = get_varchar_class(type_size)
|
|
else:
|
|
cur_type = self.TYPES_BY_NAME['STRING']
|
|
return Column(owner=None,
|
|
name=col_name.lower(),
|
|
exact_type=cur_type)
|
|
|
|
raise Exception('unable to parse: {0}, type: {1}'.format(col_name, col_type))
|
|
|
|
def create_table_from_describe(self, table_name, describe_rows):
|
|
primary_key_names = self._fetch_primary_key_names(table_name)
|
|
table = Table(table_name.lower())
|
|
for row in describe_rows:
|
|
col_name, data_type = row[:2]
|
|
col_type = self.parse_col_desc(data_type)
|
|
col = self.create_column(col_name, col_type)
|
|
col.is_primary_key = col_name in primary_key_names
|
|
table.add_col(col)
|
|
return table
|
|
|
|
def describe_table(self, table_name):
|
|
'''Return a Table with table and col names always in lowercase.'''
|
|
describe_rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name))
|
|
table = self.create_table_from_describe(table_name, describe_rows)
|
|
# IMPALA-14398: Commenting this because, testing with Python 3.8, it hit
|
|
# dbm.error: db type could not be determined
|
|
# Commenting it does not affect any existing infra tests.
|
|
# Need to revisit if persist_unique_col_metadata and load_unique_col_metadata
|
|
# is still needed.
|
|
# self.load_unique_col_metadata(table)
|
|
return table
|
|
|
|
def make_describe_table_sql(self, table_name):
|
|
return 'DESCRIBE ' + table_name
|
|
|
|
def parse_data_type(self, type_name, type_size):
|
|
if type_name in ('DECIMAL', 'NUMERIC'):
|
|
return get_decimal_class(*type_size)
|
|
if type_name == 'CHAR':
|
|
return get_char_class(*type_size)
|
|
if type_name == 'VARCHAR':
|
|
if type_size and type_size[0] <= VarChar.MAX:
|
|
return get_varchar_class(*type_size)
|
|
type_name = 'STRING'
|
|
return self.TYPES_BY_NAME[type_name]
|
|
|
|
def create_table(self, table):
|
|
LOG.info('Creating table %s', table.name)
|
|
if not table.cols:
|
|
raise Exception('At least one col is required')
|
|
table_sql = self.make_create_table_sql(table)
|
|
LOG.debug(table_sql)
|
|
self.execute(table_sql)
|
|
LOG.debug('Created table %s', table.name)
|
|
|
|
def make_create_table_sql(self, table):
|
|
column_declarations = []
|
|
primary_key_names = []
|
|
for col in table.cols:
|
|
if col.is_primary_key:
|
|
null_constraint = ''
|
|
primary_key_names.append(col.name)
|
|
elif self.conn.data_types_are_implictly_nullable:
|
|
null_constraint = ''
|
|
else:
|
|
null_constraint = ' NULL'
|
|
column_declaration = '{name} {col_type}{null_constraint}'.format(
|
|
name=col.name, col_type=self.get_sql_for_data_type(col.exact_type),
|
|
null_constraint=null_constraint)
|
|
column_declarations.append(column_declaration)
|
|
|
|
if primary_key_names:
|
|
primary_key_constraint = ', PRIMARY KEY ({keys})'.format(
|
|
keys=', '.join(primary_key_names))
|
|
else:
|
|
primary_key_constraint = ''
|
|
|
|
create_table = ('CREATE TABLE {table_name} ('
|
|
'{all_columns}'
|
|
'{primary_key_constraint}'
|
|
')'.format(
|
|
table_name=table.name, all_columns=', '.join(column_declarations),
|
|
primary_key_constraint=primary_key_constraint))
|
|
return create_table
|
|
|
|
def get_sql_for_data_type(self, data_type):
|
|
if issubclass(data_type, VarChar):
|
|
return 'VARCHAR(%s)' % data_type.MAX
|
|
if issubclass(data_type, Char):
|
|
return 'CHAR(%s)' % data_type.MAX
|
|
if issubclass(data_type, Decimal):
|
|
return 'DECIMAL(%s, %s)' % (data_type.MAX_DIGITS, data_type.MAX_FRACTIONAL_DIGITS)
|
|
return self.EXACT_TYPES_TO_SQL[data_type]
|
|
|
|
def drop_table(self, table_name, if_exists=True):
|
|
LOG.info('Dropping table %s', table_name)
|
|
self.execute('DROP TABLE IF EXISTS ' + table_name.lower())
|
|
|
|
def drop_view(self, view_name, if_exists=True):
|
|
LOG.info('Dropping view %s', view_name)
|
|
self.execute('DROP VIEW IF EXISTS ' + view_name.lower())
|
|
|
|
def index_table(self, table_name):
|
|
LOG.info('Indexing table %s', table_name)
|
|
table = self.describe_table(table_name)
|
|
for col in table.cols:
|
|
index_name = '%s_%s' % (table_name, col.name)
|
|
if self.db_name:
|
|
index_name = '%s_%s' % (self.db_name, index_name)
|
|
# Older versions of Postgres require index name to be included in the statement.
|
|
# In the past, there were some failures when trying to index a table with long
|
|
# column names because there is a limit to how long index name is allowed to be,
|
|
# This is why index_name is hashed and truncated to 10 characters.
|
|
index_name = 'ind_' + hashlib.sha256(index_name).hexdigest()[:10]
|
|
self.execute('CREATE INDEX %s ON %s(%s)' % (index_name, table_name, col.name))
|
|
|
|
def search_for_unique_cols(self, table=None, table_name=None, depth=2):
|
|
if not table:
|
|
table = self.describe_table(table_name)
|
|
sql_templ = 'SELECT COUNT(*) FROM %s GROUP BY %%s HAVING COUNT(*) > 1' % table.name
|
|
unique_cols = list()
|
|
for current_depth in range(1, depth + 1):
|
|
for cols in combinations(table.cols, current_depth): # redundant combos excluded
|
|
cols = set(cols)
|
|
if any(filter(lambda unique_subset: unique_subset < cols, unique_cols)):
|
|
# cols contains a combo known to be unique
|
|
continue
|
|
col_names = ', '.join(col.name for col in cols)
|
|
sql = sql_templ % col_names
|
|
LOG.debug('Checking column combo (%s) for uniqueness' % col_names)
|
|
self.execute(sql)
|
|
if not self.fetchone():
|
|
LOG.debug('Found unique column combo (%s)' % col_names)
|
|
unique_cols.append(cols)
|
|
return unique_cols
|
|
|
|
def persist_unique_col_metadata(self, table):
|
|
if not table.unique_cols:
|
|
return
|
|
with closing(shelve.open('/tmp/query_generator.shelve', writeback=True)) as store:
|
|
if self.db_type not in store:
|
|
store[self.db_type] = dict()
|
|
db_type_store = store[self.db_type]
|
|
if self.db_name not in db_type_store:
|
|
db_type_store[self.db_name] = dict()
|
|
db_store = db_type_store[self.db_name]
|
|
db_store[table.name] = [[col.name for col in cols] for cols in table.unique_cols]
|
|
|
|
def load_unique_col_metadata(self, table):
|
|
with closing(shelve.open('/tmp/query_generator.shelve')) as store:
|
|
db_type_store = store.get(self.db_type)
|
|
if not db_type_store:
|
|
return
|
|
db_store = db_type_store.get(self.db_name)
|
|
if not db_store:
|
|
return
|
|
unique_col_names = db_store.get(table.name)
|
|
if not unique_col_names:
|
|
return
|
|
unique_cols = list()
|
|
for entry in unique_col_names:
|
|
col_names = set(entry)
|
|
cols = set((col for col in table.cols if col.name in col_names))
|
|
if len(col_names) != len(cols):
|
|
raise Exception("Incorrect unique column data for %s" % table.name)
|
|
unique_cols.append(cols)
|
|
table.unique_cols = unique_cols
|
|
|
|
def _fetch_primary_key_names(self, table_name):
|
|
"""
|
|
This must return a tuple of strings representing the primary keys of table_name,
|
|
or an empty tuple if there are no primary keys.
|
|
"""
|
|
# This is the base method. Since we haven't tested this on Oracle or Mysql or plan
|
|
# to implement this for those databases, the base method needs to return an empty
|
|
# tuple.
|
|
return ()
|
|
|
|
|
|
class DbConnection(with_metaclass(ABCMeta, object)):
|
|
|
|
LOCK = Lock()
|
|
|
|
PORT = None
|
|
USER_NAME = None
|
|
PASSWORD = None
|
|
|
|
_CURSOR_CLASS = DbCursor
|
|
|
|
def __init__(self, host_name="localhost", port=None, user_name=None, password=None,
|
|
db_name=None, log_sql=False):
|
|
self._host_name = host_name
|
|
self._port = port or self.PORT
|
|
self._user_name = user_name or self.USER_NAME
|
|
self._password = password or self.PASSWORD
|
|
self.db_name = db_name
|
|
self._conn = None
|
|
self._connect()
|
|
|
|
if log_sql:
|
|
with DbConnection.LOCK:
|
|
sql_log_path = gettempdir() + '/sql_log_%s_%s.sql' \
|
|
% (self.db_type.lower(), time())
|
|
self.sql_log = open(sql_log_path, 'w')
|
|
link = gettempdir() + '/sql_log_%s.sql' % self.db_type.lower()
|
|
try:
|
|
unlink(link)
|
|
except OSError as e:
|
|
if 'No such file' not in str(e):
|
|
raise e
|
|
try:
|
|
symlink(sql_log_path, link)
|
|
except OSError as e:
|
|
raise e
|
|
else:
|
|
self.sql_log = None
|
|
|
|
def _clone(self, db_name, **kwargs):
|
|
return type(self)(host_name=self._host_name, port=self._port,
|
|
user_name=self._user_name, password=self._password, db_name=db_name, **kwargs)
|
|
|
|
def clone(self, db_name):
|
|
return self._clone(db_name)
|
|
|
|
def __getattr__(self, attr):
|
|
if attr == "_conn":
|
|
raise AttributeError()
|
|
return getattr(self._conn, attr)
|
|
|
|
def __setattr__(self, attr, value):
|
|
_conn = getattr(self, "_conn", None)
|
|
if not _conn or not hasattr(_conn, attr):
|
|
object.__setattr__(self, attr, value)
|
|
else:
|
|
setattr(self._conn, attr, value)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.close(quiet=True)
|
|
|
|
@property
|
|
def db_type(self):
|
|
return self._DB_TYPE
|
|
|
|
@abstractmethod
|
|
def _connect(self):
|
|
pass
|
|
|
|
def reconnect(self):
|
|
self.close(quiet=True)
|
|
self._connect()
|
|
|
|
def close(self, quiet=False):
|
|
'''Close the underlying conn.'''
|
|
if not self._conn:
|
|
return
|
|
try:
|
|
self._conn.close()
|
|
self._conn = None
|
|
except Exception as e:
|
|
if quiet:
|
|
LOG.debug('Error closing connection: %s' % e)
|
|
else:
|
|
raise e
|
|
|
|
def cursor(self):
|
|
return self._CURSOR_CLASS(self, self._conn.cursor())
|
|
|
|
@property
|
|
def supports_kill(self):
|
|
return False
|
|
|
|
def kill(self):
|
|
'''Kill the current connection and any currently running queries associated with the
|
|
conn.
|
|
'''
|
|
raise Exception('Killing connection is not supported')
|
|
|
|
@property
|
|
def supports_index_creation(self):
|
|
return True
|
|
|
|
@property
|
|
def data_types_are_implictly_nullable(self):
|
|
return False
|
|
|
|
|
|
class ImpalaCursor(DbCursor):
|
|
|
|
PK_SEARCH_PATTERN = re.compile(r'PRIMARY KEY \((?P<keys>.*?)\)')
|
|
STORAGE_FORMATS_WITH_PRIMARY_KEYS = ('KUDU',)
|
|
|
|
@classmethod
|
|
def make_insert_sql_from_data(cls, table, rows):
|
|
if not rows:
|
|
raise Exception('At least one row is required')
|
|
if not table.cols:
|
|
raise Exception('At least one col is required')
|
|
|
|
sql = 'INSERT INTO %s VALUES ' % table.name
|
|
for row_idx, row in enumerate(rows):
|
|
if row_idx > 0:
|
|
sql += ', '
|
|
sql += '('
|
|
for col_idx, col in enumerate(table.cols):
|
|
if col_idx > 0:
|
|
sql += ', '
|
|
val = row[col_idx]
|
|
if val is None:
|
|
sql += 'NULL'
|
|
elif issubclass(col.type, Timestamp):
|
|
sql += "'%s'" % val
|
|
elif issubclass(col.type, Char):
|
|
val = val.replace("'", "''")
|
|
sql += "'%s'" % val
|
|
else:
|
|
sql += str(val)
|
|
sql += ')'
|
|
return sql
|
|
|
|
@property
|
|
def cluster(self):
|
|
return self.conn.cluster
|
|
|
|
def _fetch_primary_key_names(self, table_name):
|
|
self.execute("SHOW CREATE TABLE {0}".format(table_name))
|
|
# This returns 1 column with 1 multiline string row, resembling:
|
|
#
|
|
# CREATE TABLE db.table (
|
|
# pk1 BIGINT,
|
|
# pk2 BIGINT,
|
|
# col BIGINT,
|
|
# PRIMARY KEY (pk1, pk2)
|
|
# )
|
|
#
|
|
# Even a 1-column primary key will be shown as a PRIMARY KEY constraint, like:
|
|
#
|
|
# CREATE TABLE db.table (
|
|
# pk1 BIGINT,
|
|
# col BIGINT,
|
|
# PRIMARY KEY (pk1)
|
|
# )
|
|
(raw_result,) = self.fetchone()
|
|
search_result = ImpalaCursor.PK_SEARCH_PATTERN.search(raw_result)
|
|
if search_result is None:
|
|
return ()
|
|
else:
|
|
return tuple(search_result.group("keys").split(", "))
|
|
|
|
def invalidate_metadata(self, table_name=None):
|
|
self.execute("INVALIDATE METADATA %s" % (table_name or ""))
|
|
|
|
def drop_db(self, db_name):
|
|
'''This should not be called when connected to the database being dropped.'''
|
|
LOG.info("Dropping database %s", db_name)
|
|
self.execute('DROP DATABASE %s CASCADE' % db_name)
|
|
|
|
def compute_stats(self, table_name=None):
|
|
if table_name:
|
|
self.execute("COMPUTE STATS %s" % table_name)
|
|
else:
|
|
for table_name in self.list_table_names():
|
|
self.execute("COMPUTE STATS %s" % table_name)
|
|
|
|
def create_table(self, table):
|
|
# The Hive metastore has a limit to the amount of schema it can store inline.
|
|
# Beyond this limit, the schema needs to be stored in HDFS and Hive is given a
|
|
# URL instead.
|
|
if table.storage_format == "AVRO" and not table.schema_location \
|
|
and len(table.get_avro_schema()) > 4000:
|
|
self.upload_avro_schema(table)
|
|
super(ImpalaCursor, self).create_table(table)
|
|
|
|
def upload_avro_schema(self, table):
|
|
self.ensure_schema_location(table)
|
|
avro_schema = table.get_avro_schema()
|
|
self.cluster.hdfs.create_client().write(
|
|
table.schema_location, data=avro_schema, overwrite=True)
|
|
|
|
def ensure_schema_location(self, table):
|
|
if not table.schema_location:
|
|
db_name = self.conn.db_name or "default"
|
|
table.schema_location = '%s/%s.db/%s.avsc' % (
|
|
self.cluster.hive.warehouse_dir, db_name, table.name)
|
|
|
|
def ensure_storage_location(self, table):
|
|
if not table.storage_location:
|
|
db_name = self.conn.db_name or "default"
|
|
table.storage_location = '%s/%s.db/data/%s' % (
|
|
self.cluster.hive.warehouse_dir, db_name, table.name)
|
|
|
|
def make_create_table_sql(self, table):
|
|
sql = super(ImpalaCursor, self).make_create_table_sql(table)
|
|
|
|
if table.primary_keys:
|
|
if table.storage_format in ImpalaCursor.STORAGE_FORMATS_WITH_PRIMARY_KEYS:
|
|
# IMPALA-4424 adds support for parametrizing the partitions; for now, on our
|
|
# small scale, this is ok, especially since the model is to migrate tables from
|
|
# Impala into Postgres anyway. 3 was chosen for the buckets because our
|
|
# minicluster tends to have 3 tablet servers, but otherwise it's arbitrary and
|
|
# provides valid syntax for creating Kudu tables in Impala.
|
|
sql += '\nPARTITION BY HASH ({col}) PARTITIONS 3'.format(
|
|
col=table.primary_key_names[0])
|
|
else:
|
|
raise Exception(
|
|
'table representation has primary keys {keys} but is not in a format that '
|
|
'supports them: {storage_format}'.format(
|
|
keys=str(table.primary_key_names),
|
|
storage_format=table.storage_format))
|
|
elif table.storage_format in ImpalaCursor.STORAGE_FORMATS_WITH_PRIMARY_KEYS:
|
|
raise Exception(
|
|
'table representation has storage format {storage_format} '
|
|
'but does not have any primary keys'.format(
|
|
storage_format=table.storage_format))
|
|
|
|
if table.storage_format != 'TEXTFILE':
|
|
sql += "\nSTORED AS " + table.storage_format
|
|
if table.storage_location:
|
|
sql = sql.replace("CREATE TABLE", "CREATE EXTERNAL TABLE")
|
|
sql += "\nLOCATION '%s'" % table.storage_location
|
|
if table.storage_format == 'AVRO':
|
|
if table.schema_location:
|
|
sql += "\nTBLPROPERTIES ('avro.schema.url' = '%s')" % table.schema_location
|
|
else:
|
|
avro_schema = table.get_avro_schema()
|
|
if len(avro_schema) > 4000:
|
|
raise Exception("Avro schema exceeds 4000 character limit. Create a file"
|
|
" containing the schema instead and set 'table.schema_location'.")
|
|
sql += "\nTBLPROPERTIES ('avro.schema.literal' = '%s')" % avro_schema
|
|
return sql
|
|
|
|
def get_sql_for_data_type(self, data_type):
|
|
if issubclass(data_type, String):
|
|
return 'STRING'
|
|
return super(ImpalaCursor, self).get_sql_for_data_type(data_type)
|
|
|
|
def close(self, quiet=False):
|
|
try:
|
|
# Explicitly close the operation to avoid issues like IMPALA-2562.
|
|
# This can be remove if https://github.com/cloudera/impyla/pull/142 is merged.
|
|
self._cursor.close_operation()
|
|
self._cursor.close()
|
|
except Exception as e:
|
|
if quiet:
|
|
LOG.debug('Error closing cursor: %s', e, exc_info=True)
|
|
else:
|
|
raise e
|
|
|
|
|
|
class ImpalaConnection(DbConnection):
|
|
|
|
PORT = 21050 # For HS2
|
|
|
|
_DB_TYPE = IMPALA
|
|
_CURSOR_CLASS = ImpalaCursor
|
|
_KERBEROS_SERVICE_NAME = 'impala'
|
|
_NON_KERBEROS_AUTH_MECH = 'NOSASL'
|
|
|
|
def __init__(self, use_kerberos=False, use_ssl=False, ca_cert=None, **kwargs):
|
|
self._use_kerberos = use_kerberos
|
|
self.cluster = None
|
|
self._use_ssl = use_ssl
|
|
self._ca_cert = ca_cert
|
|
DbConnection.__init__(self, **kwargs)
|
|
|
|
def clone(self, db_name):
|
|
clone = self._clone(db_name, use_kerberos=self._use_kerberos)
|
|
clone.cluster = self.cluster
|
|
return clone
|
|
|
|
@property
|
|
def data_types_are_implictly_nullable(self):
|
|
return True
|
|
|
|
@property
|
|
def supports_index_creation(self):
|
|
return False
|
|
|
|
def cursor(self):
|
|
cursor = super(ImpalaConnection, self).cursor()
|
|
cursor.arraysize = 1024 # Try to match the default batch size
|
|
return cursor
|
|
|
|
def _connect(self):
|
|
self._conn = impala.dbapi.connect(
|
|
host=self._host_name,
|
|
port=self._port,
|
|
user=self._user_name,
|
|
password=self._password,
|
|
database=self.db_name,
|
|
timeout=(60 * 60),
|
|
auth_mechanism=('GSSAPI' if self._use_kerberos else self._NON_KERBEROS_AUTH_MECH),
|
|
kerberos_service_name=self._KERBEROS_SERVICE_NAME,
|
|
use_ssl=self._use_ssl,
|
|
ca_cert=self._ca_cert)
|
|
|
|
|
|
class HiveCursor(ImpalaCursor):
|
|
|
|
def invalidate_metadata(self, table_name=None):
|
|
# There is no equivalent of "INVALIDATE METADATA" in Hive
|
|
pass
|
|
|
|
def compute_stats(self, table_name=None):
|
|
if table_name:
|
|
self.execute("ANALYZE TABLE %s COMPUTE STATISTICS" % table_name)
|
|
self.execute("ANALYZE TABLE %s COMPUTE STATISTICS FOR COLUMNS" % table_name)
|
|
else:
|
|
for table_name in self.list_table_names():
|
|
self.execute("ANALYZE TABLE %s COMPUTE STATISTICS" % table_name)
|
|
self.execute("ANALYZE TABLE %s COMPUTE STATISTICS FOR COLUMNS" % table_name)
|
|
|
|
|
|
class HiveConnection(ImpalaConnection):
|
|
|
|
PORT = 11050
|
|
|
|
_DB_TYPE = HIVE
|
|
_CURSOR_CLASS = HiveCursor
|
|
_KERBEROS_SERVICE_NAME = 'hive'
|
|
_NON_KERBEROS_AUTH_MECH = 'PLAIN'
|
|
|
|
|
|
class PostgresqlCursor(DbCursor):
|
|
|
|
TYPE_NAME_ALIASES = dict(DbCursor.TYPE_NAME_ALIASES)
|
|
TYPE_NAME_ALIASES.update({
|
|
'INTEGER': 'INT',
|
|
'NUMERIC': 'DECIMAL',
|
|
'REAL': 'FLOAT',
|
|
'DOUBLE PRECISION': 'DOUBLE',
|
|
'CHARACTER': 'CHAR',
|
|
'CHARACTER VARYING': 'VARCHAR',
|
|
'TIMESTAMP WITHOUT TIME ZONE': 'TIMESTAMP'})
|
|
|
|
EXACT_TYPES_TO_SQL = dict(DbCursor.EXACT_TYPES_TO_SQL)
|
|
EXACT_TYPES_TO_SQL.update({
|
|
Double: 'DOUBLE PRECISION',
|
|
Float: 'REAL',
|
|
Int: 'INTEGER',
|
|
Timestamp: 'TIMESTAMP WITHOUT TIME ZONE',
|
|
TinyInt: 'SMALLINT'})
|
|
|
|
@classmethod
|
|
def make_insert_sql_from_data(cls, table, rows):
|
|
if not rows:
|
|
raise Exception('At least one row is required')
|
|
if not table.cols:
|
|
raise Exception('At least one col is required')
|
|
|
|
sql = 'INSERT INTO %s VALUES ' % table.name
|
|
for row_idx, row in enumerate(rows):
|
|
if row_idx > 0:
|
|
sql += ', '
|
|
sql += '('
|
|
for col_idx, col in enumerate(table.cols):
|
|
if col_idx > 0:
|
|
sql += ', '
|
|
val = row[col_idx]
|
|
if val is None:
|
|
sql += 'NULL'
|
|
elif issubclass(col.type, Timestamp):
|
|
sql += "TIMESTAMP '%s'" % val
|
|
elif issubclass(col.type, Char):
|
|
val = val.replace("'", "''")
|
|
val = val.replace('\\', '\\\\')
|
|
sql += "'%s'" % val
|
|
else:
|
|
sql += str(val)
|
|
sql += ')'
|
|
|
|
return sql
|
|
|
|
def make_list_db_names_sql(self):
|
|
return 'SELECT datname FROM pg_database'
|
|
|
|
def make_list_table_names_sql(self):
|
|
return '''
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public' '''
|
|
|
|
def make_describe_table_sql(self, table_name):
|
|
# When doing a CREATE TABLE AS SELECT... a column may end up with type "Numeric".
|
|
# We'll assume that's a DOUBLE.
|
|
return '''
|
|
SELECT column_name,
|
|
CASE data_type
|
|
WHEN 'character' THEN
|
|
data_type || '(' || character_maximum_length || ')'
|
|
WHEN 'character varying' THEN
|
|
data_type || '(' || character_maximum_length || ')'
|
|
WHEN 'numeric' THEN
|
|
data_type
|
|
|| '('
|
|
|| numeric_precision
|
|
|| ', '
|
|
|| numeric_scale
|
|
|| ')'
|
|
ELSE data_type
|
|
END data_type
|
|
FROM information_schema.columns
|
|
WHERE table_name = '%s'
|
|
ORDER BY ordinal_position''' % \
|
|
table_name
|
|
|
|
def get_sql_for_data_type(self, data_type):
|
|
if issubclass(data_type, String):
|
|
return 'VARCHAR(%s)' % String.MAX
|
|
return super(PostgresqlCursor, self).get_sql_for_data_type(data_type)
|
|
|
|
def _fetch_primary_key_names(self, table_name):
|
|
# see:
|
|
# https://www.postgresql.org/docs/9.5/static/infoschema-key-column-usage.html
|
|
# https://www.postgresql.org/docs/9.5/static/infoschema-table-constraints.html
|
|
sql = dedent('''
|
|
SELECT
|
|
key_cols.column_name AS column_name
|
|
FROM
|
|
information_schema.key_column_usage key_cols,
|
|
information_schema.table_constraints table_constraints
|
|
WHERE
|
|
key_cols.constraint_catalog = table_constraints.constraint_catalog AND
|
|
key_cols.table_name = table_constraints.table_name AND
|
|
key_cols.constraint_name = table_constraints.constraint_name AND
|
|
table_constraints.constraint_type = 'PRIMARY KEY' AND
|
|
key_cols.table_name = '{table_name}'
|
|
ORDER BY key_cols.ordinal_position'''.format(table_name=table_name))
|
|
self.execute(sql)
|
|
rows = self.fetchall()
|
|
return tuple(row[0] for row in rows)
|
|
|
|
|
|
class PostgresqlConnection(DbConnection):
|
|
|
|
PORT = 5432
|
|
USER_NAME = "postgres"
|
|
|
|
_DB_TYPE = POSTGRESQL
|
|
_CURSOR_CLASS = PostgresqlCursor
|
|
|
|
@property
|
|
def supports_kill(self):
|
|
return True
|
|
|
|
def kill(self):
|
|
self._conn.cancel()
|
|
|
|
def _connect(self):
|
|
try:
|
|
import psycopg2
|
|
except ImportError as e:
|
|
if "No module named psycopg2" not in str(e):
|
|
raise e
|
|
import os
|
|
import subprocess
|
|
from tests.util.shell_util import shell
|
|
LOG.info("psycopg2 module not found; attempting to install it...")
|
|
pip_path = os.path.join(os.environ["IMPALA_HOME"], "infra", "python", "env",
|
|
"bin", "pip")
|
|
try:
|
|
shell(pip_path + " install psycopg2", stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT)
|
|
LOG.info("psycopg2 installation complete.")
|
|
except Exception as e:
|
|
LOG.error("psycopg2 installation failed. Try installing python-dev and"
|
|
" libpq-dev then try again.")
|
|
raise e
|
|
import psycopg2
|
|
self._conn = psycopg2.connect(
|
|
host=self._host_name,
|
|
port=self._port,
|
|
user=self._user_name,
|
|
password=self._password,
|
|
database=self.db_name)
|
|
self._conn.autocommit = True
|
|
|
|
|
|
class MySQLConnection(DbConnection):
|
|
|
|
PORT = 3306
|
|
USER_NAME = "root"
|
|
|
|
def __init__(self, client, conn, db_name=None):
|
|
DbConnection.__init__(self, client, conn, db_name=db_name)
|
|
self._session_id = self.execute_and_fetchall('SELECT connection_id()')[0][0]
|
|
|
|
def _connect(self):
|
|
try:
|
|
import MySQLdb
|
|
except Exception:
|
|
print('Error importing MySQLdb. Please make sure it is installed. '
|
|
'See the README for details.')
|
|
raise
|
|
self._conn = MySQLdb.connect(
|
|
host=self._host_name,
|
|
port=self._port,
|
|
user=self._user_name,
|
|
passwd=self._password,
|
|
db=self.db_name)
|
|
self._conn.autocommit = True
|
|
|
|
@property
|
|
def supports_kill_connection(self):
|
|
return True
|
|
|
|
def kill_connection(self):
|
|
with self._client.open_connection(db_name=self.db_name) as conn:
|
|
conn.execute('KILL %s' % (self._session_id))
|
|
|
|
|
|
class MySQLCursor(DbCursor):
|
|
|
|
def describe_table(self, table_name):
|
|
'''Return a Table with table and col names always in lowercase.'''
|
|
rows = self.conn.execute_and_fetchall(
|
|
self.make_describe_table_sql(table_name))
|
|
table = Table(table_name.lower())
|
|
cols = table.cols # This is a copy
|
|
for row in rows:
|
|
col_name, data_type = row[:2]
|
|
if data_type == 'tinyint(1)':
|
|
# Just assume this is a boolean...
|
|
data_type = 'boolean'
|
|
if 'decimal' not in data_type and '(' in data_type:
|
|
# Strip the size of the data type
|
|
data_type = data_type[:data_type.index('(')]
|
|
cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type)))
|
|
table.cols = cols
|
|
return table
|
|
|
|
def index_table(self, table_name):
|
|
table = self.describe_table(table_name)
|
|
with self.conn.open_cursor() as cursor:
|
|
for col in table.cols:
|
|
try:
|
|
cursor.execute('ALTER TABLE %s ADD INDEX (%s)' % (table_name, col.name))
|
|
except Exception as e:
|
|
if 'Incorrect index name' not in str(e):
|
|
raise
|
|
# Some sort of MySQL bug...
|
|
LOG.warn('Could not create index on %s.%s: %s' % (table_name, col.name, e))
|
|
|
|
def make_create_table_sql(self, table):
|
|
table_sql = super(MySQLConnection, self).make_create_table_sql(table)
|
|
table_sql += ' ENGINE = MYISAM'
|
|
return table_sql
|
|
|
|
|
|
class OracleCursor(DbCursor):
|
|
|
|
def make_list_table_names_sql(self):
|
|
return 'SELECT table_name FROM user_tables'
|
|
|
|
def drop_db(self, db_name):
|
|
self.execute('DROP USER %s CASCADE' % db_name)
|
|
|
|
def drop_db_if_exists(self, db_name):
|
|
'''This should not be called from a conn to the database being dropped.'''
|
|
try:
|
|
self.drop_db(db_name)
|
|
except Exception as e:
|
|
if 'ORA-01918' not in str(e): # Ignore if the user doesn't exist
|
|
raise
|
|
|
|
def create_db(self, db_name):
|
|
self.execute(
|
|
'CREATE USER %s IDENTIFIED BY %s DEFAULT TABLESPACE USERS' % (db_name, db_name))
|
|
self.execute('GRANT ALL PRIVILEGES TO %s' % db_name)
|
|
|
|
def make_describe_table_sql(self, table_name):
|
|
# Recreate the data types as defined in the model
|
|
return '''
|
|
SELECT
|
|
column_name,
|
|
CASE
|
|
WHEN data_type = 'NUMBER' AND data_scale = 0 THEN
|
|
data_type || '(' || data_precision || ')'
|
|
WHEN data_type = 'NUMBER' THEN
|
|
data_type || '(' || data_precision || ', ' || data_scale || ')'
|
|
WHEN data_type IN ('CHAR', 'VARCHAR2') THEN
|
|
data_type || '(' || data_length || ')'
|
|
WHEN data_type LIKE 'TIMESTAMP%%' THEN
|
|
'TIMESTAMP'
|
|
ELSE data_type
|
|
END
|
|
FROM all_tab_columns
|
|
WHERE owner = '%s' AND table_name = '%s'
|
|
ORDER BY column_id''' \
|
|
% (self.schema.upper(), table_name.upper())
|
|
|
|
|
|
class OracleConnection(DbConnection):
|
|
|
|
PORT = 1521
|
|
USER_NAME = 'system'
|
|
PASSWORD = 'oracle'
|
|
|
|
_CURSOR_CLASS = OracleCursor
|
|
|
|
def __init__(self, service='XE', **kwargs):
|
|
self._service = service
|
|
DbConnection.__init__(self, **kwargs)
|
|
self._conn.outputtypehandler = OracleTypeConverter.convert_type
|
|
self._conn.autocommit = True
|
|
|
|
def clone(self, db_name):
|
|
return self._clone(db_name, service=self._service)
|
|
|
|
@property
|
|
def schema(self):
|
|
return self.db_name or self._user_name
|
|
|
|
def _connect(self):
|
|
try:
|
|
import cx_Oracle
|
|
except Exception:
|
|
print('Error importing cx_Oracle. Please make sure it is installed. '
|
|
'See the README for details.')
|
|
raise
|
|
self._conn = cx_Oracle.connect('{}/{}@{}:{}/{}'.format(
|
|
self._user_name, self._password, self._host_name, self._port, self._service))
|
|
|
|
def cursor(self):
|
|
cursor = super(OracleConnection, self).cursor()
|
|
if self.db_name and self.db_name != self._user_name:
|
|
cursor.execute('ALTER SESSION SET CURRENT_SCHEMA = %s' % self.db_name)
|
|
return cursor
|
|
|
|
|
|
class OracleTypeConverter(object):
|
|
|
|
__imported_types = False
|
|
__char_type = None
|
|
__number_type = None
|
|
|
|
@classmethod
|
|
def convert_type(cls, cursor, name, default_type, size, precision, scale):
|
|
if not cls.__imported_types:
|
|
from cx_Oracle import FIXED_CHAR, NUMBER
|
|
cls.__char_type = FIXED_CHAR
|
|
cls.__number_type = NUMBER
|
|
cls.__imported_types = True
|
|
|
|
if default_type == cls.__char_type and size == 1:
|
|
return cursor.var(str, 1, cursor.arraysize, outconverter=cls.convert_boolean)
|
|
if default_type == cls.__number_type and scale:
|
|
return cursor.var(str, 100, cursor.arraysize, outconverter=PyDecimal)
|
|
|
|
@classmethod
|
|
def convert_boolean(cls, val):
|
|
if not val:
|
|
return
|
|
return val == 'T'
|