Testing: Generate queries and compare results against other databases

This is the intital commit and is a work in progress. See the README for a
list of possible improvements.

As an overview of how the files are related:

  model.py: This is the base upon which the other files are built. It
      contains something like a grammer for queries.

  query_generator.py: Generates random permutations of the model.

  model_translator.py: Produces SQL based on the model

  discrepancy_searcher.py: Uses the above to generate, run, and compare
      query results.

Change-Id: Iaca6277766f5a86568eaa3f05b99c832942ab38b
Reviewed-on: http://gerrit.ent.cloudera.com:8080/1648
Reviewed-by: Casey Ching <casey@cloudera.com>
Tested-by: Casey Ching <casey@cloudera.com>
This commit is contained in:
casey
2014-02-06 19:38:59 -08:00
committed by ishaan
parent 8d8ea28820
commit 192d52c258
9 changed files with 3178 additions and 1 deletions

View File

@@ -14,4 +14,4 @@ DATASRC="a2226.halxg.cloudera.com:/data/1/workspace/impala-data"
DATADST=$IMPALA_HOME/testdata/impala-data
mkdir -p $DATADST
scp -i $IMPALA_HOME/ssh_keys/id_rsa_impala -o "StrictHostKeyChecking=no" -r $DATASRC/* $DATADST
scp -i $HOME/.ssh/id_rsa_jenkins -o "StrictHostKeyChecking=no" -r systest@$DATASRC/* $DATADST

164
tests/comparison/README Normal file
View File

@@ -0,0 +1,164 @@
Purpose:
This package is intended to augment the standard test suite. The standard tests are
more efficient with regards to features tested versus execution time. However their
coverage as a test suite still leaves gaps in query coverage. This package provides a
random query generator to compare the results of a wide range of queries against a
reference database engine. The queries will range from very simple single table selects to
extremely complicated with multiple level of nesting. This method of testing will be
slower but has a larger coverage area.
Requirements:
1) It's assumed that Impala is running locally.
2) Impyla -- an implementation of DB API 2 for Impala.
sudo pip install git+http://github.com/laserson/impyla.git#impyla
3) At least one python driver for a reference database.
sudo apt-get install python-mysqldb
sudo apt-get install python-psycopg2 # Postgresql
Usage:
1) Generate test data
./data_generator.py --use-mysql
This will generate tables and data in MySQL and Impala
2) Run the comparison
./discrepancy_searcher.py
This will generate queries using the test database and compare the results against
MySQL (the default).
Known Issues:
1) Floats will produce false-positives. For example the results of a query that has
SELECT FLOOR(COUNT(...) * AVG(...)) AS col_1
will produce different results on Impala and MySQL if COUNT() == 3 and AVG() == 1/3.
One of the databasses will FLOOR(1) while the other will FLOOR(0.999).
Maybe this could be worked around or reduced by replacing all uses of AVG() with
"AVG() + foo", where foo is some number that makes it unlikely that
"COUNT() * (AVG() + foo)" will result in an int.
I'd guess this issue comes up in 1 out of 10-20k queries.
2) Impyla may fail with "Invalid query handle". Some queries will fail every time when run
through Impyla but run fine through the impala-shell. I need to research more and file
an issue with Impyla.
3) Impyla will fail with "Invalid session". I'm pretty sure this is also an Impyla issue
but also need to investigate more.
Things to Know:
1) A good number of queries to run seems to be about 5k. Ideally each test run would
discover the complete list of known issues. From experience a 1k query test run may
complete without finding any issues that were discovered in previous runs. 5k seems
to be about the magic number were most issues will be rediscovered. This can take 1-2
hours. However as of this writing it's rare to run 1k queries without finding at
least one discrepancy.
2) It's possible to provide a randomization seed so that the randomness is actually
reproducable. The data generation currently has a default seed so will always produce
the same tables. This also mean if a new data type is added those generated tables
will change.
3) There is a query log. It's possible that a sequence of queries is required to expose
a bug. If you come across a failure that can't be reproduced by rerunning the failed
query, try running the queries leading up to that query as well.
Miscellaneous:
1) Instead of generating new random queries with each run, it may be better to reuse a
list of queries from a previous run that are known to produce results. As of this
writing only about 50% of queries produce results. So it may be better to trade high
randomness for higher quality queries. For example it would be possible to build up a
library of 100k queries that produce results then randomly select 2.5k of those.
Maybe that would provide testing equivalent to 5k totally random queries in less
time.
This would also be useful in eliminating queries that have known issues above.
Postgresql:
1) Supports bascially all Impala language features
2) Does int division, 1 / 2 == 0
3) Has strange sorting of strings, '-1' > '1'. This may be important if ORDER BY is ever
used. The databases being compared would need to have the same collation, which is
probably configurable.
4) This was the original reference database but I moved to MySQL while trying to add
support for floats and never moved back.
MySQL:
1) Supports bascially all Impala language features, except WITH clause requires emulation
with inline views.
2) Has poor boolean support. It may be worth switching back to Postgresql for this.
Improvements:
1) Add support for simplifing buggy queries. When a random query fails the comparison
check it is basically always much too complex for directly posting a bug report. It
is also time consuming to simplify the queries because there is a lot of trial and
error and manually editing queries.
2) Add more language features
a) SEMI JOIN
b) ORDER BY
c) LIMIT, OFFSET
d) CASE / WHEN
3) Add common built-in functions. Ex: CAST, IF, NVL, ...
4) Make randomization of the query generation configurable. As of this writing all the
probabilities are hard-coded. At a very minimum it should be easy to disable or force
the use of some language features such as CROSS JOIN, GROUP BY, etc.
5) More investingation of using the existing "functional" test datasets. A very quick
trial run wasn't successful but another attempt with more effort should be made before
introducing a new dataset.
I suspect the problem with using the functional dataset was that I only imported a few
tables, maybe alltypes, alltypesagg, and something else. I don't think I imported the
tiny tables since the odds of them producing results from a random query would be
very low.
6) If the functional dataset cannot be used, someone should think more about what the
random data should be like. Only a few minutes of thought were put into selecting
random value ranges (including number of tables and columns), and it's not clear how
important those ranges are.
7) Add support for comparing results with codegen enabled and disabled. Uri recently added
support for query options in Impyla.
8) Consider adding Oracle or SQL Server support, these could be useful in the future for
analytic queries.
9) Try running with tables in various formats. Ex: parquet and/or avro.
10) Support for more data types. Only int types are known to give good results.
Floats may work but non-numeric types are not supported yet.

View File

View File

@@ -0,0 +1,409 @@
#!/usr/bin/env python
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''This module provides random data generation and database population.
When this module is run directly for purposes of database population, the default is
to use a fixed seed for randomization. The result should be that the generated random
data is the same regardless of when or where the execution is done.
'''
from datetime import datetime, timedelta
from logging import basicConfig, getLogger
from random import choice, randint, random, seed, uniform
from tests.comparison.db_connector import (
DbConnector,
IMPALA,
MYSQL,
POSTGRESQL)
from tests.comparison.model import (
Boolean,
Column,
Float,
Int,
Number,
String,
Table,
Timestamp,
TYPES)
LOG = getLogger(__name__)
class RandomValGenerator(object):
'''This class will generate random data of various data types. Currently only numeric
and string data types are supported.
'''
def __init__(self,
min_number=-1000,
max_number=1000,
min_date=datetime(1990, 1, 1),
max_date=datetime(2030, 1, 1),
null_val_percentage=0.1):
self.min_number = min_number
self.max_number = max_number
self.min_date = min_date
self.max_date = max_date
self.null_val_percentage = null_val_percentage
def generate_val(self, val_type):
'''Generate and return a single random val. Use the val_type parameter to
specify the type of val to generate. See model.DataType for valid val_type
options.
Ex:
generator = RandomValGenerator(min_number=1, max_number=5)
val = generator.generate_val(model.Int)
assert 1 <= val and val <= 5
'''
if issubclass(val_type, String):
val = self.generate_val(Int)
return None if val is None else str(val)
if random() < self.null_val_percentage:
return None
if issubclass(val_type, Int):
return randint(
max(self.min_number, val_type.MIN), min(val_type.MAX, self.max_number))
if issubclass(val_type, Number):
return uniform(self.min_number, self.max_number)
if issubclass(val_type, Timestamp):
delta = self.max_date - self.min_date
delta_in_seconds = delta.days * 24 * 60 * 60 + delta.seconds
offset_in_seconds = randint(0, delta_in_seconds)
val = self.min_date + timedelta(0, offset_in_seconds)
return datetime(val.year, val.month, val.day)
if issubclass(val_type, Boolean):
return randint(0, 1) == 1
raise Exception('Unsupported type %s' % val_type.__name__)
class DatabasePopulator(object):
'''This class will populate a database with randomly generated data. The population
includes table creation and data generation. Table names are hard coded as
table_<table number>.
'''
def __init__(self):
self.val_generator = RandomValGenerator()
def populate_db_with_random_data(self,
db_name,
db_connectors,
number_of_tables=10,
allowed_data_types=TYPES,
create_files=False):
'''Create tables with a random number of cols with data types chosen from
allowed_data_types, then fill the tables with data.
The given db_name must have already been created.
'''
connections = [connector.create_connection(db_name=db_name)
for connector in db_connectors]
for table_idx in xrange(number_of_tables):
table = self.create_random_table(
'table_%s' % (table_idx + 1),
allowed_data_types=allowed_data_types)
for connection in connections:
sql = self.make_create_table_sql(table, dialect=connection.db_type)
LOG.info('Creating %s table %s', connection.db_type, table.name)
if create_files:
with open('%s_%s.sql' % (table.name, connection.db_type.lower()), 'w') \
as f:
f.write(sql + '\n')
connection.execute(sql)
LOG.info('Inserting data into %s', table.name)
for _ in xrange(100): # each iteration will insert 100 rows
rows = self.generate_table_data(table)
for connection in connections:
sql = self.make_insert_sql_from_data(
table, rows, dialect=connection.db_type)
if create_files:
with open('%s_%s.sql' %
(table.name, connection.db_type.lower()), 'a') as f:
f.write(sql + '\n')
try:
connection.execute(sql)
except:
LOG.error('Error executing SQL: %s', sql)
raise
self.index_tables_in_database(connections)
for connection in connections:
connection.close()
def migrate_database(self,
db_name,
source_db_connector,
destination_db_connectors,
include_table_names=None):
'''Read table metadata and data from the source database and create a replica in
the destination databases. For example, the Impala funcal test database could
be copied into Postgresql.
source_db_connector and items in destination_db_connectors should be
of type db_connector.DbConnector. destination_db_connectors and
include_table_names should be iterables.
'''
source_connection = source_db_connector.create_connection(db_name)
cursors = [connector.create_connection(db_name=db_name).create_cursor()
for connector in destination_db_connectors]
for table_name in source_connection.list_table_names():
if include_table_names and table_name not in include_table_names:
continue
try:
table = source_connection.describe_table(table_name)
except Exception as e:
LOG.warn('Error fetching metadata for %s: %s', table_name, e)
continue
for destination_cursor in cursors:
sql = self.make_create_table_sql(
table, dialect=destination_cursor.connection.db_type)
destination_cursor.execute(sql)
with source_connection.open_cursor() as source_cursor:
try:
source_cursor.execute('SELECT * FROM ' + table_name)
while True:
rows = source_cursor.fetchmany(size=100)
if not rows:
break
for destination_cursor in cursors:
sql = self.make_insert_sql_from_data(
table, rows, dialect=destination_cursor.connection.db_type)
destination_cursor.execute(sql)
except Exception as e:
LOG.error('Error fetching data for %s: %s', table_name, e)
continue
self.index_tables_in_database([cursor.connection for cursor in cursors])
for cursor in cursors:
cursor.close()
cursor.connection.close()
def create_random_table(self, table_name, allowed_data_types):
'''Create and return a Table with a random number of cols chosen from the
given allowed_data_types.
'''
data_type_count = len(allowed_data_types)
col_count = randint(data_type_count / 2, data_type_count * 2)
table = Table(table_name)
for col_idx in xrange(col_count):
col_type = choice(allowed_data_types)
col = Column(
table,
'%s_col_%s' % (col_type.__name__.lower(), col_idx + 1),
col_type)
table.cols.append(col)
return table
def make_create_table_sql(self, table, dialect=IMPALA):
sql = 'CREATE TABLE %s (%s)' % (
table.name,
', '.join('%s %s' %
(col.name, self.get_sql_for_data_type(col.type, dialect)) +
('' if dialect == IMPALA else ' NULL')
for col in table.cols))
if dialect == MYSQL:
sql += ' ENGINE = MYISAM'
return sql
def get_sql_for_data_type(self, data_type, dialect=IMPALA):
# Check to see if there is an alias and if so, use the first one
if hasattr(data_type, dialect):
return getattr(data_type, dialect)[0]
return data_type.__name__.upper()
def make_insert_sql_from_data(self, table, rows, dialect=IMPALA):
# TODO: Consider using parameterized inserts so the database connector handles
# formatting the data. For example the CAST to workaround IMPALA-803 can
# probably be removed. The vals were generated this way so a data file
# could be made and attached to jiras.
if not rows:
raise Exception('At least one row is required')
if not table.cols:
raise Exception('At least one col is required')
sql = 'INSERT INTO %s VALUES ' % table.name
for row_idx, row in enumerate(rows):
if row_idx > 0:
sql += ', '
sql += '('
for col_idx, col in enumerate(table.cols):
if col_idx > 0:
sql += ', '
val = row[col_idx]
if val is None:
sql += 'NULL'
elif issubclass(col.type, Timestamp):
if dialect != IMPALA:
sql += 'TIMESTAMP '
sql += "'%s'" % val
elif issubclass(col.type, String):
val = val.replace("'", "''")
if dialect == POSTGRESQL:
val = val.replace('\\', '\\\\')
sql += "'%s'" % val
elif dialect == IMPALA \
and issubclass(col.type, Float):
# https://issues.cloudera.org/browse/IMPALA-803
sql += 'CAST(%s AS FLOAT)' % val
else:
sql += str(val)
sql += ')'
return sql
def generate_table_data(self, table, number_of_rows=100):
rows = list()
for row_idx in xrange(number_of_rows):
row = list()
for col in table.cols:
row.append(self.val_generator.generate_val(col.type))
rows.append(row)
return rows
def drop_and_create_database(self, db_name, db_connectors):
for connector in db_connectors:
with connector.open_connection() as connection:
connection.drop_db_if_exists(db_name)
connection.execute('CREATE DATABASE ' + db_name)
def index_tables_in_database(self, connections):
for connection in connections:
if connection.supports_index_creation:
for table_name in connection.list_table_names():
LOG.info('Indexing %s on %s' % (table_name, connection.db_type))
connection.index_table(table_name)
if __name__ == '__main__':
from optparse import NO_DEFAULT, OptionGroup, OptionParser
parser = OptionParser(
usage='usage: \n'
' %prog [options] [populate]\n\n'
' Create and populate database(s). The Impala database will always be \n'
' included, the other database types are optional.\n\n'
' %prog [options] migrate\n\n'
' Migrate an Impala database to another database type. The destination \n'
' database will be dropped and recreated.')
parser.add_option('--log-level', default='INFO',
help='The log level to use.', choices=('DEBUG', 'INFO', 'WARN', 'ERROR'))
parser.add_option('--db-name', default='randomness',
help='The name of the database to use. Ex: functional.')
group = OptionGroup(parser, 'MySQL Options')
group.add_option('--use-mysql', action='store_true', default=False,
help='Use MySQL')
group.add_option('--mysql-host', default='localhost',
help='The name of the host running the MySQL database.')
group.add_option('--mysql-port', default=3306, type=int,
help='The port of the host running the MySQL database.')
group.add_option('--mysql-user', default='root',
help='The user name to use when connecting to the MySQL database.')
group.add_option('--mysql-password',
help='The password to use when connecting to the MySQL database.')
parser.add_option_group(group)
group = OptionGroup(parser, 'Postgresql Options')
group.add_option('--use-postgresql', action='store_true', default=False,
help='Use Postgresql')
group.add_option('--postgresql-host', default='localhost',
help='The name of the host running the Postgresql database.')
group.add_option('--postgresql-port', default=5432, type=int,
help='The port of the host running the Postgresql database.')
group.add_option('--postgresql-user', default='postgres',
help='The user name to use when connecting to the Postgresql database.')
group.add_option('--postgresql-password',
help='The password to use when connecting to the Postgresql database.')
parser.add_option_group(group)
group = OptionGroup(parser, 'Database Population Options')
group.add_option('--randomization-seed', default=1, type='int',
help='The randomization will be initialized with this seed. Using the same seed '
'will produce the same results across runs.')
group.add_option('--create-data-files', default=False, action='store_true',
help='Create files that can be used to repopulate the databasese elsewhere.')
group.add_option('--table-count', default=10, type='int',
help='The number of tables to generate.')
parser.add_option_group(group)
group = OptionGroup(parser, 'Database Migration Options')
group.add_option('--migrate-table-names',
help='Table names should be separated with commas. The default is to migrate all '
'tables.')
parser.add_option_group(group)
for group in parser.option_groups + [parser]:
for option in group.option_list:
if option.default != NO_DEFAULT:
option.help += ' [default: %default]'
options, args = parser.parse_args()
command = args[0] if args else 'populate'
if len(args) > 1 or command not in ['populate', 'migrate']:
raise Exception('Command must either be "populate" or "migrate" but was "%s"' %
' '.join(args))
if command == 'migrate' and not any((options.use_mysql, options.use_postgresql)):
raise Exception('At least one destination database must be chosen with '
'--use-<database type>')
basicConfig(level=options.log_level)
seed(options.randomization_seed)
impala_connector = DbConnector(IMPALA)
db_connectors = []
if options.use_postgresql:
db_connectors.append(DbConnector(POSTGRESQL,
user_name=options.postgresql_user,
password=options.postgresql_password,
host_name=options.postgresql_host,
port=options.postgresql_port))
if options.use_mysql:
db_connectors.append(DbConnector(MYSQL,
user_name=options.mysql_user,
password=options.mysql_password,
host_name=options.mysql_host,
port=options.mysql_port))
populator = DatabasePopulator()
if command == 'populate':
db_connectors.append(impala_connector)
populator.drop_and_create_database(options.db_name, db_connectors)
populator.populate_db_with_random_data(
options.db_name,
db_connectors,
number_of_tables=options.table_count,
create_files=options.create_data_files)
else:
populator.drop_and_create_database(options.db_name, db_connectors)
if options.migrate_table_names:
table_names = options.migrate_table_names.split(',')
else:
table_names = None
populator.migrate_database(
options.db_name,
impala_connector,
db_connectors,
include_table_names=table_names)

View File

@@ -0,0 +1,411 @@
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''This module is intended to standardize workflows when working with various databases
such as Impala, Postgresql, etc. Even with pep-249 (DB API 2), workflows differ
slightly. For example Postgresql does not allow changing databases from within a
connection, instead a new connection must be made. However Impala does not allow
specifying a database upon connection, instead a cursor must be created and a USE
command must be issued.
'''
from contextlib import contextmanager
try:
from impala.dbapi import connect as impala_connect
except:
print('Error importing impyla. Please make sure it is installed. '
'See the README for details.')
raise
from itertools import izip
from logging import getLogger
from tests.comparison.model import Column, Table, TYPES, String
LOG = getLogger(__name__)
IMPALA = 'IMPALA'
POSTGRESQL = 'POSTGRESQL'
MYSQL = 'MYSQL'
DATABASES = [IMPALA, POSTGRESQL, MYSQL]
mysql_connect = None
postgresql_connect = None
class DbConnector(object):
'''Wraps a DB API 2 implementation to provide a standard way of obtaining a
connection and selecting a database.
Any database that supports transactions will have auto-commit enabled.
'''
def __init__(self, db_type, user_name=None, password=None, host_name=None, port=None):
self.db_type = db_type.upper()
if self.db_type not in DATABASES:
raise Exception('Unsupported database: %s' % db_type)
self.user_name = user_name
self.password = password
self.host_name = host_name or 'localhost'
self.port = port
def create_connection(self, db_name=None):
if self.db_type == IMPALA:
connection_class = ImpalaDbConnection
connection = impala_connect(host=self.host_name, port=self.port or 21050)
elif self.db_type == POSTGRESQL:
connection_class = PostgresqlDbConnection
connection_args = {'user': self.user_name or 'postgres'}
if self.password:
connection_args['password'] = self.password
if db_name:
connection_args['database'] = db_name
if self.host_name:
connection_args['host'] = self.host_name
if self.port:
connection_args['port'] = self.port
global postgresql_connect
if not postgresql_connect:
try:
from psycopg2 import connect as postgresql_connect
except:
print('Error importing psycopg2. Please make sure it is installed. '
'See the README for details.')
raise
connection = postgresql_connect(**connection_args)
connection.autocommit = True
elif self.db_type == MYSQL:
connection_class = MySQLDbConnection
connection_args = {'user': self.user_name or 'root'}
if self.password:
connection_args['passwd'] = self.password
if db_name:
connection_args['db'] = db_name
if self.host_name:
connection_args['host'] = self.host_name
if self.port:
connection_args['port'] = self.port
global mysql_connect
if not mysql_connect:
try:
from MySQLdb import connect as mysql_connect
except:
print('Error importing MySQLdb. Please make sure it is installed. '
'See the README for details.')
raise
connection = mysql_connect(**connection_args)
else:
raise Exception('Unexpected database type: %s' % self.db_type)
return connection_class(self, connection, db_name=db_name)
@contextmanager
def open_connection(self, db_name=None):
connection = None
try:
connection = self.create_connection(db_name=db_name)
yield connection
finally:
if connection:
try:
connection.close()
except Exception as e:
LOG.debug('Error closing connection: %s', e, exc_info=True)
class DbConnection(object):
'''Wraps a DB API 2 connection. Instances should only be obtained through the
DbConnector.create_connection(...) method.
'''
@staticmethod
def describe_common_tables(db_connections, filter_col_types=[]):
'''Find and return a list of Table objects that the given connections have in
common.
@param filter_col_types: Ignore any cols if they are of a data type contained
in this collection.
'''
common_table_names = None
for db_connection in db_connections:
table_names = set(db_connection.list_table_names())
if common_table_names is None:
common_table_names = table_names
else:
common_table_names &= table_names
common_table_names = sorted(common_table_names)
tables = list()
for table_name in common_table_names:
common_table = None
mismatch = False
for db_connection in db_connections:
table = db_connection.describe_table(table_name)
table.cols = [col for col in table.cols if col.type not in filter_col_types]
if common_table is None:
common_table = table
continue
if len(common_table.cols) != len(table.cols):
LOG.debug('Ignoring table %s.'
' It has a different number of columns across databases.', table_name)
mismatch = True
break
for left, right in izip(common_table.cols, table.cols):
if not left.name == right.name and left.type == right.type:
LOG.debug('Ignoring table %s. It has different columns %s vs %s.' %
(table_name, left, right))
mismatch = True
break
if mismatch:
break
if not mismatch:
tables.append(common_table)
return tables
def __init__(self, connector, connection, db_name=None):
self.connector = connector
self.connection = connection
self.db_name = db_name
@property
def db_type(self):
return self.connector.db_type
def create_cursor(self):
return DatabaseCursor(self.connection.cursor(), self)
@contextmanager
def open_cursor(self):
'''Returns a new cursor for use in a "with" statement. When the "with" statement ends,
the cursor will be closed.
'''
cursor = None
try:
cursor = self.create_cursor()
yield cursor
finally:
self.close_cursor_quietly(cursor)
def close_cursor_quietly(self, cursor):
if cursor:
try:
cursor.close()
except Exception as e:
LOG.debug('Error closing cursor: %s', e, exc_info=True)
def list_db_names(self):
'''Return a list of database names always in lowercase.'''
rows = self.execute_and_fetchall(self.make_list_db_names_sql())
return [row[0].lower() for row in rows]
def make_list_db_names_sql(self):
return 'SHOW DATABASES'
def list_table_names(self):
'''Return a list of table names always in lowercase.'''
rows = self.execute_and_fetchall(self.make_list_table_names_sql())
return [row[0].lower() for row in rows]
def make_list_table_names_sql(self):
return 'SHOW TABLES'
def describe_table(self, table_name):
'''Return a Table with table and col names always in lowercase.'''
rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name))
table = Table(table_name.lower())
for row in rows:
col_name, data_type = row[:2]
table.cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type)))
return table
def make_describe_table_sql(self, table_name):
return 'DESCRIBE ' + table_name
def parse_data_type(self, sql):
sql = sql.upper()
# Types may have declared a database specific alias
for type_ in TYPES:
if sql in getattr(type_, self.db_type, []):
return type_
for type_ in TYPES:
if type_.__name__.upper() == sql:
return type_
if 'CHAR' in sql:
return String
raise Exception('Unknown data type: ' + sql)
def create_database(self, db_name):
db_name = db_name.lower()
with self.open_cursor() as cursor:
cursor.execute('CREATE DATABASE ' + db_name)
def drop_db_if_exists(self, db_name):
'''This should not be called from a connection to the database being dropped.'''
db_name = db_name.lower()
if db_name not in self.list_db_names():
return
if self.db_name and self.db_name.lower() == db_name:
raise Exception('Cannot drop database while still connected to it')
self.drop_database(db_name)
def drop_database(self, db_name):
db_name = db_name.lower()
self.execute('DROP DATABASE ' + db_name)
@property
def supports_index_creation(self):
return True
def index_table(self, table_name):
table = self.describe_table(table_name)
with self.open_cursor() as cursor:
for col in table.cols:
index_name = '%s_%s' % (table_name, col.name)
if self.db_name:
index_name = '%s_%s' % (self.db_name, index_name)
cursor.execute('CREATE INDEX %s ON %s(%s)' % (index_name, table_name, col.name))
@property
def supports_kill_connection(self):
return False
def kill_connection(self):
'''Kill the current connection and any currently running queries assosiated with the
connection.
'''
raise Exception('Killing connection is not supported')
def materialize_query(self, query_as_text, table_name):
self.execute('CREATE TABLE %s AS %s' % (table_name.lower(), query_as_text))
def drop_table(self, table_name):
self.execute('DROP TABLE ' + table_name.lower())
def execute(self, sql):
with self.open_cursor() as cursor:
cursor.execute(sql)
def execute_and_fetchall(self, sql):
with self.open_cursor() as cursor:
cursor.execute(sql)
return cursor.fetchall()
def close(self):
'''Close the underlying connection.'''
self.connection.close()
def reconnect(self):
self.close()
other = self.connector.create_connection(db_name=self.db_name)
self.connection = other.connection
class DatabaseCursor(object):
'''Wraps a DB API 2 cursor to provide access to the related connection. This class
implements the DB API 2 interface by delegation.
'''
def __init__(self, cursor, connection):
self.cursor = cursor
self.connection = connection
def __getattr__(self, attr):
return getattr(self.cursor, attr)
class ImpalaDbConnection(DbConnection):
def create_cursor(self):
cursor = DbConnection.create_cursor(self)
if self.db_name:
cursor.execute('USE %s' % self.db_name)
return cursor
def drop_database(self, db_name):
'''This should not be called from a connection to the database being dropped.'''
db_name = db_name.lower()
with self.connector.open_connection(db_name) as list_tables_connection:
with list_tables_connection.open_cursor() as drop_table_cursor:
for table_name in list_tables_connection.list_table_names():
drop_table_cursor.execute('DROP TABLE ' + table_name)
self.execute('DROP DATABASE ' + db_name)
@property
def supports_index_creation(self):
return False
class PostgresqlDbConnection(DbConnection):
def make_list_db_names_sql(self):
return 'SELECT datname FROM pg_database'
def make_list_table_names_sql(self):
return '''
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public' '''
def make_describe_table_sql(self, table_name):
return '''
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_name = '%s'
ORDER BY ordinal_position''' % table_name
class MySQLDbConnection(DbConnection):
def __init__(self, connector, connection, db_name=None):
DbConnection.__init__(self, connector, connection, db_name=db_name)
self.session_id = self.execute_and_fetchall('SELECT connection_id()')[0][0]
def describe_table(self, table_name):
'''Return a Table with table and col names always in lowercase.'''
rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name))
table = Table(table_name.lower())
for row in rows:
col_name, data_type = row[:2]
if data_type == 'tinyint(1)':
# Just assume this is a boolean...
data_type = 'boolean'
if '(' in data_type:
# Strip the size of the data type
data_type = data_type[:data_type.index('(')]
table.cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type)))
return table
@property
def supports_kill_connection(self):
return True
def kill_connection(self):
with self.connector.open_connection(db_name=self.db_name) as connection:
connection.execute('KILL %s' % (self.session_id))
def index_table(self, table_name):
table = self.describe_table(table_name)
with self.open_cursor() as cursor:
for col in table.cols:
try:
cursor.execute('ALTER TABLE %s ADD INDEX (%s)' % (table_name, col.name))
except Exception as e:
if 'Incorrect index name' not in str(e):
raise
# Some sort of MySQL bug...
LOG.warn('Could not create index on %s.%s: %s' % (table_name, col.name, e))

View File

@@ -0,0 +1,529 @@
#!/usr/bin/env python
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''This module will run random queries against existing databases and compare the
results.
'''
from contextlib import closing
from decimal import Decimal
from itertools import izip, izip_longest
from logging import basicConfig, getLogger
from math import isinf, isnan
from os import getenv, remove
from os.path import exists, join
from shelve import open as open_shelve
from subprocess import call
from threading import current_thread, Thread
from tempfile import gettempdir
from time import time
from tests.comparison.db_connector import (
DbConnection,
DbConnector,
IMPALA,
MYSQL,
POSTGRESQL)
from tests.comparison.model import BigInt, TYPES
from tests.comparison.query_generator import QueryGenerator
from tests.comparison.model_translator import SqlWriter
LOG = getLogger(__name__)
class QueryResultComparator(object):
# If the number of rows * cols is greater than this val, then the comparison will
# be aborted. Raising this value also raises the risk of python being OOM killed. At
# 10M python would get OOM killed occasionally even on a physical machine with 32GB
# ram.
TOO_MUCH_DATA = 1000 * 1000
# Used when comparing float vals
EPSILON = 0.1
# The decimal vals will be rounded before comparison
DECIMAL_PLACES = 2
def __init__(self, impala_connection, reference_connection):
self.reference_db_type = reference_connection.db_type
self.impala_cursor = impala_connection.create_cursor()
self.reference_cursor = reference_connection.create_cursor()
self.impala_sql_writer = SqlWriter.create(dialect=impala_connection.db_type)
self.reference_sql_writer = SqlWriter.create(dialect=reference_connection.db_type)
# At this time the connection will be killed and ther comparison result will be
# timeout.
self.query_timeout_seconds = 3 * 60
def compare_query_results(self, query):
'''Execute the query, compare the data, and return a summary of the result.'''
comparison_result = ComparisonResult(query, self.reference_db_type)
reference_data_set = None
impala_data_set = None
# Impala doesn't support getting the row count without getting the rows too. So run
# the query on the other database first.
try:
for sql_writer, cursor in ((self.reference_sql_writer, self.reference_cursor),
(self.impala_sql_writer, self.impala_cursor)):
self.execute_query(cursor, sql_writer.write_query(query))
if (cursor.rowcount * len(query.select_clause.select_items)) > self.TOO_MUCH_DATA:
comparison_result.exception = Exception('Too much data to compare')
return comparison_result
if reference_data_set is None:
# MySQL returns a tuple of rows but a list is needed for sorting
reference_data_set = list(cursor.fetchall())
comparison_result.reference_row_count = len(reference_data_set)
else:
impala_data_set = cursor.fetchall()
comparison_result.impala_row_count = len(impala_data_set)
except Exception as e:
comparison_result.exception = e
LOG.debug('Error running query: %s', e, exc_info=True)
return comparison_result
comparison_result.query_resulted_in_data = (comparison_result.impala_row_count > 0
or comparison_result.reference_row_count > 0)
if comparison_result.impala_row_count != comparison_result.reference_row_count:
return comparison_result
for data_set in (reference_data_set, impala_data_set):
for row_idx, row in enumerate(data_set):
data_set[row_idx] = [self.standardize_data(data) for data in row]
data_set.sort(cmp=self.row_sort_cmp)
for impala_row, reference_row in \
izip_longest(impala_data_set, reference_data_set):
for col_idx, (impala_val, reference_val) \
in enumerate(izip_longest(impala_row, reference_row)):
if not self.vals_are_equal(impala_val, reference_val):
if isinstance(impala_val, int) \
and isinstance(reference_val, (int, float, Decimal)) \
and abs(reference_val) > BigInt.MAX:
# Impala will return incorrect results if the val is greater than max bigint
comparison_result.exception = KnownError(
'https://issues.cloudera.org/browse/IMPALA-865')
elif isinstance(impala_val, float) \
and (isinf(impala_val) or isnan(impala_val)):
# In some cases, Impala gives NaNs and Infs instead of NULLs
comparison_result.exception = KnownError(
'https://issues.cloudera.org/browse/IMPALA-724')
comparison_result.impala_row = impala_row
comparison_result.reference_row = reference_row
comparison_result.mismatch_at_row_number = row_idx + 1
comparison_result.mismatch_at_col_number = col_idx + 1
return comparison_result
if len(impala_data_set) == 1:
for val in impala_data_set[0]:
if val:
break
else:
comparison_result.query_resulted_in_data = False
return comparison_result
def execute_query(self, cursor, sql):
'''Execute the query and throw a timeout if needed.'''
def _execute_query():
try:
cursor.execute(sql)
except Exception as e:
current_thread().exception = e
query_thread = Thread(target=_execute_query, name='Query execution thread')
query_thread.daemon = True
query_thread.start()
query_thread.join(self.query_timeout_seconds)
if query_thread.is_alive():
if cursor.connection.supports_kill_connection:
LOG.debug('Attempting to kill connection')
cursor.connection.kill_connection()
LOG.debug('Kill connection')
cursor.close()
cursor.connection.close()
cursor = cursor\
.connection\
.connector\
.create_connection(db_name=cursor.connection.db_name)\
.create_cursor()
if cursor.connection.db_type == IMPALA:
self.impala_cursor = cursor
else:
self.reference_cursor = cursor
raise QueryTimeout('Query timed out after %s seconds' % self.query_timeout_seconds)
if hasattr(query_thread, 'exception'):
raise query_thread.exception
def standardize_data(self, data):
'''Return a val that is suitable for comparison.'''
# For float data we need to round otherwise differences in precision will cause errors
if isinstance(data, (float, Decimal)):
return round(data, self.DECIMAL_PLACES)
return data
def row_sort_cmp(self, left_row, right_row):
for left, right in izip(left_row, right_row):
if left is None and right is not None:
return -1
if left is not None and right is None:
return 1
result = cmp(left, right)
if result:
return result
return 0
def vals_are_equal(self, left, right):
if left == right:
return True
if isinstance(left, (int, float, Decimal)) and \
isinstance(right, (int, float, Decimal)):
return self.floats_are_equal(left, right)
return False
def floats_are_equal(self, left, right):
left = round(left, self.DECIMAL_PLACES)
right = round(right, self.DECIMAL_PLACES)
diff = abs(left - right)
if left * right == 0:
return diff < self.EPSILON
return diff / (abs(left) + abs(right)) < self.EPSILON
class ComparisonResult(object):
def __init__(self, query, reference_db_type):
self.query = query
self.reference_db_type = reference_db_type
self.query_resulted_in_data = False
self.impala_row_count = None
self.reference_row_count = None
self.mismatch_at_row_number = None
self.mismatch_at_col_number = None
self.impala_row = None
self.reference_row = None
self.exception = None
self._error_message = None
@property
def error(self):
if not self._error_message:
if self.exception:
self._error_message = str(self.exception)
elif self.impala_row_count and \
self.impala_row_count != self.reference_row_count:
self._error_message = 'Row counts do not match: %s Impala rows vs %s %s rows' \
% (self.impala_row_count,
self.reference_db_type,
self.reference_row_count)
elif self.mismatch_at_row_number is not None:
# Write a row like "[a, b, <<c>>, d]" where c is a bad value
impala_row = '[' + ', '.join(
'<<' + str(val) + '>>' if idx == self.mismatch_at_col_number - 1 else str(val)
for idx, val in enumerate(self.impala_row)
) + ']'
reference_row = '[' + ', '.join(
'<<' + str(val) + '>>' if idx == self.mismatch_at_col_number - 1 else str(val)
for idx, val in enumerate(self.reference_row)
) + ']'
self._error_message = \
'Column %s in row %s does not match: %s Impala row vs %s %s row' \
% (self.mismatch_at_col_number,
self.mismatch_at_row_number,
impala_row,
reference_row,
self.reference_db_type)
return self._error_message
@property
def is_known_error(self):
return isinstance(self.exception, KnownError)
@property
def query_timed_out(self):
return isinstance(self.exception, QueryTimeout)
class QueryTimeout(Exception):
pass
class KnownError(Exception):
def __init__(self, jira_url):
Exception.__init__(self, 'Known issue: ' + jira_url)
self.jira_url = jira_url
class QueryResultDiffSearcher(object):
# Sometimes things get into a bad state and the same error loops forever
ABORT_ON_REPEAT_ERROR_COUNT = 2
def __init__(self, impala_connection, reference_connection, filter_col_types=[]):
self.impala_connection = impala_connection
self.reference_connection = reference_connection
self.common_tables = DbConnection.describe_common_tables(
[impala_connection, reference_connection],
filter_col_types=filter_col_types)
# A file-backed dict of queries that produced a discrepancy, keyed by query number
# (in string form, as required by the dict).
self.query_shelve_path = gettempdir() + '/query.shelve'
# A list of all queries attempted
self.query_log_path = gettempdir() + '/impala_query_log.sql'
def search(self, number_of_test_queries, stop_on_result_mismatch, stop_on_crash):
if exists(self.query_shelve_path):
# Ensure a clean shelve will be created
remove(self.query_shelve_path)
start_time = time()
impala_sql_writer = SqlWriter.create(dialect=IMPALA)
reference_sql_writer = SqlWriter.create(
dialect=self.reference_connection.db_type)
query_result_comparator = QueryResultComparator(
self.impala_connection, self.reference_connection)
query_generator = QueryGenerator()
query_count = 0
queries_resulted_in_data_count = 0
mismatch_count = 0
query_timeout_count = 0
known_error_count = 0
impala_crash_count = 0
last_error = None
repeat_error_count = 0
with open(self.query_log_path, 'w') as impala_query_log:
impala_query_log.write(
'--\n'
'-- Stating new run\n'
'--\n')
while number_of_test_queries > query_count:
query = query_generator.create_query(self.common_tables)
impala_sql = impala_sql_writer.write_query(query)
if 'FULL OUTER JOIN' in impala_sql and self.reference_connection.db_type == MYSQL:
# Not supported by MySQL
continue
query_count += 1
LOG.info('Running query #%s', query_count)
impala_query_log.write(impala_sql + ';\n')
result = query_result_comparator.compare_query_results(query)
if result.query_resulted_in_data:
queries_resulted_in_data_count += 1
if result.error:
# TODO: These first two come from psycopg2, the postgres driver. Maybe we should
# try a different driver? Or maybe the usage of the driver isn't correct.
# Anyhow ignore these failures.
if 'division by zero' in result.error \
or 'out of range' in result.error \
or 'Too much data' in result.error:
LOG.debug('Ignoring error: %s', result.error)
query_count -= 1
continue
if result.is_known_error:
known_error_count += 1
elif result.query_timed_out:
query_timeout_count += 1
else:
mismatch_count += 1
with closing(open_shelve(self.query_shelve_path)) as query_shelve:
query_shelve[str(query_count)] = query
print('---Impala Query---\n')
print(impala_sql_writer.write_query(query, pretty=True) + '\n')
print('---Reference Query---\n')
print(reference_sql_writer.write_query(query, pretty=True) + '\n')
print('---Error---\n')
print(result.error + '\n')
print('------\n')
if 'Could not connect' in result.error \
or "Couldn't open transport for" in result.error:
# if stop_on_crash:
# break
# Assume Impala crashed and try restarting
impala_crash_count += 1
LOG.info('Restarting Impala')
call([join(getenv('IMPALA_HOME'), 'bin/start-impala-cluster.py'),
'--log_dir=%s' % getenv('LOG_DIR', "/tmp/")])
self.impala_connection.reconnect()
query_result_comparator.impala_cursor = self.impala_connection.create_cursor()
result = query_result_comparator.compare_query_results(query)
if result.error:
LOG.info('Restarting Impala')
call([join(getenv('IMPALA_HOME'), 'bin/start-impala-cluster.py'),
'--log_dir=%s' % getenv('LOG_DIR', "/tmp/")])
self.impala_connection.reconnect()
query_result_comparator.impala_cursor = self.impala_connection.create_cursor()
else:
break
if stop_on_result_mismatch and \
not (result.is_known_error or result.query_timed_out):
break
if last_error == result.error \
and not (result.is_known_error or result.query_timed_out):
repeat_error_count += 1
if repeat_error_count == self.ABORT_ON_REPEAT_ERROR_COUNT:
break
else:
last_error = result.error
repeat_error_count = 0
else:
if result.query_resulted_in_data:
LOG.info('Results matched (%s rows)', result.impala_row_count)
else:
LOG.info('Query did not produce meaningful data')
last_error = None
repeat_error_count = 0
return SearchResults(
query_count,
queries_resulted_in_data_count,
mismatch_count,
query_timeout_count,
known_error_count,
impala_crash_count,
time() - start_time)
class SearchResults(object):
'''This class holds information about the outcome of a search run.'''
def __init__(self,
query_count,
queries_resulted_in_data_count,
mismatch_count,
query_timeout_count,
known_error_count,
impala_crash_count,
run_time_in_seconds):
# Approx number of queries run, some queries may have been ignored
self.query_count = query_count
self.queries_resulted_in_data_count = queries_resulted_in_data_count
# Number of queries that had an error or result mismatch
self.mismatch_count = mismatch_count
self.query_timeout_count = query_timeout_count
self.known_error_count = known_error_count
self.impala_crash_count = impala_crash_count
self.run_time_in_seconds = run_time_in_seconds
def get_summary_text(self):
mins, secs = divmod(self.run_time_in_seconds, 60)
hours, mins = divmod(mins, 60)
hours = int(hours)
mins = int(mins)
if hours:
run_time = '%s hour and %s minutes' % (hours, mins)
else:
secs = int(secs)
run_time = '%s seconds' % secs
if mins:
run_time = '%s mins and ' % mins + run_time
summary_params = self.__dict__
summary_params['run_time'] = run_time
return (
'%(mismatch_count)s mismatches found after running %(query_count)s queries in '
'%(run_time)s.\n'
'%(queries_resulted_in_data_count)s of %(query_count)s queries produced results.'
'\n'
'%(impala_crash_count)s Impala crashes occurred.\n'
'%(known_error_count)s queries were excluded from the mismatch count because '
'they are known errors.\n'
'%(query_timeout_count)s queries timed out and were excluded from all counts.') \
% summary_params
if __name__ == '__main__':
import sys
from optparse import NO_DEFAULT, OptionGroup, OptionParser
parser = OptionParser()
parser.add_option('--log-level', default='INFO',
help='The log level to use.', choices=('DEBUG', 'INFO', 'WARN', 'ERROR'))
parser.add_option('--db-name', default='randomness',
help='The name of the database to use. Ex: funcal.')
parser.add_option('--reference-db-type', default=MYSQL, choices=(MYSQL, POSTGRESQL),
help='The type of the reference database to use. Ex: MYSQL.')
parser.add_option('--stop-on-mismatch', default=False, action='store_true',
help='Exit immediately upon find a discrepancy in a query result.')
parser.add_option('--stop-on-crash', default=False, action='store_true',
help='Exit immediately if Impala crashes.')
parser.add_option('--query-count', default=1000, type=int,
help='Exit after running the given number of queries.')
parser.add_option('--exclude-types', default='Double,Float,TinyInt',
help='A comma separated list of data types to exclude while generating queries.')
group = OptionGroup(parser, 'MySQL Options')
group.add_option('--mysql-host', default='localhost',
help='The name of the host running the MySQL database.')
group.add_option('--mysql-port', default=3306, type=int,
help='The port of the host running the MySQL database.')
group.add_option('--mysql-user', default='root',
help='The user name to use when connecting to the MySQL database.')
group.add_option('--mysql-password',
help='The password to use when connecting to the MySQL database.')
parser.add_option_group(group)
group = OptionGroup(parser, 'Postgresql Options')
group.add_option('--postgresql-host', default='localhost',
help='The name of the host running the Postgresql database.')
group.add_option('--postgresql-port', default=5432, type=int,
help='The port of the host running the Postgresql database.')
group.add_option('--postgresql-user', default='postgres',
help='The user name to use when connecting to the Postgresql database.')
group.add_option('--postgresql-password',
help='The password to use when connecting to the Postgresql database.')
parser.add_option_group(group)
for group in parser.option_groups + [parser]:
for option in group.option_list:
if option.default != NO_DEFAULT:
option.help += " [default: %default]"
options, args = parser.parse_args()
basicConfig(level=options.log_level)
impala_connection = DbConnector(IMPALA).create_connection(options.db_name)
db_connector_param_key = options.reference_db_type.lower()
reference_connection = DbConnector(options.reference_db_type,
user_name=getattr(options, db_connector_param_key + '_user'),
password=getattr(options, db_connector_param_key + '_password'),
host_name=getattr(options, db_connector_param_key + '_host'),
port=getattr(options, db_connector_param_key + '_port')) \
.create_connection(options.db_name)
if options.exclude_types:
exclude_types = set(type_name.lower() for type_name
in options.exclude_types.split(','))
filter_col_types = [type_ for type_ in TYPES
if type_.__name__.lower() in exclude_types]
else:
filter_col_types = []
diff_searcher = QueryResultDiffSearcher(
impala_connection, reference_connection, filter_col_types=filter_col_types)
search_results = diff_searcher.search(
options.query_count, options.stop_on_mismatch, options.stop_on_crash)
print(search_results.get_summary_text())
sys.exit(search_results.mismatch_count)

741
tests/comparison/model.py Normal file
View File

@@ -0,0 +1,741 @@
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
class Query(object):
'''A representation of the stucture of a SQL query. Only the select_clause and
from_clause are required for a valid query.
'''
def __init__(self, select_clause, from_clause):
self.with_clause = None
self.select_clause = select_clause
self.from_clause = from_clause
self.where_clause = None
self.group_by_clause = None
self.having_clause = None
self.union_clause = None
@property
def table_exprs(self):
'''Provides a list of all table_exprs that are declared by this query. This
includes table_exprs in the WITH and FROM sections.
'''
table_exprs = self.from_clause.table_exprs
if self.with_clause:
table_exprs += self.with_clause.table_exprs
return table_exprs
class SelectClause(object):
'''This encapuslates the SELECT part of a query. It is convenient to separate
non-agg items from agg items so that it is simple to know if the query
is an agg query or not.
'''
def __init__(self, non_agg_items=None, agg_items=None):
self.non_agg_items = non_agg_items or list()
self.agg_items = agg_items or list()
self.distinct = False
@property
def select_items(self):
'''Provides a consolidated view of all select items.'''
return self.non_agg_items + self.agg_items
class SelectItem(object):
'''A representation of any possible expr than would be valid in
SELECT <SelectItem>[, <SelectItem>...] FROM ...
Each SelectItem contains a ValExpr which will either be a instance of a
DataType (representing a constant), a Column, or a Func.
Ex: "SELECT int_col + smallint_col FROM alltypes" would have a val_expr of
Plus(Column(<alltypes.int_col>), Column(<alltypes.smallint_col>)).
'''
def __init__(self, val_expr, alias=None):
self.val_expr = val_expr
self.alias = alias
@property
def type(self):
'''Returns the DataType of this item.'''
return self.val_expr.type
@property
def is_agg(self):
'''Evaluates to True if this item contains an aggregate expression.'''
return self.val_expr.is_agg
class ValExpr(object):
'''This is an AbstractClass that represents a generic expr that results in a
scalar. The abc module was not used because it caused problems for the pickle
module.
'''
@property
def type(self):
'''This is declared for documentations purposes, subclasses should override this to
return the DataType that this expr represents.
'''
pass
@property
def base_type(self):
'''Return the most fundemental data type that the expr evaluates to. Only
numeric types will result in a different val than would be returned by self.type.
Ex:
if self.type == BigInt:
assert self.base_type == Int
if self.type == Double:
assert self.base_type == Float
if self.type == String:
assert self.base_type == self.type
'''
if self.returns_int:
return Int
if self.returns_float:
return Float
return self.type
@property
def is_func(self):
return isinstance(self, Func)
@property
def is_agg(self):
'''Evaluates to True if this expression contains an aggregate function.'''
if isinstance(self, AggFunc):
return True
if self.is_func:
for arg in self.args:
if arg.is_agg:
return True
@property
def is_col(self):
return isinstance(self, Column)
@property
def is_constant(self):
return isinstance(self, DataType)
@property
def returns_boolean(self):
return issubclass(self.type, Boolean)
@property
def returns_number(self):
return issubclass(self.type, Number)
@property
def returns_int(self):
return issubclass(self.type, Int)
@property
def returns_float(self):
return issubclass(self.type, Float)
@property
def returns_string(self):
return issubclass(self.type, String)
@property
def returns_timestamp(self):
return issubclass(self.type, Timestamp)
class Column(ValExpr):
'''A representation of a col. All TableExprs will have Columns. So a Column
may belong to an InlineView as well as a standard Table.
This class is used in two ways:
1) As a piece of metadata in a table definiton. In this usage the col isn't
intended to represent an val.
2) As an expr in a query, for example an item being selected or as part of
a join condition. In this usage the col is more like a val, which is why
it implements/extends ValExpr.
'''
def __init__(self, owner, name, type_):
self.owner = owner
self.name = name
self._type = type_
@property
def type(self):
return self._type
def __hash__(self):
return hash(self.name)
def __eq__(self, other):
if not isinstance(other, Column):
return False
if self is other:
return True
return self.name == other.name and self.owner.identifier == other.owner.identifier
def __repr__(self):
return '%s<name: %s, type: %s>' % (
type(self).__name__, self.name, self._type.__name__)
class FromClause(object):
'''A representation of a FROM clause. The member variable join_clauses may optionally
contain JoinClause items.
'''
def __init__(self, table_expr, join_clauses=None):
self.table_expr = table_expr
self.join_clauses = join_clauses or list()
@property
def table_exprs(self):
'''Provides a list of all table_exprs that are declared within this FROM
block.
'''
table_exprs = [join_clause.table_expr for join_clause in self.join_clauses]
table_exprs.append(self.table_expr)
return table_exprs
class TableExpr(object):
'''This is an AbstractClass that represents something that a query may use to select
from or join on. The abc module was not used because it caused problems for the
pickle module.
'''
def identifier(self):
'''Returns either a table name or alias if one has been declared.'''
pass
def cols(self):
pass
@property
def cols_by_base_type(self):
'''Group cols by their basic data type and return a dict of the results.
As an example, a "BigInt" would be considered as an "Int".
'''
return DataType.group_by_base_type(self.cols)
@property
def is_table(self):
return isinstance(self, Table)
@property
def is_inline_view(self):
return isinstance(self, InlineView)
@property
def is_with_clause_inline_view(self):
return isinstance(self, WithClauseInlineView)
def __eq__(self, other):
if not isinstance(other, type(self)):
return False
return self.identifier == other.identifier
class Table(TableExpr):
'''Represents a standard database table.'''
def __init__(self, name):
self.name = name
self._cols = []
self.alias = None
@property
def identifier(self):
return self.alias or self.name
@property
def cols(self):
return self._cols
@cols.setter
def cols(self, cols):
self._cols = cols
class InlineView(TableExpr):
'''Represents an inline view.
Ex: In the query "SELECT * FROM (SELECT * FROM foo) AS bar",
"(SELECT * FROM foo) AS bar" would be an inline view.
'''
def __init__(self, query):
self.query = query
self.alias = None
@property
def identifier(self):
return self.alias
@property
def cols(self):
return [Column(self, item.alias, item.type) for item in
self.query.select_clause.non_agg_items + self.query.select_clause.agg_items]
class WithClause(object):
'''Represents a WITH clause.
Ex: In the query "WITH bar AS (SELECT * FROM foo) SELECT * FROM bar",
"WITH bar AS (SELECT * FROM foo)" would be the with clause.
'''
def __init__(self, with_clause_inline_views):
self.with_clause_inline_views = with_clause_inline_views
@property
def table_exprs(self):
return self.with_clause_inline_views
class WithClauseInlineView(InlineView):
'''Represents the entries in a WITH clause. These are very similar to InlineViews but
may have an additional alias.
Ex: WITH bar AS (SELECT * FROM foo)
SELECT *
FROM bar as r
JOIN (SELECT * FROM baz) AS z ON ...
The WithClauseInlineView has aliases "bar" and "r" while the InlineView has
only the alias "z".
'''
def __init__(self, query, with_clause_alias):
self.query = query
self.with_clause_alias = with_clause_alias
self.alias = None
@property
def identifier(self):
return self.alias or self.with_clause_alias
class JoinClause(object):
'''A representation of a JOIN clause.
Ex: SELECT * FROM foo <join_type> JOIN <table_expr> [ON <boolean_expr>]
The member variable boolean_expr will be an instance of a boolean func
defined below.
'''
JOINS_TYPES = ['INNER', 'LEFT', 'RIGHT', 'FULL OUTER', 'CROSS']
def __init__(self, join_type, table_expr, boolean_expr=None):
self.join_type = join_type
self.table_expr = table_expr
self.boolean_expr = boolean_expr
class WhereClause(object):
'''The member variable boolean_expr will be an instance of a boolean func
defined below.
'''
def __init__(self, boolean_expr):
self.boolean_expr = boolean_expr
class GroupByClause(object):
def __init__(self, select_items):
self.group_by_items = select_items
class HavingClause(object):
'''The member variable boolean_expr will be an instance of a boolean func
defined below.
'''
def __init__(self, boolean_expr):
self.boolean_expr = boolean_expr
class UnionClause(object):
'''A representation of a UNION clause.
If the member variable "all" is True, the instance represents a "UNION ALL".
'''
def __init__(self, query):
self.query = query
self.all = False
@property
def queries(self):
queries = list()
query = self.query
while True:
queries.append(query)
if not query.union_clause:
break
query = query.union_clause.query
return queries
class DataTypeMetaclass(type):
'''Provides sorting of classes used to determine upcasting.'''
def __cmp__(cls, other):
return cmp(
getattr(cls, 'CMP_VALUE', cls.__name__),
getattr(other, 'CMP_VALUE', other.__name__))
class DataType(ValExpr):
'''Base class for data types.
Data types are represented as classes so inheritence can be used.
'''
__metaclass__ = DataTypeMetaclass
def __init__(self, val):
self.val = val
@property
def type(self):
return type(self)
@staticmethod
def group_by_base_type(vals):
'''Group cols by their basic data type and return a dict of the results.
As an example, a "BigInt" would be considered as an "Int".
'''
vals_by_type = defaultdict(list)
for val in vals:
type_ = val.type
if issubclass(type_, Int):
type_ = Int
elif issubclass(type_, Float):
type_ = Float
vals_by_type[type_].append(val)
return vals_by_type
class Boolean(DataType):
pass
class Number(DataType):
pass
class Int(Number):
# Used to compare with other numbers for determining upcasting
CMP_VALUE = 2
# Used during data generation to keep vals in range
MIN = -2 ** 31
MAX = -MIN - 1
# Aliases used when reading and writing table definitions
POSTGRESQL = ['INTEGER']
class TinyInt(Int):
CMP_VALUE = 0
MIN = -2 ** 7
MAX = -MIN - 1
POSTGRESQL = ['SMALLINT']
class SmallInt(Int):
CMP_VALUE = 1
MIN = -2 ** 15
MAX = -MIN - 1
class BigInt(Int):
CMP_VALUE = 3
MIN = -2 ** 63
MAX = -MIN - 1
class Float(Number):
CMP_VALUE = 4
POSTGRESQL = ['REAL']
class Double(Float):
CMP_VALUE = 5
MYSQL = ['DOUBLE', 'DECIMAL'] # Use double by default but add decimal synonym
POSTGRESQL = ['DOUBLE PRECISION']
class String(DataType):
MIN = 0
# The Impala limit is 32,767 but MySQL has a row size limit of 65,535. To allow 3+
# String cols per table, the limit will be lowered to 1,000. That should be fine
# for testing anyhow.
MAX = 1000
MYSQL = ['VARCHAR(%s)' % MAX]
POSTGRESQL = MYSQL + ['CHARACTER VARYING']
class Timestamp(DataType):
MYSQL = ['DATETIME']
POSTGRESQL = ['TIMESTAMP WITHOUT TIME ZONE']
NUMBER_TYPES = [Int, TinyInt, SmallInt, BigInt, Float, Double]
TYPES = NUMBER_TYPES + [Boolean, String, Timestamp]
class Func(ValExpr):
'''Base class for funcs'''
def __init__(self, *args):
self.args = list(args)
def __hash__(self):
return hash(type(self)) + hash(tuple(self.args))
def __eq__(self, other):
if not isinstance(other, type(self)):
return False
if self is other:
return True
return self.args == other.args
class UnaryFunc(Func):
def __init__(self, arg):
Func.__init__(self, arg)
class BinaryFunc(Func):
def __init__(self, left, right):
Func.__init__(self, left, right)
@property
def left(self):
return self.args[0]
@left.setter
def left(self, left):
self.args[0] = left
@property
def right(self):
return self.args[1]
@right.setter
def right(self, right):
self.args[1] = right
class BooleanFunc(Func):
@property
def type(self):
return Boolean
class IntFunc(Func):
@property
def type(self):
return Int
class DoubleFunc(Func):
@property
def type(self):
return Double
class StringFunc(Func):
@property
def type(self):
return String
class UpcastingFunc(Func):
@property
def type(self):
return max(arg.type for arg in self.args)
class AggFunc(Func):
# Avoid having a self.distinct because it would need to be __init__'d explictly,
# which none of the AggFunc subclasses do (ex: Avg doesn't have it's
# own __init__).
@property
def distinct(self):
return getattr(self, '_distinct', False)
@distinct.setter
def distinct(self, val):
return setattr(self, '_distinct', val)
# The classes below diverge from above by including the SQL representation. It's a lot
# easier this way because there are a lot of funcs but they all have the same
# structure. Non-standard funcs, such as string concatenation, would need to have
# their representation information elsewhere (like classes above).
Parentheses = type('Parentheses', (UnaryFunc, UpcastingFunc), {'FORMAT': '({})'})
IsNull = type('IsNull', (UnaryFunc, BooleanFunc), {'FORMAT': '{} IS NULL'})
IsNotNull = type('IsNotNull', (UnaryFunc, BooleanFunc), {'FORMAT': '{} IS NOT NULL'})
And = type('And', (BinaryFunc, BooleanFunc), {'FORMAT': '{} AND {}'})
Or = type('Or', (BinaryFunc, BooleanFunc), {'FORMAT': '{} OR {}'})
Equals = type('Equals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} = {}'})
NotEquals = type('NotEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} != {}'})
GreaterThan = type('GreaterThan', (BinaryFunc, BooleanFunc), {'FORMAT': '{} > {}'})
LessThan = type('LessThan', (BinaryFunc, BooleanFunc), {'FORMAT': '{} < {}'})
GreaterThanOrEquals = type(
'GreaterThanOrEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} >= {}'})
LessThanOrEquals = type(
'LessThanOrEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} <= {}'})
Plus = type('Plus', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} + {}'})
Minus = type('Minus', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} - {}'})
Multiply = type('Multiply', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} * {}'})
Divide = type('Divide', (BinaryFunc, DoubleFunc), {'FORMAT': '{} / {}'})
Floor = type('Floor', (UnaryFunc, IntFunc), {'FORMAT': 'FLOOR({})'})
Concat = type('Concat', (BinaryFunc, StringFunc), {'FORMAT': 'CONCAT({}, {})'})
Length = type('Length', (UnaryFunc, IntFunc), {'FORMAT': 'LENGTH({})'})
ExtractYear = type(
'ExtractYear', (UnaryFunc, IntFunc), {'FORMAT': "EXTRACT('YEAR' FROM {})"})
# Formatting of agg funcs is a little trickier since they may have a distinct
Avg = type('Avg', (UnaryFunc, DoubleFunc, AggFunc), {})
Count = type('Count', (UnaryFunc, IntFunc, AggFunc), {})
Max = type('Max', (UnaryFunc, UpcastingFunc, AggFunc), {})
Min = type('Min', (UnaryFunc, UpcastingFunc, AggFunc), {})
Sum = type('Sum', (UnaryFunc, UpcastingFunc, AggFunc), {})
UNARY_BOOLEAN_FUNCS = [IsNull, IsNotNull]
BINARY_BOOLEAN_FUNCS = [And, Or]
RELATIONAL_OPERATORS = [
Equals, NotEquals, GreaterThan, LessThan, GreaterThanOrEquals, LessThanOrEquals]
MATH_OPERATORS = [Plus, Minus, Multiply] # Leaving out Divide
BINARY_STRING_FUNCS = [Concat]
AGG_FUNCS = [Avg, Count, Max, Min, Sum]
class If(Func):
FORMAT = 'CASE WHEN {} THEN {} ELSE {} END'
def __init__(self, boolean_expr, consquent_expr, alternative_expr):
Func.__init__(
self, boolean_expr, consquent_expr, alternative_expr)
@property
def boolean_expr(self):
return self.args[0]
@property
def consquent_expr(self):
return self.args[1]
@property
def alternative_expr(self):
return self.args[2]
@property
def type(self):
return max((self.consquent_expr, self.alternative_expr))
class Greatest(BinaryFunc, UpcastingFunc, If):
def __init__(self, left, rigt):
BinaryFunc.__init__(self, left, rigt)
If.__init__(self, GreaterThan(left, rigt), left, rigt)
@property
def type(self):
return UpcastingFunc.type.fget(self)
class Cast(Func):
FORMAT = 'CAST({} AS {})'
def __init__(self, val_expr, resulting_type):
if resulting_type not in TYPES:
raise Exception('Unexpected type: {}'.format(resulting_type))
Func.__init__(self, val_expr, resulting_type)
@property
def val_expr(self):
return self.args[0]
@property
def resulting_type(self):
return self.args[1]
@property
def type(self):
return self.resulting_type

View File

@@ -0,0 +1,367 @@
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from inspect import getmro
from logging import getLogger
from re import sub
from sqlparse import format
from tests.comparison.model import (
Boolean,
Float,
Int,
Number,
Query,
String,
Timestamp)
LOG = getLogger(__name__)
class SqlWriter(object):
'''Subclasses of SQLWriter will take a Query and provide the SQL representation for a
specific database such as Impala or MySQL. The SqlWriter.create([dialect=])
factory method may be used instead of specifying the concrete class.
Another important function of this class is to ensure that CASTs produce the same
results across different databases. Sometimes the CASTs implemented here produce odd
results. For example, the result of CAST(date_col AS INT) in MySQL may be an int of
YYYYMMDD whereas in Impala it may be seconds since the epoch. For comparison purposes
the CAST could be transformed into EXTRACT(DAY from date_col).
'''
@staticmethod
def create(dialect='impala'):
'''Create and return a new SqlWriter appropriate for the given sql dialect. "dialect"
refers to database specific deviations of sql, and the val should be one of
"IMPALA", "MYSQL", or "POSTGRESQL".
'''
dialect = dialect.upper()
if dialect == 'IMPALA':
return SqlWriter()
if dialect == 'POSTGRESQL':
return PostgresqlSqlWriter()
if dialect == 'MYSQL':
return MySQLSqlWriter()
raise Exception('Unknown dialect: %s' % dialect)
def write_query(self, query, pretty=False):
'''Return SQL as a string for the given query.'''
sql = list()
# Write out each section in the proper order
for clause in (
query.with_clause,
query.select_clause,
query.from_clause,
query.where_clause,
query.group_by_clause,
query.having_clause,
query.union_clause):
if clause:
sql.append(self._write(clause))
sql = '\n'.join(sql)
if pretty:
sql = self.make_pretty_sql(sql)
return sql
def make_pretty_sql(self, sql):
try:
sql = format(sql, reindent=True)
except Exception as e:
LOG.warn('Unable to format sql: %s', e)
return sql
def _write_with_clause(self, with_clause):
return 'WITH ' + ',\n'.join('%s AS (%s)' % (view.identifier, self._write(view.query))
for view in with_clause.with_clause_inline_views)
def _write_select_clause(self, select_clause):
items = select_clause.non_agg_items + select_clause.agg_items
sql = 'SELECT'
if select_clause.distinct:
sql += ' DISTINCT'
sql += '\n' + ',\n'.join(self._write(item) for item in items)
return sql
def _write_select_item(self, select_item):
# If the query is nested, the items will have aliases so that the outer query can
# easily reference them.
if not select_item.alias:
raise Exception('An alias is required')
return '%s AS %s' % (self._write(select_item.val_expr), select_item.alias)
def _write_column(self, col):
return '%s.%s' % (col.owner.identifier, col.name)
def _write_from_clause(self, from_clause):
sql = 'FROM %s' % self._write(from_clause.table_expr)
if from_clause.join_clauses:
sql += '\n' + '\n'.join(self._write(join) for join in from_clause.join_clauses)
return sql
def _write_table(self, table):
if table.alias:
return '%s AS %s' % (table.name, table.identifier)
return table.name
def _write_inline_view(self, inline_view):
if not inline_view.identifier:
raise Exception('An inline view requires an identifier')
return '(\n%s\n) AS %s' % (self._write(inline_view.query), inline_view.identifier)
def _write_with_clause_inline_view(self, with_clause_inline_view):
if not with_clause_inline_view.with_clause_alias:
raise Exception('An with clause entry requires an identifier')
sql = with_clause_inline_view.with_clause_alias
if with_clause_inline_view.alias:
sql += ' AS ' + with_clause_inline_view.alias
return sql
def _write_join_clause(self, join_clause):
sql = '%s JOIN %s' % (join_clause.join_type, self._write(join_clause.table_expr))
if join_clause.boolean_expr:
sql += ' ON ' + self._write(join_clause.boolean_expr)
return sql
def _write_where_clause(self, where_clause):
return 'WHERE\n' + self._write(where_clause.boolean_expr)
def _write_group_by_clause(self, group_by_clause):
return 'GROUP BY\n' + ',\n'.join(self._write(item.val_expr)
for item in group_by_clause.group_by_items)
def _write_having_clause(self, having_clause):
return 'HAVING\n' + self._write(having_clause.boolean_expr)
def _write_union_clause(self, union_clause):
sql = 'UNION'
if union_clause.all:
sql += ' ALL'
sql += '\n' + self._write(union_clause.query)
return sql
def _write_data_type(self, data_type):
'''Write a literal value.'''
if data_type.returns_string:
return "'{}'".format(data_type.val)
if data_type.returns_timestamp:
return "CAST('{}' AS TIMESTAMP)".format(data_type.val)
return str(data_type.val)
def _write_func(self, func):
return func.FORMAT.format(*[self._write(arg) for arg in func.args])
def _write_cast(self, cast):
# Handle casts that produce different results across database types or just don't
# make sense like casting a DATE as a BOOLEAN....
if cast.val_expr.returns_boolean:
if issubclass(cast.resulting_type, Timestamp):
return "CAST(CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END AS TIMESTAMP)"\
.format(self._write(cast.val_expr))
elif cast.val_expr.returns_number:
if issubclass(cast.resulting_type, Timestamp):
return ("CAST(CONCAT('2000-01-', "
"LPAD(CAST(ABS(FLOOR({})) % 31 + 1 AS STRING), 2, '0')) "
"AS TIMESTAMP)").format(self._write(cast.val_expr))
elif cast.val_expr.returns_string:
if issubclass(cast.resulting_type, Boolean):
return "(LENGTH({}) > 2)".format(self._write(cast.val_expr))
if issubclass(cast.resulting_type, Timestamp):
return ("CAST(CONCAT('2000-01-', LPAD(CAST(LENGTH({}) % 31 + 1 AS STRING), "
"2, '0')) AS TIMESTAMP)").format(self._write(cast.val_expr))
elif cast.val_expr.returns_timestamp:
if issubclass(cast.resulting_type, Boolean):
return '(DAY({0}) > MONTH({0}))'.format(self._write(cast.val_expr))
if issubclass(cast.resulting_type, Number):
return ('(DAY({0}) + 100 * MONTH({0}) + 100 * 100 * YEAR({0}))').format(
self._write(cast.val_expr))
return self._write_func(cast)
def _write_agg_func(self, agg_func):
sql = type(agg_func).__name__.upper() + '('
if agg_func.distinct:
sql += 'DISTINCT '
# All agg funcs only have a single arg
sql += self._write(agg_func.args[0]) + ')'
return sql
def _write_data_type_metaclass(self, data_type_class):
'''Write a data type class such as Int or Boolean.'''
return data_type_class.__name__.upper()
def _write(self, object_):
'''Return a sql string representation of the given object.'''
# What's below is effectively a giant switch statement. It works based on a func
# naming and signature convention. It should match the incoming object with the
# corresponding func defined, then call the func and return the result.
#
# Ex:
# a = model.And(...)
# _write(a) should call _write_func(a) because "And" is a subclass of "Func" and no
# other _writer_<class name> methods have been defined higher up the method
# resolution order (MRO). If _write_and(...) were to be defined, it would be called
# instead.
for type_ in getmro(type(object_)):
writer_func_name = '_write' + sub('([A-Z])', r'_\1', type_.__name__).lower()
writer_func = getattr(self, writer_func_name, None)
if writer_func:
return writer_func(object_)
# Handle any remaining cases
if isinstance(object_, Query):
return self.write_query(object_)
raise Exception('Unsupported object: %s<%s>' % (type(object_).__name__, object_))
class PostgresqlSqlWriter(SqlWriter):
# TODO: This class is out of date since switching to MySQL. This is left here as is
# in case there is a desire to switch back in the future (it should be better than
# starting from nothing).
def _write_divide(self, divide):
# For ints, Postgresql does int division but Impala does float division.
return 'CAST({} AS REAL) / {}' \
.format(*[self._write(arg) for arg in divide.args])
def _write_data_type_metaclass(self, data_type_class):
'''Write a data type class such as Int or Boolean.'''
if hasattr(data_type_class, 'POSTGRESQL'):
return data_type_class.POSTGRESQL[0]
return data_type_class.__name__.upper()
def _write_cast(self, cast):
# Handle casts that produce different results across database types or just don't
# make sense like casting a DATE as a BOOLEAN....
if cast.val_expr.returns_boolean:
if issubclass(cast.resulting_type, Float):
return "CASE {} WHEN TRUE THEN 1.0 WHEN FALSE THEN 0.0 END".format(
self._write(cast.val_expr))
if issubclass(cast.resulting_type, Timestamp):
return "CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END".format(
self._write(cast.val_expr))
if issubclass(cast.resulting_type, String):
return "CASE {} WHEN TRUE THEN '1' WHEN FALSE THEN '0' END".format(
self._write(cast.val_expr))
elif cast.val_expr.returns_number:
if issubclass(cast.resulting_type, Boolean):
return 'CASE WHEN ({0}) != 0 THEN TRUE WHEN ({0}) = 0 THEN FALSE END'.format(
self._write(cast.val_expr))
if issubclass(cast.resulting_type, Timestamp):
return "CASE WHEN ({}) > 0 THEN '2000-01-01' ELSE '1999-01-01' END".format(
self._write(cast.val_expr))
elif cast.val_expr.returns_string:
if issubclass(cast.resulting_type, Boolean):
return "(LENGTH({}) > 2)".format(self._write(cast.val_expr))
elif cast.val_expr.returns_timestamp:
if issubclass(cast.resulting_type, Boolean):
return '(EXTRACT(DAY FROM {0}) > EXTRACT(MONTH FROM {0}))'.format(
self._write(cast.val_expr))
if issubclass(cast.resulting_type, Number):
return ('(EXTRACT(DAY FROM {0}) '
'+ 100 * EXTRACT(MONTH FROM {0}) '
'+ 100 * 100 * EXTRACT(YEAR FROM {0}))').format(
self._write(cast.val_expr))
return self._write_func(cast)
class MySQLSqlWriter(SqlWriter):
def write_query(self, query, pretty=False):
# MySQL doesn't support WITH clauses so they need to be converted into inline views.
# We are going to cheat by making use of the fact that the query generator creates
# with clause entries with unique aliases even considering nested queries.
sql = list()
for clause in (
query.select_clause,
query.from_clause,
query.where_clause,
query.group_by_clause,
query.having_clause,
query.union_clause):
if clause:
sql.append(self._write(clause))
sql = '\n'.join(sql)
if query.with_clause:
# Just replace the named referenes with inline views. Go in reverse order because
# entries at the bottom of the WITH clause definition may reference entries above.
for with_clause_inline_view in reversed(query.with_clause.with_clause_inline_views):
replacement_sql = '(' + self.write_query(with_clause_inline_view.query) + ')'
sql = sql.replace(with_clause_inline_view.identifier, replacement_sql)
if pretty:
sql = self.make_pretty_sql(sql)
return sql
def _write_data_type_metaclass(self, data_type_class):
'''Write a data type class such as Int or Boolean.'''
if issubclass(data_type_class, Int):
return 'INTEGER'
if issubclass(data_type_class, Float):
return 'DECIMAL(65, 15)'
if issubclass(data_type_class, String):
return 'CHAR'
if hasattr(data_type_class, 'MYSQL'):
return data_type_class.MYSQL[0]
return data_type_class.__name__.upper()
def _write_data_type(self, data_type):
'''Write a literal value.'''
if data_type.returns_timestamp:
return "CAST('{}' AS DATETIME)".format(data_type.val)
if data_type.returns_boolean:
# MySQL will error if a data_type "FALSE" is used as a GROUP BY field
return '(0 = 0)' if data_type.val else '(1 = 0)'
return SqlWriter._write_data_type(self, data_type)
def _write_cast(self, cast):
# Handle casts that produce different results across database types or just don't
# make sense like casting a DATE as a BOOLEAN....
if cast.val_expr.returns_boolean:
if issubclass(cast.resulting_type, Timestamp):
return "CAST(CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END AS DATETIME)"\
.format(self._write(cast.val_expr))
elif cast.val_expr.returns_number:
if issubclass(cast.resulting_type, Boolean):
return ("CASE WHEN ({0}) != 0 THEN TRUE WHEN ({0}) = 0 THEN FALSE END").format(
self._write(cast.val_expr))
if issubclass(cast.resulting_type, Timestamp):
return "CAST(CONCAT('2000-01-', ABS(FLOOR({})) % 31 + 1) AS DATETIME)"\
.format(self._write(cast.val_expr))
elif cast.val_expr.returns_string:
if issubclass(cast.resulting_type, Boolean):
return "(LENGTH({}) > 2)".format(self._write(cast.val_expr))
if issubclass(cast.resulting_type, Timestamp):
return ("CAST(CONCAT('2000-01-', LENGTH({}) % 31 + 1) AS DATETIME)").format(
self._write(cast.val_expr))
elif cast.val_expr.returns_timestamp:
if issubclass(cast.resulting_type, Number):
return ('(EXTRACT(DAY FROM {0}) '
'+ 100 * EXTRACT(MONTH FROM {0}) '
'+ 100 * 100 * EXTRACT(YEAR FROM {0}))').format(
self._write(cast.val_expr))
if issubclass(cast.resulting_type, Boolean):
return '(EXTRACT(DAY FROM {0}) > EXTRACT(MONTH FROM {0}))'.format(
self._write(cast.val_expr))
# MySQL uses different type names when casting...
if issubclass(cast.resulting_type, Boolean):
data_type = 'UNSIGNED'
elif issubclass(cast.resulting_type, Float):
data_type = 'DECIMAL(65, 15)'
elif issubclass(cast.resulting_type, Int):
data_type = 'SIGNED'
elif issubclass(cast.resulting_type, String):
data_type = 'CHAR'
elif issubclass(cast.resulting_type, Timestamp):
data_type = 'DATETIME'
return cast.FORMAT.format(self._write(cast.val_expr), data_type)

View File

@@ -0,0 +1,556 @@
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from copy import deepcopy
from itertools import chain
from random import choice, randint, shuffle
from tests.comparison.model import (
AGG_FUNCS,
AggFunc,
And,
BINARY_STRING_FUNCS,
BigInt,
Boolean,
Cast,
Column,
Count,
DataType,
Double,
Equals,
Float,
Floor,
FromClause,
Func,
Greatest,
GroupByClause,
HavingClause,
InlineView,
Int,
JoinClause,
Length,
MATH_OPERATORS,
Number,
Query,
RELATIONAL_OPERATORS,
SelectClause,
SelectItem,
String,
Table,
Timestamp,
TYPES,
UNARY_BOOLEAN_FUNCS,
UnionClause,
WhereClause,
WithClause,
WithClauseInlineView)
def random_boolean():
'''Return a val that evaluates to True 50% of the time'''
return randint(0, 1)
def zero_or_more():
'''The chance of the return val of n is 1 / 2 ^ (n + 1)'''
val = 0
while random_boolean():
val += 1
return val
def one_or_more():
return zero_or_more() + 1
def random_non_empty_split(iterable):
'''Return two non-empty lists'''
if len(iterable) < 2:
raise Exception('The iterable must contain at least two items')
split_index = randint(1, len(iterable) - 1)
left, right = list(), list()
for idx, item in enumerate(iterable):
if idx < split_index:
left.append(item)
else:
right.append(item)
return left, right
class QueryGenerator(object):
def create_query(self,
table_exprs,
allow_with_clause=True,
select_item_data_types=None):
'''Create a random query using various language features.
The initial call to this method should only use tables in the table_exprs
parameter, and not inline views or "with" definitions. The other types of
table exprs may be added as part of the query generation.
If select_item_data_types is specified it must be a sequence or iterable of
DataType. The generated query.select_clause.select_items will have data
types suitable for use in a UNION.
'''
# Make a copy so tables can be added if a "with" clause is used
table_exprs = list(table_exprs)
with_clause = None
if allow_with_clause and randint(1, 10) == 1:
with_clause = self._create_with_clause(table_exprs)
table_exprs.extend(with_clause.table_exprs)
from_clause = self._create_from_clause(table_exprs)
select_clause = self._create_select_clause(
from_clause.table_exprs,
select_item_data_types=select_item_data_types)
query = Query(select_clause, from_clause)
if with_clause:
query.with_clause = with_clause
if random_boolean():
query.where_clause = self._create_where_clause(from_clause.table_exprs)
if select_clause.agg_items and select_clause.non_agg_items:
query.group_by_clause = GroupByClause(list(select_clause.non_agg_items))
if randint(1, 10) == 1:
if select_clause.agg_items:
self._enable_distinct_on_random_agg_items(select_clause.agg_items)
else:
select_clause.distinct = True
if random_boolean() and (query.group_by_clause or select_clause.agg_items):
query.having_clause = self._create_having_clause(from_clause.table_exprs)
if randint(1, 10) == 1:
select_item_data_types = list()
for select_item in select_clause.select_items:
# For numbers, choose the largest possible data type in case a CAST is needed.
if select_item.val_expr.returns_float:
select_item_data_types.append(Double)
elif select_item.val_expr.returns_int:
select_item_data_types.append(BigInt)
else:
select_item_data_types.append(select_item.val_expr.type)
query.union_clause = UnionClause(self.create_query(
table_exprs,
allow_with_clause=False,
select_item_data_types=select_item_data_types))
query.union_clause.all = random_boolean()
return query
def _create_with_clause(self, table_exprs):
# Make a copy so newly created tables can be added and made availabele for use in
# future table definitions.
table_exprs = list(table_exprs)
with_clause_inline_views = list()
for with_clause_inline_view_idx in xrange(one_or_more()):
query = self.create_query(table_exprs)
# To help prevent nested WITH clauses from having entries with the same alias,
# choose a random alias. Of course it would be much better to know which aliases
# were already chosen but that information isn't easy to get from here.
with_clause_alias = 'with_%s_%s' % \
(with_clause_inline_view_idx + 1, randint(1, 1000))
with_clause_inline_view = WithClauseInlineView(query, with_clause_alias)
table_exprs.append(with_clause_inline_view)
with_clause_inline_views.append(with_clause_inline_view)
return WithClause(with_clause_inline_views)
def _create_select_clause(self, table_exprs, select_item_data_types=None):
while True:
non_agg_items = [self._create_non_agg_select_item(table_exprs)
for _ in xrange(zero_or_more())]
agg_items = [self._create_agg_select_item(table_exprs)
for _ in xrange(zero_or_more())]
if non_agg_items or agg_items:
if select_item_data_types:
if len(select_item_data_types) > len(non_agg_items) + len(agg_items):
# Not enough items generated, try again
continue
while len(select_item_data_types) < len(non_agg_items) + len(agg_items):
items = choice([non_agg_items, agg_items])
if items:
items.pop()
for data_type_idx, data_type in enumerate(select_item_data_types):
if data_type_idx < len(non_agg_items):
item = non_agg_items[data_type_idx]
else:
item = agg_items[data_type_idx - len(non_agg_items)]
if not issubclass(item.type, data_type):
item.val_expr = self.convert_val_expr_to_type(item.val_expr, data_type)
for idx, item in enumerate(chain(non_agg_items, agg_items)):
item.alias = '%s_col_%s' % (item.type.__name__.lower(), idx + 1)
return SelectClause(non_agg_items=non_agg_items, agg_items=agg_items)
def _choose_col(self, table_exprs):
table_expr = choice(table_exprs)
return choice(table_expr.cols)
def _create_non_agg_select_item(self, table_exprs):
return SelectItem(self._create_val_expr(table_exprs))
def _create_val_expr(self, table_exprs):
vals = [self._choose_col(table_exprs) for _ in xrange(one_or_more())]
return self._combine_val_exprs(vals)
def _create_agg_select_item(self, table_exprs):
vals = [self._create_agg_val_expr(table_exprs) for _ in xrange(one_or_more())]
return SelectItem(self._combine_val_exprs(vals))
def _create_agg_val_expr(self, table_exprs):
val = self._create_val_expr(table_exprs)
if issubclass(val.type, Number):
funcs = list(AGG_FUNCS)
else:
funcs = [Count]
return choice(funcs)(val)
def _create_from_clause(self, table_exprs):
table_expr = self._create_table_expr(table_exprs)
table_expr_count = 1
table_expr.alias = 't%s' % table_expr_count
from_clause = FromClause(table_expr)
for join_idx in xrange(zero_or_more()):
join_clause = self._create_join_clause(from_clause, table_exprs)
table_expr_count += 1
join_clause.table_expr.alias = 't%s' % table_expr_count
from_clause.join_clauses.append(join_clause)
return from_clause
def _create_table_expr(self, table_exprs):
if randint(1, 10) == 1:
return self._create_inline_view(table_exprs)
return self._choose_table(table_exprs)
def _choose_table(self, table_exprs):
return deepcopy(choice(table_exprs))
def _create_inline_view(self, table_exprs):
return InlineView(self.create_query(table_exprs))
def _create_join_clause(self, from_clause, table_exprs):
table_expr = self._create_table_expr(table_exprs)
# Increase the chance of using the first join type which is INNER
join_type_idx = (zero_or_more() / 2) % len(JoinClause.JOINS_TYPES)
join_type = JoinClause.JOINS_TYPES[join_type_idx]
join_clause = JoinClause(join_type, table_expr)
# Prefer non-boolean cols for the first condition. Boolean cols produce too
# many results so it's unlikely that someone would want to join tables only using
# boolean cols.
non_boolean_types = set(type_ for type_ in TYPES if not issubclass(type_, Boolean))
if join_type != 'CROSS':
join_clause.boolean_expr = self._combine_val_exprs(
[self._create_relational_join_condition(
table_expr,
choice(from_clause.table_exprs),
prefered_data_types=(non_boolean_types if idx == 0 else set()))
for idx in xrange(one_or_more())],
resulting_type=Boolean)
return join_clause
def _create_relational_join_condition(self,
left_table_expr,
right_table_expr,
prefered_data_types):
# "base type" means condense all int types into just int, same for floats
left_cols_by_base_type = left_table_expr.cols_by_base_type
right_cols_by_base_type = right_table_expr.cols_by_base_type
common_col_types = set(left_cols_by_base_type) & set(right_cols_by_base_type)
if prefered_data_types:
common_col_types &= prefered_data_types
if common_col_types:
col_type = choice(list(common_col_types))
left = choice(left_cols_by_base_type[col_type])
right = choice(right_cols_by_base_type[col_type])
else:
col_type = None
if prefered_data_types:
for available_col_types in (left_cols_by_base_type, right_cols_by_base_type):
prefered_available_col_types = set(available_col_types) & prefered_data_types
if prefered_available_col_types:
col_type = choice(list(prefered_available_col_types))
break
if not col_type:
col_type = choice(left_cols_by_base_type.keys())
if col_type in left_cols_by_base_type:
left = choice(left_cols_by_base_type[col_type])
else:
left = choice(choice(left_cols_by_base_type.values()))
left = self.convert_val_expr_to_type(left, col_type)
if col_type in right_cols_by_base_type:
right = choice(right_cols_by_base_type[col_type])
else:
right = choice(choice(right_cols_by_base_type.values()))
right = self.convert_val_expr_to_type(right, col_type)
return Equals(left, right)
def _create_where_clause(self, table_exprs):
boolean_exprs = list()
# Create one boolean expr per iteration...
for _ in xrange(one_or_more()):
col_type = None
cols = list()
# ...using one or more cols...
for _ in xrange(one_or_more()):
# ...from any random table, inline view, etc.
table_expr = choice(table_exprs)
if not col_type:
col_type = choice(list(table_expr.cols_by_base_type))
if col_type in table_expr.cols_by_base_type:
col = choice(table_expr.cols_by_base_type[col_type])
else:
col = choice(table_expr.cols)
cols.append(col)
boolean_exprs.append(self._combine_val_exprs(cols, resulting_type=Boolean))
return WhereClause(self._combine_val_exprs(boolean_exprs))
def _combine_val_exprs(self, vals, resulting_type=None):
'''Combine the given vals into a single val.
If resulting_type is specified, the returned val will be of that type. If
the resulting data type was not specified, it will be randomly chosen from the
types of the input vals.
'''
if not vals:
raise Exception('At least one val is required')
types_to_vals = DataType.group_by_base_type(vals)
if not resulting_type:
resulting_type = choice(types_to_vals.keys())
vals_of_resulting_type = list()
for val_type, vals in types_to_vals.iteritems():
if issubclass(val_type, resulting_type):
vals_of_resulting_type.extend(vals)
elif resulting_type == Boolean:
# To produce other result types, the vals will be aggd into a single val
# then converted into the desired type. However to make a boolean, relational
# operaters can be used on the vals to make a more realistic query.
val = self._create_boolean_expr_from_vals_of_same_type(vals)
vals_of_resulting_type.append(val)
else:
val = self._combine_vals_of_same_type(vals)
if not (issubclass(val.type, Number) and issubclass(resulting_type, Number)):
val = self.convert_val_expr_to_type(val, resulting_type)
vals_of_resulting_type.append(val)
return self._combine_vals_of_same_type(vals_of_resulting_type)
def _create_boolean_expr_from_vals_of_same_type(self, vals):
if not vals:
raise Exception('At least one val is required')
if len(vals) == 1:
val = vals[0]
if Boolean == val.type:
return val
# Convert a single non-boolean val into a boolean using a func like
# IsNull or IsNotNull.
return choice(UNARY_BOOLEAN_FUNCS)(val)
if len(vals) == 2:
left, right = vals
if left.type == right.type:
if left.type == String:
# Databases may vary in how string comparisons are done. Results may differ
# when using operators like > or <, so just always use =.
return Equals(left, right)
if left.type == Boolean:
# TODO: Enable "OR" at some frequency, using OR at 50% will probably produce
# too many slow queries.
return And(left, right)
# At this point we've got two data points of the same type so any valid
# relational operator is valid and will produce a boolean.
return choice(RELATIONAL_OPERATORS)(left, right)
elif issubclass(left.type, Number) and issubclass(right.type, Number):
# Numbers need not be of the same type. SmallInt, BigInt, etc can all be compared.
# Note: For now ints are the only numbers enabled and division is disabled
# though AVG() is in use. If floats are enabled this will likely need to be
# updated to do some rounding based comparison.
return choice(RELATIONAL_OPERATORS)(left, right)
raise Exception('Vals are not of the same type: %s<%s> vs %s<%s>'
% (left, left.type, right, right.type))
# Reduce the number of inputs and try again...
left_subset, right_subset = random_non_empty_split(vals)
return self._create_boolean_expr_from_vals_of_same_type([
self._combine_vals_of_same_type(left_subset),
self._combine_vals_of_same_type(right_subset)])
def _combine_vals_of_same_type(self, vals):
'''Combine the given vals into a single expr of the same type. The input
vals must be of the same base data type. For example Int's must not be mixed
with Strings.
'''
if not vals:
raise Exception('At least one val is required')
val_type = None
for val in vals:
if not val_type:
if issubclass(val.type, Number):
val_type = Number
else:
val_type = val.type
elif not issubclass(val.type, val_type):
raise Exception('Incompatable types %s and %s' % (val_type, val.type))
if len(vals) == 1:
return vals[0]
if val_type == Number:
funcs = MATH_OPERATORS
elif val_type == Boolean:
# TODO: Enable "OR" at some frequency
funcs = [And]
elif val_type == String:
funcs = BINARY_STRING_FUNCS
return vals[0]
elif val_type == Timestamp:
funcs = [Greatest]
vals = list(vals)
shuffle(vals)
left = vals.pop()
right = vals.pop()
while True:
func = choice(funcs)
left = func(left, right)
if not vals:
return left
right = vals.pop()
def convert_val_expr_to_type(self, val_expr, resulting_type):
if resulting_type not in TYPES:
raise Exception('Unexpected type: {}'.format(resulting_type))
val_type = val_expr.type
if issubclass(val_type, resulting_type):
return val_expr
if issubclass(resulting_type, Int):
if val_expr.returns_float:
# Impala will FLOOR while Postgresql will ROUND. Use FLOOR to be conistent.
return Floor(val_expr)
if issubclass(resulting_type, Number):
if val_expr.returns_string:
return Length(val_expr)
if issubclass(resulting_type, String):
if val_expr.returns_float:
# Different databases may use different precision.
return Cast(Floor(val_expr), resulting_type)
return Cast(val_expr, resulting_type)
def _create_having_clause(self, table_exprs):
boolean_exprs = list()
# Create one boolean expr per iteration...
for _ in xrange(one_or_more()):
agg_items = list()
# ...using one or more agg exprs...
for _ in xrange(one_or_more()):
vals = [self._create_agg_val_expr(table_exprs) for _ in xrange(one_or_more())]
agg_items.append(self._combine_val_exprs(vals))
boolean_exprs.append(self._combine_val_exprs(agg_items, resulting_type=Boolean))
return HavingClause(self._combine_val_exprs(boolean_exprs))
def _enable_distinct_on_random_agg_items(self, agg_items):
'''Randomly choose an agg func and set it to use DISTINCT'''
# Impala has a limitation where 'DISTINCT' may only be applied to one agg
# expr. If an agg expr is used more than once, each usage may
# or may not include DISTINCT.
#
# Examples:
# OK: SELECT COUNT(DISTINCT a) + SUM(DISTINCT a) + MAX(a)...
# Not OK: SELECT COUNT(DISTINCT a) + COUNT(DISTINCT b)...
#
# Given a select list like:
# COUNT(a), SUM(a), MAX(b)
#
# We want to ouput one of:
# COUNT(DISTINCT a), SUM(DISTINCT a), AVG(b)
# COUNT(DISTINCT a), SUM(a), AVG(b)
# COUNT(a), SUM(a), AVG(DISTINCT b)
#
# This will be done by first grouping all agg funcs by their inner
# expr:
# {a: [COUNT(a), SUM(a)],
# b: [MAX(b)]}
#
# then choosing a random val (which is a list of aggs) in the above dict, and
# finaly randomly adding DISTINCT to items in the list.
exprs_to_funcs = defaultdict(list)
for item in agg_items:
for expr, funcs in self._group_agg_funcs_by_expr(item.val_expr).iteritems():
exprs_to_funcs[expr].extend(funcs)
funcs = choice(exprs_to_funcs.values())
for func in funcs:
if random_boolean():
func.distinct = True
def _group_agg_funcs_by_expr(self, val_expr):
'''Group exprs and return a dict mapping the expr to the agg items
it is used in.
Example: COUNT(a) * SUM(a) - MAX(b) + MIN(c) -> {a: [COUNT(a), SUM(a)],
b: [MAX(b)],
c: [MIN(c)]}
'''
exprs_to_funcs = defaultdict(list)
if isinstance(val_expr, AggFunc):
exprs_to_funcs[tuple(val_expr.args)].append(val_expr)
elif isinstance(val_expr, Func):
for arg in val_expr.args:
for expr, funcs in self._group_agg_funcs_by_expr(arg).iteritems():
exprs_to_funcs[expr].extend(funcs)
# else: The remaining case could happen if the original expr was something like
# "SUM(a) + b + 1" where b is a GROUP BY field.
return exprs_to_funcs
if __name__ == '__main__':
'''Generate some queries for manual inspection. The query won't run anywhere because the
tables used are fake. To make real queries, we'd need to connect to a database and
read the table metadata and such.
'''
tables = list()
data_types = TYPES
data_types.remove(Float)
data_types.remove(Double)
for table_idx in xrange(5):
table = Table('table_%s' % table_idx)
tables.append(table)
for col_idx in xrange(3):
col_type = choice(data_types)
col = Column(table, '%s_col_%s' % (col_type.__name__.lower(), col_idx), col_type)
table.cols.append(col)
query_generator = QueryGenerator()
from model_translator import SqlWriter
sql_writer = SqlWriter.create()
for _ in range(3000):
query = query_generator.create_query(tables)
print(sql_writer.write_query(query) + '\n')