mirror of
https://github.com/apache/impala.git
synced 2025-12-19 18:12:08 -05:00
Testing: Generate queries and compare results against other databases
This is the intital commit and is a work in progress. See the README for a
list of possible improvements.
As an overview of how the files are related:
model.py: This is the base upon which the other files are built. It
contains something like a grammer for queries.
query_generator.py: Generates random permutations of the model.
model_translator.py: Produces SQL based on the model
discrepancy_searcher.py: Uses the above to generate, run, and compare
query results.
Change-Id: Iaca6277766f5a86568eaa3f05b99c832942ab38b
Reviewed-on: http://gerrit.ent.cloudera.com:8080/1648
Reviewed-by: Casey Ching <casey@cloudera.com>
Tested-by: Casey Ching <casey@cloudera.com>
This commit is contained in:
@@ -14,4 +14,4 @@ DATASRC="a2226.halxg.cloudera.com:/data/1/workspace/impala-data"
|
||||
DATADST=$IMPALA_HOME/testdata/impala-data
|
||||
mkdir -p $DATADST
|
||||
|
||||
scp -i $IMPALA_HOME/ssh_keys/id_rsa_impala -o "StrictHostKeyChecking=no" -r $DATASRC/* $DATADST
|
||||
scp -i $HOME/.ssh/id_rsa_jenkins -o "StrictHostKeyChecking=no" -r systest@$DATASRC/* $DATADST
|
||||
|
||||
164
tests/comparison/README
Normal file
164
tests/comparison/README
Normal file
@@ -0,0 +1,164 @@
|
||||
Purpose:
|
||||
|
||||
This package is intended to augment the standard test suite. The standard tests are
|
||||
more efficient with regards to features tested versus execution time. However their
|
||||
coverage as a test suite still leaves gaps in query coverage. This package provides a
|
||||
random query generator to compare the results of a wide range of queries against a
|
||||
reference database engine. The queries will range from very simple single table selects to
|
||||
extremely complicated with multiple level of nesting. This method of testing will be
|
||||
slower but has a larger coverage area.
|
||||
|
||||
|
||||
Requirements:
|
||||
|
||||
1) It's assumed that Impala is running locally.
|
||||
|
||||
2) Impyla -- an implementation of DB API 2 for Impala.
|
||||
|
||||
sudo pip install git+http://github.com/laserson/impyla.git#impyla
|
||||
|
||||
3) At least one python driver for a reference database.
|
||||
|
||||
sudo apt-get install python-mysqldb
|
||||
sudo apt-get install python-psycopg2 # Postgresql
|
||||
|
||||
|
||||
Usage:
|
||||
|
||||
1) Generate test data
|
||||
|
||||
./data_generator.py --use-mysql
|
||||
|
||||
This will generate tables and data in MySQL and Impala
|
||||
|
||||
|
||||
2) Run the comparison
|
||||
|
||||
./discrepancy_searcher.py
|
||||
|
||||
This will generate queries using the test database and compare the results against
|
||||
MySQL (the default).
|
||||
|
||||
|
||||
Known Issues:
|
||||
|
||||
1) Floats will produce false-positives. For example the results of a query that has
|
||||
|
||||
SELECT FLOOR(COUNT(...) * AVG(...)) AS col_1
|
||||
|
||||
will produce different results on Impala and MySQL if COUNT() == 3 and AVG() == 1/3.
|
||||
One of the databasses will FLOOR(1) while the other will FLOOR(0.999).
|
||||
|
||||
Maybe this could be worked around or reduced by replacing all uses of AVG() with
|
||||
"AVG() + foo", where foo is some number that makes it unlikely that
|
||||
"COUNT() * (AVG() + foo)" will result in an int.
|
||||
|
||||
I'd guess this issue comes up in 1 out of 10-20k queries.
|
||||
|
||||
2) Impyla may fail with "Invalid query handle". Some queries will fail every time when run
|
||||
through Impyla but run fine through the impala-shell. I need to research more and file
|
||||
an issue with Impyla.
|
||||
|
||||
3) Impyla will fail with "Invalid session". I'm pretty sure this is also an Impyla issue
|
||||
but also need to investigate more.
|
||||
|
||||
|
||||
Things to Know:
|
||||
|
||||
1) A good number of queries to run seems to be about 5k. Ideally each test run would
|
||||
discover the complete list of known issues. From experience a 1k query test run may
|
||||
complete without finding any issues that were discovered in previous runs. 5k seems
|
||||
to be about the magic number were most issues will be rediscovered. This can take 1-2
|
||||
hours. However as of this writing it's rare to run 1k queries without finding at
|
||||
least one discrepancy.
|
||||
|
||||
2) It's possible to provide a randomization seed so that the randomness is actually
|
||||
reproducable. The data generation currently has a default seed so will always produce
|
||||
the same tables. This also mean if a new data type is added those generated tables
|
||||
will change.
|
||||
|
||||
3) There is a query log. It's possible that a sequence of queries is required to expose
|
||||
a bug. If you come across a failure that can't be reproduced by rerunning the failed
|
||||
query, try running the queries leading up to that query as well.
|
||||
|
||||
|
||||
Miscellaneous:
|
||||
|
||||
1) Instead of generating new random queries with each run, it may be better to reuse a
|
||||
list of queries from a previous run that are known to produce results. As of this
|
||||
writing only about 50% of queries produce results. So it may be better to trade high
|
||||
randomness for higher quality queries. For example it would be possible to build up a
|
||||
library of 100k queries that produce results then randomly select 2.5k of those.
|
||||
Maybe that would provide testing equivalent to 5k totally random queries in less
|
||||
time.
|
||||
|
||||
This would also be useful in eliminating queries that have known issues above.
|
||||
|
||||
|
||||
Postgresql:
|
||||
|
||||
1) Supports bascially all Impala language features
|
||||
|
||||
2) Does int division, 1 / 2 == 0
|
||||
|
||||
3) Has strange sorting of strings, '-1' > '1'. This may be important if ORDER BY is ever
|
||||
used. The databases being compared would need to have the same collation, which is
|
||||
probably configurable.
|
||||
|
||||
4) This was the original reference database but I moved to MySQL while trying to add
|
||||
support for floats and never moved back.
|
||||
|
||||
|
||||
MySQL:
|
||||
|
||||
1) Supports bascially all Impala language features, except WITH clause requires emulation
|
||||
with inline views.
|
||||
|
||||
2) Has poor boolean support. It may be worth switching back to Postgresql for this.
|
||||
|
||||
|
||||
Improvements:
|
||||
|
||||
1) Add support for simplifing buggy queries. When a random query fails the comparison
|
||||
check it is basically always much too complex for directly posting a bug report. It
|
||||
is also time consuming to simplify the queries because there is a lot of trial and
|
||||
error and manually editing queries.
|
||||
|
||||
2) Add more language features
|
||||
|
||||
a) SEMI JOIN
|
||||
b) ORDER BY
|
||||
c) LIMIT, OFFSET
|
||||
d) CASE / WHEN
|
||||
|
||||
3) Add common built-in functions. Ex: CAST, IF, NVL, ...
|
||||
|
||||
4) Make randomization of the query generation configurable. As of this writing all the
|
||||
probabilities are hard-coded. At a very minimum it should be easy to disable or force
|
||||
the use of some language features such as CROSS JOIN, GROUP BY, etc.
|
||||
|
||||
5) More investingation of using the existing "functional" test datasets. A very quick
|
||||
trial run wasn't successful but another attempt with more effort should be made before
|
||||
introducing a new dataset.
|
||||
|
||||
I suspect the problem with using the functional dataset was that I only imported a few
|
||||
tables, maybe alltypes, alltypesagg, and something else. I don't think I imported the
|
||||
tiny tables since the odds of them producing results from a random query would be
|
||||
very low.
|
||||
|
||||
6) If the functional dataset cannot be used, someone should think more about what the
|
||||
random data should be like. Only a few minutes of thought were put into selecting
|
||||
random value ranges (including number of tables and columns), and it's not clear how
|
||||
important those ranges are.
|
||||
|
||||
7) Add support for comparing results with codegen enabled and disabled. Uri recently added
|
||||
support for query options in Impyla.
|
||||
|
||||
8) Consider adding Oracle or SQL Server support, these could be useful in the future for
|
||||
analytic queries.
|
||||
|
||||
9) Try running with tables in various formats. Ex: parquet and/or avro.
|
||||
|
||||
10) Support for more data types. Only int types are known to give good results.
|
||||
Floats may work but non-numeric types are not supported yet.
|
||||
|
||||
0
tests/comparison/__init__.py
Normal file
0
tests/comparison/__init__.py
Normal file
409
tests/comparison/data_generator.py
Executable file
409
tests/comparison/data_generator.py
Executable file
@@ -0,0 +1,409 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''This module provides random data generation and database population.
|
||||
|
||||
When this module is run directly for purposes of database population, the default is
|
||||
to use a fixed seed for randomization. The result should be that the generated random
|
||||
data is the same regardless of when or where the execution is done.
|
||||
|
||||
'''
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from logging import basicConfig, getLogger
|
||||
from random import choice, randint, random, seed, uniform
|
||||
|
||||
from tests.comparison.db_connector import (
|
||||
DbConnector,
|
||||
IMPALA,
|
||||
MYSQL,
|
||||
POSTGRESQL)
|
||||
from tests.comparison.model import (
|
||||
Boolean,
|
||||
Column,
|
||||
Float,
|
||||
Int,
|
||||
Number,
|
||||
String,
|
||||
Table,
|
||||
Timestamp,
|
||||
TYPES)
|
||||
|
||||
LOG = getLogger(__name__)
|
||||
|
||||
class RandomValGenerator(object):
|
||||
'''This class will generate random data of various data types. Currently only numeric
|
||||
and string data types are supported.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self,
|
||||
min_number=-1000,
|
||||
max_number=1000,
|
||||
min_date=datetime(1990, 1, 1),
|
||||
max_date=datetime(2030, 1, 1),
|
||||
null_val_percentage=0.1):
|
||||
self.min_number = min_number
|
||||
self.max_number = max_number
|
||||
self.min_date = min_date
|
||||
self.max_date = max_date
|
||||
self.null_val_percentage = null_val_percentage
|
||||
|
||||
def generate_val(self, val_type):
|
||||
'''Generate and return a single random val. Use the val_type parameter to
|
||||
specify the type of val to generate. See model.DataType for valid val_type
|
||||
options.
|
||||
|
||||
Ex:
|
||||
generator = RandomValGenerator(min_number=1, max_number=5)
|
||||
val = generator.generate_val(model.Int)
|
||||
assert 1 <= val and val <= 5
|
||||
'''
|
||||
if issubclass(val_type, String):
|
||||
val = self.generate_val(Int)
|
||||
return None if val is None else str(val)
|
||||
if random() < self.null_val_percentage:
|
||||
return None
|
||||
if issubclass(val_type, Int):
|
||||
return randint(
|
||||
max(self.min_number, val_type.MIN), min(val_type.MAX, self.max_number))
|
||||
if issubclass(val_type, Number):
|
||||
return uniform(self.min_number, self.max_number)
|
||||
if issubclass(val_type, Timestamp):
|
||||
delta = self.max_date - self.min_date
|
||||
delta_in_seconds = delta.days * 24 * 60 * 60 + delta.seconds
|
||||
offset_in_seconds = randint(0, delta_in_seconds)
|
||||
val = self.min_date + timedelta(0, offset_in_seconds)
|
||||
return datetime(val.year, val.month, val.day)
|
||||
if issubclass(val_type, Boolean):
|
||||
return randint(0, 1) == 1
|
||||
raise Exception('Unsupported type %s' % val_type.__name__)
|
||||
|
||||
|
||||
class DatabasePopulator(object):
|
||||
'''This class will populate a database with randomly generated data. The population
|
||||
includes table creation and data generation. Table names are hard coded as
|
||||
table_<table number>.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self):
|
||||
self.val_generator = RandomValGenerator()
|
||||
|
||||
def populate_db_with_random_data(self,
|
||||
db_name,
|
||||
db_connectors,
|
||||
number_of_tables=10,
|
||||
allowed_data_types=TYPES,
|
||||
create_files=False):
|
||||
'''Create tables with a random number of cols with data types chosen from
|
||||
allowed_data_types, then fill the tables with data.
|
||||
|
||||
The given db_name must have already been created.
|
||||
|
||||
'''
|
||||
connections = [connector.create_connection(db_name=db_name)
|
||||
for connector in db_connectors]
|
||||
for table_idx in xrange(number_of_tables):
|
||||
table = self.create_random_table(
|
||||
'table_%s' % (table_idx + 1),
|
||||
allowed_data_types=allowed_data_types)
|
||||
for connection in connections:
|
||||
sql = self.make_create_table_sql(table, dialect=connection.db_type)
|
||||
LOG.info('Creating %s table %s', connection.db_type, table.name)
|
||||
if create_files:
|
||||
with open('%s_%s.sql' % (table.name, connection.db_type.lower()), 'w') \
|
||||
as f:
|
||||
f.write(sql + '\n')
|
||||
connection.execute(sql)
|
||||
LOG.info('Inserting data into %s', table.name)
|
||||
for _ in xrange(100): # each iteration will insert 100 rows
|
||||
rows = self.generate_table_data(table)
|
||||
for connection in connections:
|
||||
sql = self.make_insert_sql_from_data(
|
||||
table, rows, dialect=connection.db_type)
|
||||
if create_files:
|
||||
with open('%s_%s.sql' %
|
||||
(table.name, connection.db_type.lower()), 'a') as f:
|
||||
f.write(sql + '\n')
|
||||
try:
|
||||
connection.execute(sql)
|
||||
except:
|
||||
LOG.error('Error executing SQL: %s', sql)
|
||||
raise
|
||||
|
||||
self.index_tables_in_database(connections)
|
||||
|
||||
for connection in connections:
|
||||
connection.close()
|
||||
|
||||
def migrate_database(self,
|
||||
db_name,
|
||||
source_db_connector,
|
||||
destination_db_connectors,
|
||||
include_table_names=None):
|
||||
'''Read table metadata and data from the source database and create a replica in
|
||||
the destination databases. For example, the Impala funcal test database could
|
||||
be copied into Postgresql.
|
||||
|
||||
source_db_connector and items in destination_db_connectors should be
|
||||
of type db_connector.DbConnector. destination_db_connectors and
|
||||
include_table_names should be iterables.
|
||||
'''
|
||||
source_connection = source_db_connector.create_connection(db_name)
|
||||
|
||||
cursors = [connector.create_connection(db_name=db_name).create_cursor()
|
||||
for connector in destination_db_connectors]
|
||||
|
||||
for table_name in source_connection.list_table_names():
|
||||
if include_table_names and table_name not in include_table_names:
|
||||
continue
|
||||
try:
|
||||
table = source_connection.describe_table(table_name)
|
||||
except Exception as e:
|
||||
LOG.warn('Error fetching metadata for %s: %s', table_name, e)
|
||||
continue
|
||||
for destination_cursor in cursors:
|
||||
sql = self.make_create_table_sql(
|
||||
table, dialect=destination_cursor.connection.db_type)
|
||||
destination_cursor.execute(sql)
|
||||
with source_connection.open_cursor() as source_cursor:
|
||||
try:
|
||||
source_cursor.execute('SELECT * FROM ' + table_name)
|
||||
while True:
|
||||
rows = source_cursor.fetchmany(size=100)
|
||||
if not rows:
|
||||
break
|
||||
for destination_cursor in cursors:
|
||||
sql = self.make_insert_sql_from_data(
|
||||
table, rows, dialect=destination_cursor.connection.db_type)
|
||||
destination_cursor.execute(sql)
|
||||
except Exception as e:
|
||||
LOG.error('Error fetching data for %s: %s', table_name, e)
|
||||
continue
|
||||
|
||||
self.index_tables_in_database([cursor.connection for cursor in cursors])
|
||||
|
||||
for cursor in cursors:
|
||||
cursor.close()
|
||||
cursor.connection.close()
|
||||
|
||||
def create_random_table(self, table_name, allowed_data_types):
|
||||
'''Create and return a Table with a random number of cols chosen from the
|
||||
given allowed_data_types.
|
||||
|
||||
'''
|
||||
data_type_count = len(allowed_data_types)
|
||||
col_count = randint(data_type_count / 2, data_type_count * 2)
|
||||
table = Table(table_name)
|
||||
for col_idx in xrange(col_count):
|
||||
col_type = choice(allowed_data_types)
|
||||
col = Column(
|
||||
table,
|
||||
'%s_col_%s' % (col_type.__name__.lower(), col_idx + 1),
|
||||
col_type)
|
||||
table.cols.append(col)
|
||||
return table
|
||||
|
||||
def make_create_table_sql(self, table, dialect=IMPALA):
|
||||
sql = 'CREATE TABLE %s (%s)' % (
|
||||
table.name,
|
||||
', '.join('%s %s' %
|
||||
(col.name, self.get_sql_for_data_type(col.type, dialect)) +
|
||||
('' if dialect == IMPALA else ' NULL')
|
||||
for col in table.cols))
|
||||
if dialect == MYSQL:
|
||||
sql += ' ENGINE = MYISAM'
|
||||
return sql
|
||||
|
||||
def get_sql_for_data_type(self, data_type, dialect=IMPALA):
|
||||
# Check to see if there is an alias and if so, use the first one
|
||||
if hasattr(data_type, dialect):
|
||||
return getattr(data_type, dialect)[0]
|
||||
return data_type.__name__.upper()
|
||||
|
||||
def make_insert_sql_from_data(self, table, rows, dialect=IMPALA):
|
||||
# TODO: Consider using parameterized inserts so the database connector handles
|
||||
# formatting the data. For example the CAST to workaround IMPALA-803 can
|
||||
# probably be removed. The vals were generated this way so a data file
|
||||
# could be made and attached to jiras.
|
||||
if not rows:
|
||||
raise Exception('At least one row is required')
|
||||
if not table.cols:
|
||||
raise Exception('At least one col is required')
|
||||
|
||||
sql = 'INSERT INTO %s VALUES ' % table.name
|
||||
for row_idx, row in enumerate(rows):
|
||||
if row_idx > 0:
|
||||
sql += ', '
|
||||
sql += '('
|
||||
for col_idx, col in enumerate(table.cols):
|
||||
if col_idx > 0:
|
||||
sql += ', '
|
||||
val = row[col_idx]
|
||||
if val is None:
|
||||
sql += 'NULL'
|
||||
elif issubclass(col.type, Timestamp):
|
||||
if dialect != IMPALA:
|
||||
sql += 'TIMESTAMP '
|
||||
sql += "'%s'" % val
|
||||
elif issubclass(col.type, String):
|
||||
val = val.replace("'", "''")
|
||||
if dialect == POSTGRESQL:
|
||||
val = val.replace('\\', '\\\\')
|
||||
sql += "'%s'" % val
|
||||
elif dialect == IMPALA \
|
||||
and issubclass(col.type, Float):
|
||||
# https://issues.cloudera.org/browse/IMPALA-803
|
||||
sql += 'CAST(%s AS FLOAT)' % val
|
||||
else:
|
||||
sql += str(val)
|
||||
sql += ')'
|
||||
return sql
|
||||
|
||||
def generate_table_data(self, table, number_of_rows=100):
|
||||
rows = list()
|
||||
for row_idx in xrange(number_of_rows):
|
||||
row = list()
|
||||
for col in table.cols:
|
||||
row.append(self.val_generator.generate_val(col.type))
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def drop_and_create_database(self, db_name, db_connectors):
|
||||
for connector in db_connectors:
|
||||
with connector.open_connection() as connection:
|
||||
connection.drop_db_if_exists(db_name)
|
||||
connection.execute('CREATE DATABASE ' + db_name)
|
||||
|
||||
def index_tables_in_database(self, connections):
|
||||
for connection in connections:
|
||||
if connection.supports_index_creation:
|
||||
for table_name in connection.list_table_names():
|
||||
LOG.info('Indexing %s on %s' % (table_name, connection.db_type))
|
||||
connection.index_table(table_name)
|
||||
|
||||
if __name__ == '__main__':
|
||||
from optparse import NO_DEFAULT, OptionGroup, OptionParser
|
||||
|
||||
parser = OptionParser(
|
||||
usage='usage: \n'
|
||||
' %prog [options] [populate]\n\n'
|
||||
' Create and populate database(s). The Impala database will always be \n'
|
||||
' included, the other database types are optional.\n\n'
|
||||
' %prog [options] migrate\n\n'
|
||||
' Migrate an Impala database to another database type. The destination \n'
|
||||
' database will be dropped and recreated.')
|
||||
parser.add_option('--log-level', default='INFO',
|
||||
help='The log level to use.', choices=('DEBUG', 'INFO', 'WARN', 'ERROR'))
|
||||
parser.add_option('--db-name', default='randomness',
|
||||
help='The name of the database to use. Ex: functional.')
|
||||
|
||||
group = OptionGroup(parser, 'MySQL Options')
|
||||
group.add_option('--use-mysql', action='store_true', default=False,
|
||||
help='Use MySQL')
|
||||
group.add_option('--mysql-host', default='localhost',
|
||||
help='The name of the host running the MySQL database.')
|
||||
group.add_option('--mysql-port', default=3306, type=int,
|
||||
help='The port of the host running the MySQL database.')
|
||||
group.add_option('--mysql-user', default='root',
|
||||
help='The user name to use when connecting to the MySQL database.')
|
||||
group.add_option('--mysql-password',
|
||||
help='The password to use when connecting to the MySQL database.')
|
||||
parser.add_option_group(group)
|
||||
|
||||
group = OptionGroup(parser, 'Postgresql Options')
|
||||
group.add_option('--use-postgresql', action='store_true', default=False,
|
||||
help='Use Postgresql')
|
||||
group.add_option('--postgresql-host', default='localhost',
|
||||
help='The name of the host running the Postgresql database.')
|
||||
group.add_option('--postgresql-port', default=5432, type=int,
|
||||
help='The port of the host running the Postgresql database.')
|
||||
group.add_option('--postgresql-user', default='postgres',
|
||||
help='The user name to use when connecting to the Postgresql database.')
|
||||
group.add_option('--postgresql-password',
|
||||
help='The password to use when connecting to the Postgresql database.')
|
||||
parser.add_option_group(group)
|
||||
|
||||
group = OptionGroup(parser, 'Database Population Options')
|
||||
group.add_option('--randomization-seed', default=1, type='int',
|
||||
help='The randomization will be initialized with this seed. Using the same seed '
|
||||
'will produce the same results across runs.')
|
||||
group.add_option('--create-data-files', default=False, action='store_true',
|
||||
help='Create files that can be used to repopulate the databasese elsewhere.')
|
||||
group.add_option('--table-count', default=10, type='int',
|
||||
help='The number of tables to generate.')
|
||||
parser.add_option_group(group)
|
||||
|
||||
group = OptionGroup(parser, 'Database Migration Options')
|
||||
group.add_option('--migrate-table-names',
|
||||
help='Table names should be separated with commas. The default is to migrate all '
|
||||
'tables.')
|
||||
parser.add_option_group(group)
|
||||
|
||||
for group in parser.option_groups + [parser]:
|
||||
for option in group.option_list:
|
||||
if option.default != NO_DEFAULT:
|
||||
option.help += ' [default: %default]'
|
||||
|
||||
options, args = parser.parse_args()
|
||||
command = args[0] if args else 'populate'
|
||||
if len(args) > 1 or command not in ['populate', 'migrate']:
|
||||
raise Exception('Command must either be "populate" or "migrate" but was "%s"' %
|
||||
' '.join(args))
|
||||
if command == 'migrate' and not any((options.use_mysql, options.use_postgresql)):
|
||||
raise Exception('At least one destination database must be chosen with '
|
||||
'--use-<database type>')
|
||||
|
||||
basicConfig(level=options.log_level)
|
||||
|
||||
seed(options.randomization_seed)
|
||||
|
||||
impala_connector = DbConnector(IMPALA)
|
||||
db_connectors = []
|
||||
if options.use_postgresql:
|
||||
db_connectors.append(DbConnector(POSTGRESQL,
|
||||
user_name=options.postgresql_user,
|
||||
password=options.postgresql_password,
|
||||
host_name=options.postgresql_host,
|
||||
port=options.postgresql_port))
|
||||
if options.use_mysql:
|
||||
db_connectors.append(DbConnector(MYSQL,
|
||||
user_name=options.mysql_user,
|
||||
password=options.mysql_password,
|
||||
host_name=options.mysql_host,
|
||||
port=options.mysql_port))
|
||||
|
||||
populator = DatabasePopulator()
|
||||
if command == 'populate':
|
||||
db_connectors.append(impala_connector)
|
||||
populator.drop_and_create_database(options.db_name, db_connectors)
|
||||
populator.populate_db_with_random_data(
|
||||
options.db_name,
|
||||
db_connectors,
|
||||
number_of_tables=options.table_count,
|
||||
create_files=options.create_data_files)
|
||||
else:
|
||||
populator.drop_and_create_database(options.db_name, db_connectors)
|
||||
if options.migrate_table_names:
|
||||
table_names = options.migrate_table_names.split(',')
|
||||
else:
|
||||
table_names = None
|
||||
populator.migrate_database(
|
||||
options.db_name,
|
||||
impala_connector,
|
||||
db_connectors,
|
||||
include_table_names=table_names)
|
||||
411
tests/comparison/db_connector.py
Normal file
411
tests/comparison/db_connector.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''This module is intended to standardize workflows when working with various databases
|
||||
such as Impala, Postgresql, etc. Even with pep-249 (DB API 2), workflows differ
|
||||
slightly. For example Postgresql does not allow changing databases from within a
|
||||
connection, instead a new connection must be made. However Impala does not allow
|
||||
specifying a database upon connection, instead a cursor must be created and a USE
|
||||
command must be issued.
|
||||
|
||||
'''
|
||||
|
||||
from contextlib import contextmanager
|
||||
try:
|
||||
from impala.dbapi import connect as impala_connect
|
||||
except:
|
||||
print('Error importing impyla. Please make sure it is installed. '
|
||||
'See the README for details.')
|
||||
raise
|
||||
from itertools import izip
|
||||
from logging import getLogger
|
||||
from tests.comparison.model import Column, Table, TYPES, String
|
||||
|
||||
LOG = getLogger(__name__)
|
||||
|
||||
IMPALA = 'IMPALA'
|
||||
POSTGRESQL = 'POSTGRESQL'
|
||||
MYSQL = 'MYSQL'
|
||||
|
||||
DATABASES = [IMPALA, POSTGRESQL, MYSQL]
|
||||
|
||||
mysql_connect = None
|
||||
postgresql_connect = None
|
||||
|
||||
class DbConnector(object):
|
||||
'''Wraps a DB API 2 implementation to provide a standard way of obtaining a
|
||||
connection and selecting a database.
|
||||
|
||||
Any database that supports transactions will have auto-commit enabled.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, db_type, user_name=None, password=None, host_name=None, port=None):
|
||||
self.db_type = db_type.upper()
|
||||
if self.db_type not in DATABASES:
|
||||
raise Exception('Unsupported database: %s' % db_type)
|
||||
self.user_name = user_name
|
||||
self.password = password
|
||||
self.host_name = host_name or 'localhost'
|
||||
self.port = port
|
||||
|
||||
def create_connection(self, db_name=None):
|
||||
if self.db_type == IMPALA:
|
||||
connection_class = ImpalaDbConnection
|
||||
connection = impala_connect(host=self.host_name, port=self.port or 21050)
|
||||
elif self.db_type == POSTGRESQL:
|
||||
connection_class = PostgresqlDbConnection
|
||||
connection_args = {'user': self.user_name or 'postgres'}
|
||||
if self.password:
|
||||
connection_args['password'] = self.password
|
||||
if db_name:
|
||||
connection_args['database'] = db_name
|
||||
if self.host_name:
|
||||
connection_args['host'] = self.host_name
|
||||
if self.port:
|
||||
connection_args['port'] = self.port
|
||||
global postgresql_connect
|
||||
if not postgresql_connect:
|
||||
try:
|
||||
from psycopg2 import connect as postgresql_connect
|
||||
except:
|
||||
print('Error importing psycopg2. Please make sure it is installed. '
|
||||
'See the README for details.')
|
||||
raise
|
||||
connection = postgresql_connect(**connection_args)
|
||||
connection.autocommit = True
|
||||
elif self.db_type == MYSQL:
|
||||
connection_class = MySQLDbConnection
|
||||
connection_args = {'user': self.user_name or 'root'}
|
||||
if self.password:
|
||||
connection_args['passwd'] = self.password
|
||||
if db_name:
|
||||
connection_args['db'] = db_name
|
||||
if self.host_name:
|
||||
connection_args['host'] = self.host_name
|
||||
if self.port:
|
||||
connection_args['port'] = self.port
|
||||
global mysql_connect
|
||||
if not mysql_connect:
|
||||
try:
|
||||
from MySQLdb import connect as mysql_connect
|
||||
except:
|
||||
print('Error importing MySQLdb. Please make sure it is installed. '
|
||||
'See the README for details.')
|
||||
raise
|
||||
connection = mysql_connect(**connection_args)
|
||||
else:
|
||||
raise Exception('Unexpected database type: %s' % self.db_type)
|
||||
return connection_class(self, connection, db_name=db_name)
|
||||
|
||||
@contextmanager
|
||||
def open_connection(self, db_name=None):
|
||||
connection = None
|
||||
try:
|
||||
connection = self.create_connection(db_name=db_name)
|
||||
yield connection
|
||||
finally:
|
||||
if connection:
|
||||
try:
|
||||
connection.close()
|
||||
except Exception as e:
|
||||
LOG.debug('Error closing connection: %s', e, exc_info=True)
|
||||
|
||||
|
||||
class DbConnection(object):
|
||||
'''Wraps a DB API 2 connection. Instances should only be obtained through the
|
||||
DbConnector.create_connection(...) method.
|
||||
|
||||
'''
|
||||
|
||||
@staticmethod
|
||||
def describe_common_tables(db_connections, filter_col_types=[]):
|
||||
'''Find and return a list of Table objects that the given connections have in
|
||||
common.
|
||||
|
||||
@param filter_col_types: Ignore any cols if they are of a data type contained
|
||||
in this collection.
|
||||
|
||||
'''
|
||||
common_table_names = None
|
||||
for db_connection in db_connections:
|
||||
table_names = set(db_connection.list_table_names())
|
||||
if common_table_names is None:
|
||||
common_table_names = table_names
|
||||
else:
|
||||
common_table_names &= table_names
|
||||
common_table_names = sorted(common_table_names)
|
||||
|
||||
tables = list()
|
||||
for table_name in common_table_names:
|
||||
common_table = None
|
||||
mismatch = False
|
||||
for db_connection in db_connections:
|
||||
table = db_connection.describe_table(table_name)
|
||||
table.cols = [col for col in table.cols if col.type not in filter_col_types]
|
||||
if common_table is None:
|
||||
common_table = table
|
||||
continue
|
||||
if len(common_table.cols) != len(table.cols):
|
||||
LOG.debug('Ignoring table %s.'
|
||||
' It has a different number of columns across databases.', table_name)
|
||||
mismatch = True
|
||||
break
|
||||
for left, right in izip(common_table.cols, table.cols):
|
||||
if not left.name == right.name and left.type == right.type:
|
||||
LOG.debug('Ignoring table %s. It has different columns %s vs %s.' %
|
||||
(table_name, left, right))
|
||||
mismatch = True
|
||||
break
|
||||
if mismatch:
|
||||
break
|
||||
if not mismatch:
|
||||
tables.append(common_table)
|
||||
|
||||
return tables
|
||||
|
||||
def __init__(self, connector, connection, db_name=None):
|
||||
self.connector = connector
|
||||
self.connection = connection
|
||||
self.db_name = db_name
|
||||
|
||||
@property
|
||||
def db_type(self):
|
||||
return self.connector.db_type
|
||||
|
||||
def create_cursor(self):
|
||||
return DatabaseCursor(self.connection.cursor(), self)
|
||||
|
||||
@contextmanager
|
||||
def open_cursor(self):
|
||||
'''Returns a new cursor for use in a "with" statement. When the "with" statement ends,
|
||||
the cursor will be closed.
|
||||
|
||||
'''
|
||||
cursor = None
|
||||
try:
|
||||
cursor = self.create_cursor()
|
||||
yield cursor
|
||||
finally:
|
||||
self.close_cursor_quietly(cursor)
|
||||
|
||||
def close_cursor_quietly(self, cursor):
|
||||
if cursor:
|
||||
try:
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
LOG.debug('Error closing cursor: %s', e, exc_info=True)
|
||||
|
||||
def list_db_names(self):
|
||||
'''Return a list of database names always in lowercase.'''
|
||||
rows = self.execute_and_fetchall(self.make_list_db_names_sql())
|
||||
return [row[0].lower() for row in rows]
|
||||
|
||||
def make_list_db_names_sql(self):
|
||||
return 'SHOW DATABASES'
|
||||
|
||||
def list_table_names(self):
|
||||
'''Return a list of table names always in lowercase.'''
|
||||
rows = self.execute_and_fetchall(self.make_list_table_names_sql())
|
||||
return [row[0].lower() for row in rows]
|
||||
|
||||
def make_list_table_names_sql(self):
|
||||
return 'SHOW TABLES'
|
||||
|
||||
def describe_table(self, table_name):
|
||||
'''Return a Table with table and col names always in lowercase.'''
|
||||
rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name))
|
||||
table = Table(table_name.lower())
|
||||
for row in rows:
|
||||
col_name, data_type = row[:2]
|
||||
table.cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type)))
|
||||
return table
|
||||
|
||||
def make_describe_table_sql(self, table_name):
|
||||
return 'DESCRIBE ' + table_name
|
||||
|
||||
def parse_data_type(self, sql):
|
||||
sql = sql.upper()
|
||||
# Types may have declared a database specific alias
|
||||
for type_ in TYPES:
|
||||
if sql in getattr(type_, self.db_type, []):
|
||||
return type_
|
||||
for type_ in TYPES:
|
||||
if type_.__name__.upper() == sql:
|
||||
return type_
|
||||
if 'CHAR' in sql:
|
||||
return String
|
||||
raise Exception('Unknown data type: ' + sql)
|
||||
|
||||
def create_database(self, db_name):
|
||||
db_name = db_name.lower()
|
||||
with self.open_cursor() as cursor:
|
||||
cursor.execute('CREATE DATABASE ' + db_name)
|
||||
|
||||
def drop_db_if_exists(self, db_name):
|
||||
'''This should not be called from a connection to the database being dropped.'''
|
||||
db_name = db_name.lower()
|
||||
if db_name not in self.list_db_names():
|
||||
return
|
||||
if self.db_name and self.db_name.lower() == db_name:
|
||||
raise Exception('Cannot drop database while still connected to it')
|
||||
self.drop_database(db_name)
|
||||
|
||||
def drop_database(self, db_name):
|
||||
db_name = db_name.lower()
|
||||
self.execute('DROP DATABASE ' + db_name)
|
||||
|
||||
@property
|
||||
def supports_index_creation(self):
|
||||
return True
|
||||
|
||||
def index_table(self, table_name):
|
||||
table = self.describe_table(table_name)
|
||||
with self.open_cursor() as cursor:
|
||||
for col in table.cols:
|
||||
index_name = '%s_%s' % (table_name, col.name)
|
||||
if self.db_name:
|
||||
index_name = '%s_%s' % (self.db_name, index_name)
|
||||
cursor.execute('CREATE INDEX %s ON %s(%s)' % (index_name, table_name, col.name))
|
||||
|
||||
@property
|
||||
def supports_kill_connection(self):
|
||||
return False
|
||||
|
||||
def kill_connection(self):
|
||||
'''Kill the current connection and any currently running queries assosiated with the
|
||||
connection.
|
||||
'''
|
||||
raise Exception('Killing connection is not supported')
|
||||
|
||||
def materialize_query(self, query_as_text, table_name):
|
||||
self.execute('CREATE TABLE %s AS %s' % (table_name.lower(), query_as_text))
|
||||
|
||||
def drop_table(self, table_name):
|
||||
self.execute('DROP TABLE ' + table_name.lower())
|
||||
|
||||
def execute(self, sql):
|
||||
with self.open_cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
|
||||
def execute_and_fetchall(self, sql):
|
||||
with self.open_cursor() as cursor:
|
||||
cursor.execute(sql)
|
||||
return cursor.fetchall()
|
||||
|
||||
def close(self):
|
||||
'''Close the underlying connection.'''
|
||||
self.connection.close()
|
||||
|
||||
def reconnect(self):
|
||||
self.close()
|
||||
other = self.connector.create_connection(db_name=self.db_name)
|
||||
self.connection = other.connection
|
||||
|
||||
|
||||
class DatabaseCursor(object):
|
||||
'''Wraps a DB API 2 cursor to provide access to the related connection. This class
|
||||
implements the DB API 2 interface by delegation.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, cursor, connection):
|
||||
self.cursor = cursor
|
||||
self.connection = connection
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.cursor, attr)
|
||||
|
||||
|
||||
class ImpalaDbConnection(DbConnection):
|
||||
|
||||
def create_cursor(self):
|
||||
cursor = DbConnection.create_cursor(self)
|
||||
if self.db_name:
|
||||
cursor.execute('USE %s' % self.db_name)
|
||||
return cursor
|
||||
|
||||
def drop_database(self, db_name):
|
||||
'''This should not be called from a connection to the database being dropped.'''
|
||||
db_name = db_name.lower()
|
||||
with self.connector.open_connection(db_name) as list_tables_connection:
|
||||
with list_tables_connection.open_cursor() as drop_table_cursor:
|
||||
for table_name in list_tables_connection.list_table_names():
|
||||
drop_table_cursor.execute('DROP TABLE ' + table_name)
|
||||
self.execute('DROP DATABASE ' + db_name)
|
||||
|
||||
@property
|
||||
def supports_index_creation(self):
|
||||
return False
|
||||
|
||||
|
||||
class PostgresqlDbConnection(DbConnection):
|
||||
|
||||
def make_list_db_names_sql(self):
|
||||
return 'SELECT datname FROM pg_database'
|
||||
|
||||
def make_list_table_names_sql(self):
|
||||
return '''
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public' '''
|
||||
|
||||
def make_describe_table_sql(self, table_name):
|
||||
return '''
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = '%s'
|
||||
ORDER BY ordinal_position''' % table_name
|
||||
|
||||
|
||||
class MySQLDbConnection(DbConnection):
|
||||
|
||||
def __init__(self, connector, connection, db_name=None):
|
||||
DbConnection.__init__(self, connector, connection, db_name=db_name)
|
||||
self.session_id = self.execute_and_fetchall('SELECT connection_id()')[0][0]
|
||||
|
||||
def describe_table(self, table_name):
|
||||
'''Return a Table with table and col names always in lowercase.'''
|
||||
rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name))
|
||||
table = Table(table_name.lower())
|
||||
for row in rows:
|
||||
col_name, data_type = row[:2]
|
||||
if data_type == 'tinyint(1)':
|
||||
# Just assume this is a boolean...
|
||||
data_type = 'boolean'
|
||||
if '(' in data_type:
|
||||
# Strip the size of the data type
|
||||
data_type = data_type[:data_type.index('(')]
|
||||
table.cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type)))
|
||||
return table
|
||||
|
||||
@property
|
||||
def supports_kill_connection(self):
|
||||
return True
|
||||
|
||||
def kill_connection(self):
|
||||
with self.connector.open_connection(db_name=self.db_name) as connection:
|
||||
connection.execute('KILL %s' % (self.session_id))
|
||||
|
||||
def index_table(self, table_name):
|
||||
table = self.describe_table(table_name)
|
||||
with self.open_cursor() as cursor:
|
||||
for col in table.cols:
|
||||
try:
|
||||
cursor.execute('ALTER TABLE %s ADD INDEX (%s)' % (table_name, col.name))
|
||||
except Exception as e:
|
||||
if 'Incorrect index name' not in str(e):
|
||||
raise
|
||||
# Some sort of MySQL bug...
|
||||
LOG.warn('Could not create index on %s.%s: %s' % (table_name, col.name, e))
|
||||
529
tests/comparison/discrepancy_searcher.py
Executable file
529
tests/comparison/discrepancy_searcher.py
Executable file
@@ -0,0 +1,529 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''This module will run random queries against existing databases and compare the
|
||||
results.
|
||||
|
||||
'''
|
||||
|
||||
from contextlib import closing
|
||||
from decimal import Decimal
|
||||
from itertools import izip, izip_longest
|
||||
from logging import basicConfig, getLogger
|
||||
from math import isinf, isnan
|
||||
from os import getenv, remove
|
||||
from os.path import exists, join
|
||||
from shelve import open as open_shelve
|
||||
from subprocess import call
|
||||
from threading import current_thread, Thread
|
||||
from tempfile import gettempdir
|
||||
from time import time
|
||||
|
||||
from tests.comparison.db_connector import (
|
||||
DbConnection,
|
||||
DbConnector,
|
||||
IMPALA,
|
||||
MYSQL,
|
||||
POSTGRESQL)
|
||||
from tests.comparison.model import BigInt, TYPES
|
||||
from tests.comparison.query_generator import QueryGenerator
|
||||
from tests.comparison.model_translator import SqlWriter
|
||||
|
||||
LOG = getLogger(__name__)
|
||||
|
||||
class QueryResultComparator(object):
|
||||
|
||||
# If the number of rows * cols is greater than this val, then the comparison will
|
||||
# be aborted. Raising this value also raises the risk of python being OOM killed. At
|
||||
# 10M python would get OOM killed occasionally even on a physical machine with 32GB
|
||||
# ram.
|
||||
TOO_MUCH_DATA = 1000 * 1000
|
||||
|
||||
# Used when comparing float vals
|
||||
EPSILON = 0.1
|
||||
|
||||
# The decimal vals will be rounded before comparison
|
||||
DECIMAL_PLACES = 2
|
||||
|
||||
def __init__(self, impala_connection, reference_connection):
|
||||
self.reference_db_type = reference_connection.db_type
|
||||
|
||||
self.impala_cursor = impala_connection.create_cursor()
|
||||
self.reference_cursor = reference_connection.create_cursor()
|
||||
|
||||
self.impala_sql_writer = SqlWriter.create(dialect=impala_connection.db_type)
|
||||
self.reference_sql_writer = SqlWriter.create(dialect=reference_connection.db_type)
|
||||
|
||||
# At this time the connection will be killed and ther comparison result will be
|
||||
# timeout.
|
||||
self.query_timeout_seconds = 3 * 60
|
||||
|
||||
def compare_query_results(self, query):
|
||||
'''Execute the query, compare the data, and return a summary of the result.'''
|
||||
comparison_result = ComparisonResult(query, self.reference_db_type)
|
||||
|
||||
reference_data_set = None
|
||||
impala_data_set = None
|
||||
# Impala doesn't support getting the row count without getting the rows too. So run
|
||||
# the query on the other database first.
|
||||
try:
|
||||
for sql_writer, cursor in ((self.reference_sql_writer, self.reference_cursor),
|
||||
(self.impala_sql_writer, self.impala_cursor)):
|
||||
self.execute_query(cursor, sql_writer.write_query(query))
|
||||
if (cursor.rowcount * len(query.select_clause.select_items)) > self.TOO_MUCH_DATA:
|
||||
comparison_result.exception = Exception('Too much data to compare')
|
||||
return comparison_result
|
||||
if reference_data_set is None:
|
||||
# MySQL returns a tuple of rows but a list is needed for sorting
|
||||
reference_data_set = list(cursor.fetchall())
|
||||
comparison_result.reference_row_count = len(reference_data_set)
|
||||
else:
|
||||
impala_data_set = cursor.fetchall()
|
||||
comparison_result.impala_row_count = len(impala_data_set)
|
||||
except Exception as e:
|
||||
comparison_result.exception = e
|
||||
LOG.debug('Error running query: %s', e, exc_info=True)
|
||||
return comparison_result
|
||||
|
||||
comparison_result.query_resulted_in_data = (comparison_result.impala_row_count > 0
|
||||
or comparison_result.reference_row_count > 0)
|
||||
|
||||
if comparison_result.impala_row_count != comparison_result.reference_row_count:
|
||||
return comparison_result
|
||||
|
||||
for data_set in (reference_data_set, impala_data_set):
|
||||
for row_idx, row in enumerate(data_set):
|
||||
data_set[row_idx] = [self.standardize_data(data) for data in row]
|
||||
data_set.sort(cmp=self.row_sort_cmp)
|
||||
|
||||
for impala_row, reference_row in \
|
||||
izip_longest(impala_data_set, reference_data_set):
|
||||
for col_idx, (impala_val, reference_val) \
|
||||
in enumerate(izip_longest(impala_row, reference_row)):
|
||||
if not self.vals_are_equal(impala_val, reference_val):
|
||||
if isinstance(impala_val, int) \
|
||||
and isinstance(reference_val, (int, float, Decimal)) \
|
||||
and abs(reference_val) > BigInt.MAX:
|
||||
# Impala will return incorrect results if the val is greater than max bigint
|
||||
comparison_result.exception = KnownError(
|
||||
'https://issues.cloudera.org/browse/IMPALA-865')
|
||||
elif isinstance(impala_val, float) \
|
||||
and (isinf(impala_val) or isnan(impala_val)):
|
||||
# In some cases, Impala gives NaNs and Infs instead of NULLs
|
||||
comparison_result.exception = KnownError(
|
||||
'https://issues.cloudera.org/browse/IMPALA-724')
|
||||
comparison_result.impala_row = impala_row
|
||||
comparison_result.reference_row = reference_row
|
||||
comparison_result.mismatch_at_row_number = row_idx + 1
|
||||
comparison_result.mismatch_at_col_number = col_idx + 1
|
||||
return comparison_result
|
||||
|
||||
if len(impala_data_set) == 1:
|
||||
for val in impala_data_set[0]:
|
||||
if val:
|
||||
break
|
||||
else:
|
||||
comparison_result.query_resulted_in_data = False
|
||||
|
||||
return comparison_result
|
||||
|
||||
def execute_query(self, cursor, sql):
|
||||
'''Execute the query and throw a timeout if needed.'''
|
||||
def _execute_query():
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
except Exception as e:
|
||||
current_thread().exception = e
|
||||
query_thread = Thread(target=_execute_query, name='Query execution thread')
|
||||
query_thread.daemon = True
|
||||
query_thread.start()
|
||||
query_thread.join(self.query_timeout_seconds)
|
||||
if query_thread.is_alive():
|
||||
if cursor.connection.supports_kill_connection:
|
||||
LOG.debug('Attempting to kill connection')
|
||||
cursor.connection.kill_connection()
|
||||
LOG.debug('Kill connection')
|
||||
cursor.close()
|
||||
cursor.connection.close()
|
||||
cursor = cursor\
|
||||
.connection\
|
||||
.connector\
|
||||
.create_connection(db_name=cursor.connection.db_name)\
|
||||
.create_cursor()
|
||||
if cursor.connection.db_type == IMPALA:
|
||||
self.impala_cursor = cursor
|
||||
else:
|
||||
self.reference_cursor = cursor
|
||||
raise QueryTimeout('Query timed out after %s seconds' % self.query_timeout_seconds)
|
||||
if hasattr(query_thread, 'exception'):
|
||||
raise query_thread.exception
|
||||
|
||||
def standardize_data(self, data):
|
||||
'''Return a val that is suitable for comparison.'''
|
||||
# For float data we need to round otherwise differences in precision will cause errors
|
||||
if isinstance(data, (float, Decimal)):
|
||||
return round(data, self.DECIMAL_PLACES)
|
||||
return data
|
||||
|
||||
def row_sort_cmp(self, left_row, right_row):
|
||||
for left, right in izip(left_row, right_row):
|
||||
if left is None and right is not None:
|
||||
return -1
|
||||
if left is not None and right is None:
|
||||
return 1
|
||||
result = cmp(left, right)
|
||||
if result:
|
||||
return result
|
||||
return 0
|
||||
|
||||
def vals_are_equal(self, left, right):
|
||||
if left == right:
|
||||
return True
|
||||
if isinstance(left, (int, float, Decimal)) and \
|
||||
isinstance(right, (int, float, Decimal)):
|
||||
return self.floats_are_equal(left, right)
|
||||
return False
|
||||
|
||||
def floats_are_equal(self, left, right):
|
||||
left = round(left, self.DECIMAL_PLACES)
|
||||
right = round(right, self.DECIMAL_PLACES)
|
||||
diff = abs(left - right)
|
||||
if left * right == 0:
|
||||
return diff < self.EPSILON
|
||||
return diff / (abs(left) + abs(right)) < self.EPSILON
|
||||
|
||||
|
||||
class ComparisonResult(object):
|
||||
|
||||
def __init__(self, query, reference_db_type):
|
||||
self.query = query
|
||||
self.reference_db_type = reference_db_type
|
||||
self.query_resulted_in_data = False
|
||||
self.impala_row_count = None
|
||||
self.reference_row_count = None
|
||||
self.mismatch_at_row_number = None
|
||||
self.mismatch_at_col_number = None
|
||||
self.impala_row = None
|
||||
self.reference_row = None
|
||||
self.exception = None
|
||||
self._error_message = None
|
||||
|
||||
@property
|
||||
def error(self):
|
||||
if not self._error_message:
|
||||
if self.exception:
|
||||
self._error_message = str(self.exception)
|
||||
elif self.impala_row_count and \
|
||||
self.impala_row_count != self.reference_row_count:
|
||||
self._error_message = 'Row counts do not match: %s Impala rows vs %s %s rows' \
|
||||
% (self.impala_row_count,
|
||||
self.reference_db_type,
|
||||
self.reference_row_count)
|
||||
elif self.mismatch_at_row_number is not None:
|
||||
# Write a row like "[a, b, <<c>>, d]" where c is a bad value
|
||||
impala_row = '[' + ', '.join(
|
||||
'<<' + str(val) + '>>' if idx == self.mismatch_at_col_number - 1 else str(val)
|
||||
for idx, val in enumerate(self.impala_row)
|
||||
) + ']'
|
||||
reference_row = '[' + ', '.join(
|
||||
'<<' + str(val) + '>>' if idx == self.mismatch_at_col_number - 1 else str(val)
|
||||
for idx, val in enumerate(self.reference_row)
|
||||
) + ']'
|
||||
self._error_message = \
|
||||
'Column %s in row %s does not match: %s Impala row vs %s %s row' \
|
||||
% (self.mismatch_at_col_number,
|
||||
self.mismatch_at_row_number,
|
||||
impala_row,
|
||||
reference_row,
|
||||
self.reference_db_type)
|
||||
return self._error_message
|
||||
|
||||
@property
|
||||
def is_known_error(self):
|
||||
return isinstance(self.exception, KnownError)
|
||||
|
||||
@property
|
||||
def query_timed_out(self):
|
||||
return isinstance(self.exception, QueryTimeout)
|
||||
|
||||
|
||||
class QueryTimeout(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class KnownError(Exception):
|
||||
|
||||
def __init__(self, jira_url):
|
||||
Exception.__init__(self, 'Known issue: ' + jira_url)
|
||||
self.jira_url = jira_url
|
||||
|
||||
|
||||
class QueryResultDiffSearcher(object):
|
||||
|
||||
# Sometimes things get into a bad state and the same error loops forever
|
||||
ABORT_ON_REPEAT_ERROR_COUNT = 2
|
||||
|
||||
def __init__(self, impala_connection, reference_connection, filter_col_types=[]):
|
||||
self.impala_connection = impala_connection
|
||||
self.reference_connection = reference_connection
|
||||
self.common_tables = DbConnection.describe_common_tables(
|
||||
[impala_connection, reference_connection],
|
||||
filter_col_types=filter_col_types)
|
||||
|
||||
# A file-backed dict of queries that produced a discrepancy, keyed by query number
|
||||
# (in string form, as required by the dict).
|
||||
self.query_shelve_path = gettempdir() + '/query.shelve'
|
||||
|
||||
# A list of all queries attempted
|
||||
self.query_log_path = gettempdir() + '/impala_query_log.sql'
|
||||
|
||||
def search(self, number_of_test_queries, stop_on_result_mismatch, stop_on_crash):
|
||||
if exists(self.query_shelve_path):
|
||||
# Ensure a clean shelve will be created
|
||||
remove(self.query_shelve_path)
|
||||
|
||||
start_time = time()
|
||||
impala_sql_writer = SqlWriter.create(dialect=IMPALA)
|
||||
reference_sql_writer = SqlWriter.create(
|
||||
dialect=self.reference_connection.db_type)
|
||||
query_result_comparator = QueryResultComparator(
|
||||
self.impala_connection, self.reference_connection)
|
||||
query_generator = QueryGenerator()
|
||||
query_count = 0
|
||||
queries_resulted_in_data_count = 0
|
||||
mismatch_count = 0
|
||||
query_timeout_count = 0
|
||||
known_error_count = 0
|
||||
impala_crash_count = 0
|
||||
last_error = None
|
||||
repeat_error_count = 0
|
||||
with open(self.query_log_path, 'w') as impala_query_log:
|
||||
impala_query_log.write(
|
||||
'--\n'
|
||||
'-- Stating new run\n'
|
||||
'--\n')
|
||||
while number_of_test_queries > query_count:
|
||||
query = query_generator.create_query(self.common_tables)
|
||||
impala_sql = impala_sql_writer.write_query(query)
|
||||
if 'FULL OUTER JOIN' in impala_sql and self.reference_connection.db_type == MYSQL:
|
||||
# Not supported by MySQL
|
||||
continue
|
||||
|
||||
query_count += 1
|
||||
LOG.info('Running query #%s', query_count)
|
||||
impala_query_log.write(impala_sql + ';\n')
|
||||
result = query_result_comparator.compare_query_results(query)
|
||||
if result.query_resulted_in_data:
|
||||
queries_resulted_in_data_count += 1
|
||||
if result.error:
|
||||
# TODO: These first two come from psycopg2, the postgres driver. Maybe we should
|
||||
# try a different driver? Or maybe the usage of the driver isn't correct.
|
||||
# Anyhow ignore these failures.
|
||||
if 'division by zero' in result.error \
|
||||
or 'out of range' in result.error \
|
||||
or 'Too much data' in result.error:
|
||||
LOG.debug('Ignoring error: %s', result.error)
|
||||
query_count -= 1
|
||||
continue
|
||||
|
||||
if result.is_known_error:
|
||||
known_error_count += 1
|
||||
elif result.query_timed_out:
|
||||
query_timeout_count += 1
|
||||
else:
|
||||
mismatch_count += 1
|
||||
with closing(open_shelve(self.query_shelve_path)) as query_shelve:
|
||||
query_shelve[str(query_count)] = query
|
||||
|
||||
print('---Impala Query---\n')
|
||||
print(impala_sql_writer.write_query(query, pretty=True) + '\n')
|
||||
print('---Reference Query---\n')
|
||||
print(reference_sql_writer.write_query(query, pretty=True) + '\n')
|
||||
print('---Error---\n')
|
||||
print(result.error + '\n')
|
||||
print('------\n')
|
||||
|
||||
if 'Could not connect' in result.error \
|
||||
or "Couldn't open transport for" in result.error:
|
||||
# if stop_on_crash:
|
||||
# break
|
||||
# Assume Impala crashed and try restarting
|
||||
impala_crash_count += 1
|
||||
LOG.info('Restarting Impala')
|
||||
call([join(getenv('IMPALA_HOME'), 'bin/start-impala-cluster.py'),
|
||||
'--log_dir=%s' % getenv('LOG_DIR', "/tmp/")])
|
||||
self.impala_connection.reconnect()
|
||||
query_result_comparator.impala_cursor = self.impala_connection.create_cursor()
|
||||
result = query_result_comparator.compare_query_results(query)
|
||||
if result.error:
|
||||
LOG.info('Restarting Impala')
|
||||
call([join(getenv('IMPALA_HOME'), 'bin/start-impala-cluster.py'),
|
||||
'--log_dir=%s' % getenv('LOG_DIR', "/tmp/")])
|
||||
self.impala_connection.reconnect()
|
||||
query_result_comparator.impala_cursor = self.impala_connection.create_cursor()
|
||||
else:
|
||||
break
|
||||
|
||||
if stop_on_result_mismatch and \
|
||||
not (result.is_known_error or result.query_timed_out):
|
||||
break
|
||||
|
||||
if last_error == result.error \
|
||||
and not (result.is_known_error or result.query_timed_out):
|
||||
repeat_error_count += 1
|
||||
if repeat_error_count == self.ABORT_ON_REPEAT_ERROR_COUNT:
|
||||
break
|
||||
else:
|
||||
last_error = result.error
|
||||
repeat_error_count = 0
|
||||
else:
|
||||
if result.query_resulted_in_data:
|
||||
LOG.info('Results matched (%s rows)', result.impala_row_count)
|
||||
else:
|
||||
LOG.info('Query did not produce meaningful data')
|
||||
last_error = None
|
||||
repeat_error_count = 0
|
||||
|
||||
return SearchResults(
|
||||
query_count,
|
||||
queries_resulted_in_data_count,
|
||||
mismatch_count,
|
||||
query_timeout_count,
|
||||
known_error_count,
|
||||
impala_crash_count,
|
||||
time() - start_time)
|
||||
|
||||
|
||||
class SearchResults(object):
|
||||
'''This class holds information about the outcome of a search run.'''
|
||||
|
||||
def __init__(self,
|
||||
query_count,
|
||||
queries_resulted_in_data_count,
|
||||
mismatch_count,
|
||||
query_timeout_count,
|
||||
known_error_count,
|
||||
impala_crash_count,
|
||||
run_time_in_seconds):
|
||||
# Approx number of queries run, some queries may have been ignored
|
||||
self.query_count = query_count
|
||||
self.queries_resulted_in_data_count = queries_resulted_in_data_count
|
||||
# Number of queries that had an error or result mismatch
|
||||
self.mismatch_count = mismatch_count
|
||||
self.query_timeout_count = query_timeout_count
|
||||
self.known_error_count = known_error_count
|
||||
self.impala_crash_count = impala_crash_count
|
||||
self.run_time_in_seconds = run_time_in_seconds
|
||||
|
||||
def get_summary_text(self):
|
||||
mins, secs = divmod(self.run_time_in_seconds, 60)
|
||||
hours, mins = divmod(mins, 60)
|
||||
hours = int(hours)
|
||||
mins = int(mins)
|
||||
if hours:
|
||||
run_time = '%s hour and %s minutes' % (hours, mins)
|
||||
else:
|
||||
secs = int(secs)
|
||||
run_time = '%s seconds' % secs
|
||||
if mins:
|
||||
run_time = '%s mins and ' % mins + run_time
|
||||
summary_params = self.__dict__
|
||||
summary_params['run_time'] = run_time
|
||||
return (
|
||||
'%(mismatch_count)s mismatches found after running %(query_count)s queries in '
|
||||
'%(run_time)s.\n'
|
||||
'%(queries_resulted_in_data_count)s of %(query_count)s queries produced results.'
|
||||
'\n'
|
||||
'%(impala_crash_count)s Impala crashes occurred.\n'
|
||||
'%(known_error_count)s queries were excluded from the mismatch count because '
|
||||
'they are known errors.\n'
|
||||
'%(query_timeout_count)s queries timed out and were excluded from all counts.') \
|
||||
% summary_params
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
from optparse import NO_DEFAULT, OptionGroup, OptionParser
|
||||
|
||||
parser = OptionParser()
|
||||
parser.add_option('--log-level', default='INFO',
|
||||
help='The log level to use.', choices=('DEBUG', 'INFO', 'WARN', 'ERROR'))
|
||||
parser.add_option('--db-name', default='randomness',
|
||||
help='The name of the database to use. Ex: funcal.')
|
||||
|
||||
parser.add_option('--reference-db-type', default=MYSQL, choices=(MYSQL, POSTGRESQL),
|
||||
help='The type of the reference database to use. Ex: MYSQL.')
|
||||
parser.add_option('--stop-on-mismatch', default=False, action='store_true',
|
||||
help='Exit immediately upon find a discrepancy in a query result.')
|
||||
parser.add_option('--stop-on-crash', default=False, action='store_true',
|
||||
help='Exit immediately if Impala crashes.')
|
||||
parser.add_option('--query-count', default=1000, type=int,
|
||||
help='Exit after running the given number of queries.')
|
||||
parser.add_option('--exclude-types', default='Double,Float,TinyInt',
|
||||
help='A comma separated list of data types to exclude while generating queries.')
|
||||
|
||||
group = OptionGroup(parser, 'MySQL Options')
|
||||
group.add_option('--mysql-host', default='localhost',
|
||||
help='The name of the host running the MySQL database.')
|
||||
group.add_option('--mysql-port', default=3306, type=int,
|
||||
help='The port of the host running the MySQL database.')
|
||||
group.add_option('--mysql-user', default='root',
|
||||
help='The user name to use when connecting to the MySQL database.')
|
||||
group.add_option('--mysql-password',
|
||||
help='The password to use when connecting to the MySQL database.')
|
||||
parser.add_option_group(group)
|
||||
|
||||
group = OptionGroup(parser, 'Postgresql Options')
|
||||
group.add_option('--postgresql-host', default='localhost',
|
||||
help='The name of the host running the Postgresql database.')
|
||||
group.add_option('--postgresql-port', default=5432, type=int,
|
||||
help='The port of the host running the Postgresql database.')
|
||||
group.add_option('--postgresql-user', default='postgres',
|
||||
help='The user name to use when connecting to the Postgresql database.')
|
||||
group.add_option('--postgresql-password',
|
||||
help='The password to use when connecting to the Postgresql database.')
|
||||
parser.add_option_group(group)
|
||||
|
||||
for group in parser.option_groups + [parser]:
|
||||
for option in group.option_list:
|
||||
if option.default != NO_DEFAULT:
|
||||
option.help += " [default: %default]"
|
||||
|
||||
options, args = parser.parse_args()
|
||||
|
||||
basicConfig(level=options.log_level)
|
||||
|
||||
impala_connection = DbConnector(IMPALA).create_connection(options.db_name)
|
||||
db_connector_param_key = options.reference_db_type.lower()
|
||||
reference_connection = DbConnector(options.reference_db_type,
|
||||
user_name=getattr(options, db_connector_param_key + '_user'),
|
||||
password=getattr(options, db_connector_param_key + '_password'),
|
||||
host_name=getattr(options, db_connector_param_key + '_host'),
|
||||
port=getattr(options, db_connector_param_key + '_port')) \
|
||||
.create_connection(options.db_name)
|
||||
if options.exclude_types:
|
||||
exclude_types = set(type_name.lower() for type_name
|
||||
in options.exclude_types.split(','))
|
||||
filter_col_types = [type_ for type_ in TYPES
|
||||
if type_.__name__.lower() in exclude_types]
|
||||
else:
|
||||
filter_col_types = []
|
||||
diff_searcher = QueryResultDiffSearcher(
|
||||
impala_connection, reference_connection, filter_col_types=filter_col_types)
|
||||
search_results = diff_searcher.search(
|
||||
options.query_count, options.stop_on_mismatch, options.stop_on_crash)
|
||||
print(search_results.get_summary_text())
|
||||
sys.exit(search_results.mismatch_count)
|
||||
741
tests/comparison/model.py
Normal file
741
tests/comparison/model.py
Normal file
@@ -0,0 +1,741 @@
|
||||
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
class Query(object):
|
||||
'''A representation of the stucture of a SQL query. Only the select_clause and
|
||||
from_clause are required for a valid query.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, select_clause, from_clause):
|
||||
self.with_clause = None
|
||||
self.select_clause = select_clause
|
||||
self.from_clause = from_clause
|
||||
self.where_clause = None
|
||||
self.group_by_clause = None
|
||||
self.having_clause = None
|
||||
self.union_clause = None
|
||||
|
||||
@property
|
||||
def table_exprs(self):
|
||||
'''Provides a list of all table_exprs that are declared by this query. This
|
||||
includes table_exprs in the WITH and FROM sections.
|
||||
'''
|
||||
table_exprs = self.from_clause.table_exprs
|
||||
if self.with_clause:
|
||||
table_exprs += self.with_clause.table_exprs
|
||||
return table_exprs
|
||||
|
||||
|
||||
class SelectClause(object):
|
||||
'''This encapuslates the SELECT part of a query. It is convenient to separate
|
||||
non-agg items from agg items so that it is simple to know if the query
|
||||
is an agg query or not.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, non_agg_items=None, agg_items=None):
|
||||
self.non_agg_items = non_agg_items or list()
|
||||
self.agg_items = agg_items or list()
|
||||
self.distinct = False
|
||||
|
||||
@property
|
||||
def select_items(self):
|
||||
'''Provides a consolidated view of all select items.'''
|
||||
return self.non_agg_items + self.agg_items
|
||||
|
||||
|
||||
class SelectItem(object):
|
||||
'''A representation of any possible expr than would be valid in
|
||||
|
||||
SELECT <SelectItem>[, <SelectItem>...] FROM ...
|
||||
|
||||
Each SelectItem contains a ValExpr which will either be a instance of a
|
||||
DataType (representing a constant), a Column, or a Func.
|
||||
|
||||
Ex: "SELECT int_col + smallint_col FROM alltypes" would have a val_expr of
|
||||
Plus(Column(<alltypes.int_col>), Column(<alltypes.smallint_col>)).
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, val_expr, alias=None):
|
||||
self.val_expr = val_expr
|
||||
self.alias = alias
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
'''Returns the DataType of this item.'''
|
||||
return self.val_expr.type
|
||||
|
||||
@property
|
||||
def is_agg(self):
|
||||
'''Evaluates to True if this item contains an aggregate expression.'''
|
||||
return self.val_expr.is_agg
|
||||
|
||||
|
||||
class ValExpr(object):
|
||||
'''This is an AbstractClass that represents a generic expr that results in a
|
||||
scalar. The abc module was not used because it caused problems for the pickle
|
||||
module.
|
||||
|
||||
'''
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
'''This is declared for documentations purposes, subclasses should override this to
|
||||
return the DataType that this expr represents.
|
||||
'''
|
||||
pass
|
||||
|
||||
@property
|
||||
def base_type(self):
|
||||
'''Return the most fundemental data type that the expr evaluates to. Only
|
||||
numeric types will result in a different val than would be returned by self.type.
|
||||
|
||||
Ex:
|
||||
if self.type == BigInt:
|
||||
assert self.base_type == Int
|
||||
if self.type == Double:
|
||||
assert self.base_type == Float
|
||||
if self.type == String:
|
||||
assert self.base_type == self.type
|
||||
'''
|
||||
if self.returns_int:
|
||||
return Int
|
||||
if self.returns_float:
|
||||
return Float
|
||||
return self.type
|
||||
|
||||
@property
|
||||
def is_func(self):
|
||||
return isinstance(self, Func)
|
||||
|
||||
@property
|
||||
def is_agg(self):
|
||||
'''Evaluates to True if this expression contains an aggregate function.'''
|
||||
if isinstance(self, AggFunc):
|
||||
return True
|
||||
if self.is_func:
|
||||
for arg in self.args:
|
||||
if arg.is_agg:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_col(self):
|
||||
return isinstance(self, Column)
|
||||
|
||||
@property
|
||||
def is_constant(self):
|
||||
return isinstance(self, DataType)
|
||||
|
||||
@property
|
||||
def returns_boolean(self):
|
||||
return issubclass(self.type, Boolean)
|
||||
|
||||
@property
|
||||
def returns_number(self):
|
||||
return issubclass(self.type, Number)
|
||||
|
||||
@property
|
||||
def returns_int(self):
|
||||
return issubclass(self.type, Int)
|
||||
|
||||
@property
|
||||
def returns_float(self):
|
||||
return issubclass(self.type, Float)
|
||||
|
||||
@property
|
||||
def returns_string(self):
|
||||
return issubclass(self.type, String)
|
||||
|
||||
@property
|
||||
def returns_timestamp(self):
|
||||
return issubclass(self.type, Timestamp)
|
||||
|
||||
|
||||
class Column(ValExpr):
|
||||
'''A representation of a col. All TableExprs will have Columns. So a Column
|
||||
may belong to an InlineView as well as a standard Table.
|
||||
|
||||
This class is used in two ways:
|
||||
|
||||
1) As a piece of metadata in a table definiton. In this usage the col isn't
|
||||
intended to represent an val.
|
||||
|
||||
2) As an expr in a query, for example an item being selected or as part of
|
||||
a join condition. In this usage the col is more like a val, which is why
|
||||
it implements/extends ValExpr.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, owner, name, type_):
|
||||
self.owner = owner
|
||||
self.name = name
|
||||
self._type = type_
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self._type
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Column):
|
||||
return False
|
||||
if self is other:
|
||||
return True
|
||||
return self.name == other.name and self.owner.identifier == other.owner.identifier
|
||||
|
||||
def __repr__(self):
|
||||
return '%s<name: %s, type: %s>' % (
|
||||
type(self).__name__, self.name, self._type.__name__)
|
||||
|
||||
|
||||
class FromClause(object):
|
||||
'''A representation of a FROM clause. The member variable join_clauses may optionally
|
||||
contain JoinClause items.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, table_expr, join_clauses=None):
|
||||
self.table_expr = table_expr
|
||||
self.join_clauses = join_clauses or list()
|
||||
|
||||
@property
|
||||
def table_exprs(self):
|
||||
'''Provides a list of all table_exprs that are declared within this FROM
|
||||
block.
|
||||
'''
|
||||
table_exprs = [join_clause.table_expr for join_clause in self.join_clauses]
|
||||
table_exprs.append(self.table_expr)
|
||||
return table_exprs
|
||||
|
||||
|
||||
class TableExpr(object):
|
||||
'''This is an AbstractClass that represents something that a query may use to select
|
||||
from or join on. The abc module was not used because it caused problems for the
|
||||
pickle module.
|
||||
|
||||
'''
|
||||
|
||||
def identifier(self):
|
||||
'''Returns either a table name or alias if one has been declared.'''
|
||||
pass
|
||||
|
||||
def cols(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def cols_by_base_type(self):
|
||||
'''Group cols by their basic data type and return a dict of the results.
|
||||
|
||||
As an example, a "BigInt" would be considered as an "Int".
|
||||
'''
|
||||
return DataType.group_by_base_type(self.cols)
|
||||
|
||||
@property
|
||||
def is_table(self):
|
||||
return isinstance(self, Table)
|
||||
|
||||
@property
|
||||
def is_inline_view(self):
|
||||
return isinstance(self, InlineView)
|
||||
|
||||
@property
|
||||
def is_with_clause_inline_view(self):
|
||||
return isinstance(self, WithClauseInlineView)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, type(self)):
|
||||
return False
|
||||
return self.identifier == other.identifier
|
||||
|
||||
|
||||
class Table(TableExpr):
|
||||
'''Represents a standard database table.'''
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self._cols = []
|
||||
self.alias = None
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
return self.alias or self.name
|
||||
|
||||
@property
|
||||
def cols(self):
|
||||
return self._cols
|
||||
|
||||
@cols.setter
|
||||
def cols(self, cols):
|
||||
self._cols = cols
|
||||
|
||||
|
||||
class InlineView(TableExpr):
|
||||
'''Represents an inline view.
|
||||
|
||||
Ex: In the query "SELECT * FROM (SELECT * FROM foo) AS bar",
|
||||
"(SELECT * FROM foo) AS bar" would be an inline view.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, query):
|
||||
self.query = query
|
||||
self.alias = None
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
return self.alias
|
||||
|
||||
@property
|
||||
def cols(self):
|
||||
return [Column(self, item.alias, item.type) for item in
|
||||
self.query.select_clause.non_agg_items + self.query.select_clause.agg_items]
|
||||
|
||||
|
||||
class WithClause(object):
|
||||
'''Represents a WITH clause.
|
||||
|
||||
Ex: In the query "WITH bar AS (SELECT * FROM foo) SELECT * FROM bar",
|
||||
"WITH bar AS (SELECT * FROM foo)" would be the with clause.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, with_clause_inline_views):
|
||||
self.with_clause_inline_views = with_clause_inline_views
|
||||
|
||||
@property
|
||||
def table_exprs(self):
|
||||
return self.with_clause_inline_views
|
||||
|
||||
|
||||
class WithClauseInlineView(InlineView):
|
||||
'''Represents the entries in a WITH clause. These are very similar to InlineViews but
|
||||
may have an additional alias.
|
||||
|
||||
Ex: WITH bar AS (SELECT * FROM foo)
|
||||
SELECT *
|
||||
FROM bar as r
|
||||
JOIN (SELECT * FROM baz) AS z ON ...
|
||||
|
||||
The WithClauseInlineView has aliases "bar" and "r" while the InlineView has
|
||||
only the alias "z".
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, query, with_clause_alias):
|
||||
self.query = query
|
||||
self.with_clause_alias = with_clause_alias
|
||||
self.alias = None
|
||||
|
||||
@property
|
||||
def identifier(self):
|
||||
return self.alias or self.with_clause_alias
|
||||
|
||||
|
||||
class JoinClause(object):
|
||||
'''A representation of a JOIN clause.
|
||||
|
||||
Ex: SELECT * FROM foo <join_type> JOIN <table_expr> [ON <boolean_expr>]
|
||||
|
||||
The member variable boolean_expr will be an instance of a boolean func
|
||||
defined below.
|
||||
|
||||
'''
|
||||
|
||||
JOINS_TYPES = ['INNER', 'LEFT', 'RIGHT', 'FULL OUTER', 'CROSS']
|
||||
|
||||
def __init__(self, join_type, table_expr, boolean_expr=None):
|
||||
self.join_type = join_type
|
||||
self.table_expr = table_expr
|
||||
self.boolean_expr = boolean_expr
|
||||
|
||||
|
||||
class WhereClause(object):
|
||||
'''The member variable boolean_expr will be an instance of a boolean func
|
||||
defined below.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, boolean_expr):
|
||||
self.boolean_expr = boolean_expr
|
||||
|
||||
|
||||
class GroupByClause(object):
|
||||
|
||||
def __init__(self, select_items):
|
||||
self.group_by_items = select_items
|
||||
|
||||
|
||||
class HavingClause(object):
|
||||
'''The member variable boolean_expr will be an instance of a boolean func
|
||||
defined below.
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, boolean_expr):
|
||||
self.boolean_expr = boolean_expr
|
||||
|
||||
|
||||
class UnionClause(object):
|
||||
'''A representation of a UNION clause.
|
||||
|
||||
If the member variable "all" is True, the instance represents a "UNION ALL".
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, query):
|
||||
self.query = query
|
||||
self.all = False
|
||||
|
||||
@property
|
||||
def queries(self):
|
||||
queries = list()
|
||||
query = self.query
|
||||
while True:
|
||||
queries.append(query)
|
||||
if not query.union_clause:
|
||||
break
|
||||
query = query.union_clause.query
|
||||
return queries
|
||||
|
||||
|
||||
class DataTypeMetaclass(type):
|
||||
'''Provides sorting of classes used to determine upcasting.'''
|
||||
|
||||
def __cmp__(cls, other):
|
||||
return cmp(
|
||||
getattr(cls, 'CMP_VALUE', cls.__name__),
|
||||
getattr(other, 'CMP_VALUE', other.__name__))
|
||||
|
||||
|
||||
class DataType(ValExpr):
|
||||
'''Base class for data types.
|
||||
|
||||
Data types are represented as classes so inheritence can be used.
|
||||
|
||||
'''
|
||||
|
||||
__metaclass__ = DataTypeMetaclass
|
||||
|
||||
def __init__(self, val):
|
||||
self.val = val
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return type(self)
|
||||
|
||||
@staticmethod
|
||||
def group_by_base_type(vals):
|
||||
'''Group cols by their basic data type and return a dict of the results.
|
||||
|
||||
As an example, a "BigInt" would be considered as an "Int".
|
||||
'''
|
||||
vals_by_type = defaultdict(list)
|
||||
for val in vals:
|
||||
type_ = val.type
|
||||
if issubclass(type_, Int):
|
||||
type_ = Int
|
||||
elif issubclass(type_, Float):
|
||||
type_ = Float
|
||||
vals_by_type[type_].append(val)
|
||||
return vals_by_type
|
||||
|
||||
|
||||
class Boolean(DataType):
|
||||
pass
|
||||
|
||||
|
||||
class Number(DataType):
|
||||
pass
|
||||
|
||||
|
||||
class Int(Number):
|
||||
|
||||
# Used to compare with other numbers for determining upcasting
|
||||
CMP_VALUE = 2
|
||||
|
||||
# Used during data generation to keep vals in range
|
||||
MIN = -2 ** 31
|
||||
MAX = -MIN - 1
|
||||
|
||||
# Aliases used when reading and writing table definitions
|
||||
POSTGRESQL = ['INTEGER']
|
||||
|
||||
|
||||
class TinyInt(Int):
|
||||
|
||||
CMP_VALUE = 0
|
||||
|
||||
MIN = -2 ** 7
|
||||
MAX = -MIN - 1
|
||||
|
||||
POSTGRESQL = ['SMALLINT']
|
||||
|
||||
|
||||
class SmallInt(Int):
|
||||
|
||||
CMP_VALUE = 1
|
||||
|
||||
MIN = -2 ** 15
|
||||
MAX = -MIN - 1
|
||||
|
||||
|
||||
class BigInt(Int):
|
||||
|
||||
CMP_VALUE = 3
|
||||
|
||||
MIN = -2 ** 63
|
||||
MAX = -MIN - 1
|
||||
|
||||
|
||||
class Float(Number):
|
||||
|
||||
CMP_VALUE = 4
|
||||
|
||||
POSTGRESQL = ['REAL']
|
||||
|
||||
|
||||
class Double(Float):
|
||||
|
||||
CMP_VALUE = 5
|
||||
|
||||
MYSQL = ['DOUBLE', 'DECIMAL'] # Use double by default but add decimal synonym
|
||||
POSTGRESQL = ['DOUBLE PRECISION']
|
||||
|
||||
|
||||
class String(DataType):
|
||||
|
||||
MIN = 0
|
||||
# The Impala limit is 32,767 but MySQL has a row size limit of 65,535. To allow 3+
|
||||
# String cols per table, the limit will be lowered to 1,000. That should be fine
|
||||
# for testing anyhow.
|
||||
MAX = 1000
|
||||
|
||||
MYSQL = ['VARCHAR(%s)' % MAX]
|
||||
POSTGRESQL = MYSQL + ['CHARACTER VARYING']
|
||||
|
||||
|
||||
class Timestamp(DataType):
|
||||
|
||||
MYSQL = ['DATETIME']
|
||||
POSTGRESQL = ['TIMESTAMP WITHOUT TIME ZONE']
|
||||
|
||||
|
||||
NUMBER_TYPES = [Int, TinyInt, SmallInt, BigInt, Float, Double]
|
||||
TYPES = NUMBER_TYPES + [Boolean, String, Timestamp]
|
||||
|
||||
class Func(ValExpr):
|
||||
'''Base class for funcs'''
|
||||
|
||||
def __init__(self, *args):
|
||||
self.args = list(args)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(type(self)) + hash(tuple(self.args))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, type(self)):
|
||||
return False
|
||||
if self is other:
|
||||
return True
|
||||
return self.args == other.args
|
||||
|
||||
|
||||
class UnaryFunc(Func):
|
||||
|
||||
def __init__(self, arg):
|
||||
Func.__init__(self, arg)
|
||||
|
||||
|
||||
class BinaryFunc(Func):
|
||||
|
||||
def __init__(self, left, right):
|
||||
Func.__init__(self, left, right)
|
||||
|
||||
@property
|
||||
def left(self):
|
||||
return self.args[0]
|
||||
|
||||
@left.setter
|
||||
def left(self, left):
|
||||
self.args[0] = left
|
||||
|
||||
@property
|
||||
def right(self):
|
||||
return self.args[1]
|
||||
|
||||
@right.setter
|
||||
def right(self, right):
|
||||
self.args[1] = right
|
||||
|
||||
|
||||
class BooleanFunc(Func):
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return Boolean
|
||||
|
||||
|
||||
class IntFunc(Func):
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return Int
|
||||
|
||||
|
||||
class DoubleFunc(Func):
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return Double
|
||||
|
||||
|
||||
class StringFunc(Func):
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return String
|
||||
|
||||
|
||||
class UpcastingFunc(Func):
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return max(arg.type for arg in self.args)
|
||||
|
||||
|
||||
class AggFunc(Func):
|
||||
|
||||
# Avoid having a self.distinct because it would need to be __init__'d explictly,
|
||||
# which none of the AggFunc subclasses do (ex: Avg doesn't have it's
|
||||
# own __init__).
|
||||
|
||||
@property
|
||||
def distinct(self):
|
||||
return getattr(self, '_distinct', False)
|
||||
|
||||
@distinct.setter
|
||||
def distinct(self, val):
|
||||
return setattr(self, '_distinct', val)
|
||||
|
||||
# The classes below diverge from above by including the SQL representation. It's a lot
|
||||
# easier this way because there are a lot of funcs but they all have the same
|
||||
# structure. Non-standard funcs, such as string concatenation, would need to have
|
||||
# their representation information elsewhere (like classes above).
|
||||
|
||||
Parentheses = type('Parentheses', (UnaryFunc, UpcastingFunc), {'FORMAT': '({})'})
|
||||
|
||||
IsNull = type('IsNull', (UnaryFunc, BooleanFunc), {'FORMAT': '{} IS NULL'})
|
||||
IsNotNull = type('IsNotNull', (UnaryFunc, BooleanFunc), {'FORMAT': '{} IS NOT NULL'})
|
||||
And = type('And', (BinaryFunc, BooleanFunc), {'FORMAT': '{} AND {}'})
|
||||
Or = type('Or', (BinaryFunc, BooleanFunc), {'FORMAT': '{} OR {}'})
|
||||
Equals = type('Equals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} = {}'})
|
||||
NotEquals = type('NotEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} != {}'})
|
||||
GreaterThan = type('GreaterThan', (BinaryFunc, BooleanFunc), {'FORMAT': '{} > {}'})
|
||||
LessThan = type('LessThan', (BinaryFunc, BooleanFunc), {'FORMAT': '{} < {}'})
|
||||
GreaterThanOrEquals = type(
|
||||
'GreaterThanOrEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} >= {}'})
|
||||
LessThanOrEquals = type(
|
||||
'LessThanOrEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} <= {}'})
|
||||
|
||||
Plus = type('Plus', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} + {}'})
|
||||
Minus = type('Minus', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} - {}'})
|
||||
Multiply = type('Multiply', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} * {}'})
|
||||
Divide = type('Divide', (BinaryFunc, DoubleFunc), {'FORMAT': '{} / {}'})
|
||||
|
||||
Floor = type('Floor', (UnaryFunc, IntFunc), {'FORMAT': 'FLOOR({})'})
|
||||
|
||||
Concat = type('Concat', (BinaryFunc, StringFunc), {'FORMAT': 'CONCAT({}, {})'})
|
||||
Length = type('Length', (UnaryFunc, IntFunc), {'FORMAT': 'LENGTH({})'})
|
||||
|
||||
ExtractYear = type(
|
||||
'ExtractYear', (UnaryFunc, IntFunc), {'FORMAT': "EXTRACT('YEAR' FROM {})"})
|
||||
|
||||
# Formatting of agg funcs is a little trickier since they may have a distinct
|
||||
Avg = type('Avg', (UnaryFunc, DoubleFunc, AggFunc), {})
|
||||
Count = type('Count', (UnaryFunc, IntFunc, AggFunc), {})
|
||||
Max = type('Max', (UnaryFunc, UpcastingFunc, AggFunc), {})
|
||||
Min = type('Min', (UnaryFunc, UpcastingFunc, AggFunc), {})
|
||||
Sum = type('Sum', (UnaryFunc, UpcastingFunc, AggFunc), {})
|
||||
|
||||
UNARY_BOOLEAN_FUNCS = [IsNull, IsNotNull]
|
||||
BINARY_BOOLEAN_FUNCS = [And, Or]
|
||||
RELATIONAL_OPERATORS = [
|
||||
Equals, NotEquals, GreaterThan, LessThan, GreaterThanOrEquals, LessThanOrEquals]
|
||||
MATH_OPERATORS = [Plus, Minus, Multiply] # Leaving out Divide
|
||||
BINARY_STRING_FUNCS = [Concat]
|
||||
AGG_FUNCS = [Avg, Count, Max, Min, Sum]
|
||||
|
||||
|
||||
class If(Func):
|
||||
|
||||
FORMAT = 'CASE WHEN {} THEN {} ELSE {} END'
|
||||
|
||||
def __init__(self, boolean_expr, consquent_expr, alternative_expr):
|
||||
Func.__init__(
|
||||
self, boolean_expr, consquent_expr, alternative_expr)
|
||||
|
||||
@property
|
||||
def boolean_expr(self):
|
||||
return self.args[0]
|
||||
|
||||
@property
|
||||
def consquent_expr(self):
|
||||
return self.args[1]
|
||||
|
||||
@property
|
||||
def alternative_expr(self):
|
||||
return self.args[2]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return max((self.consquent_expr, self.alternative_expr))
|
||||
|
||||
|
||||
class Greatest(BinaryFunc, UpcastingFunc, If):
|
||||
|
||||
def __init__(self, left, rigt):
|
||||
BinaryFunc.__init__(self, left, rigt)
|
||||
If.__init__(self, GreaterThan(left, rigt), left, rigt)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return UpcastingFunc.type.fget(self)
|
||||
|
||||
|
||||
class Cast(Func):
|
||||
|
||||
FORMAT = 'CAST({} AS {})'
|
||||
|
||||
def __init__(self, val_expr, resulting_type):
|
||||
if resulting_type not in TYPES:
|
||||
raise Exception('Unexpected type: {}'.format(resulting_type))
|
||||
Func.__init__(self, val_expr, resulting_type)
|
||||
|
||||
@property
|
||||
def val_expr(self):
|
||||
return self.args[0]
|
||||
|
||||
@property
|
||||
def resulting_type(self):
|
||||
return self.args[1]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.resulting_type
|
||||
367
tests/comparison/model_translator.py
Normal file
367
tests/comparison/model_translator.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from inspect import getmro
|
||||
from logging import getLogger
|
||||
from re import sub
|
||||
from sqlparse import format
|
||||
|
||||
from tests.comparison.model import (
|
||||
Boolean,
|
||||
Float,
|
||||
Int,
|
||||
Number,
|
||||
Query,
|
||||
String,
|
||||
Timestamp)
|
||||
|
||||
LOG = getLogger(__name__)
|
||||
|
||||
class SqlWriter(object):
|
||||
'''Subclasses of SQLWriter will take a Query and provide the SQL representation for a
|
||||
specific database such as Impala or MySQL. The SqlWriter.create([dialect=])
|
||||
factory method may be used instead of specifying the concrete class.
|
||||
|
||||
Another important function of this class is to ensure that CASTs produce the same
|
||||
results across different databases. Sometimes the CASTs implemented here produce odd
|
||||
results. For example, the result of CAST(date_col AS INT) in MySQL may be an int of
|
||||
YYYYMMDD whereas in Impala it may be seconds since the epoch. For comparison purposes
|
||||
the CAST could be transformed into EXTRACT(DAY from date_col).
|
||||
'''
|
||||
|
||||
@staticmethod
|
||||
def create(dialect='impala'):
|
||||
'''Create and return a new SqlWriter appropriate for the given sql dialect. "dialect"
|
||||
refers to database specific deviations of sql, and the val should be one of
|
||||
"IMPALA", "MYSQL", or "POSTGRESQL".
|
||||
'''
|
||||
dialect = dialect.upper()
|
||||
if dialect == 'IMPALA':
|
||||
return SqlWriter()
|
||||
if dialect == 'POSTGRESQL':
|
||||
return PostgresqlSqlWriter()
|
||||
if dialect == 'MYSQL':
|
||||
return MySQLSqlWriter()
|
||||
raise Exception('Unknown dialect: %s' % dialect)
|
||||
|
||||
def write_query(self, query, pretty=False):
|
||||
'''Return SQL as a string for the given query.'''
|
||||
sql = list()
|
||||
# Write out each section in the proper order
|
||||
for clause in (
|
||||
query.with_clause,
|
||||
query.select_clause,
|
||||
query.from_clause,
|
||||
query.where_clause,
|
||||
query.group_by_clause,
|
||||
query.having_clause,
|
||||
query.union_clause):
|
||||
if clause:
|
||||
sql.append(self._write(clause))
|
||||
sql = '\n'.join(sql)
|
||||
if pretty:
|
||||
sql = self.make_pretty_sql(sql)
|
||||
return sql
|
||||
|
||||
def make_pretty_sql(self, sql):
|
||||
try:
|
||||
sql = format(sql, reindent=True)
|
||||
except Exception as e:
|
||||
LOG.warn('Unable to format sql: %s', e)
|
||||
return sql
|
||||
|
||||
def _write_with_clause(self, with_clause):
|
||||
return 'WITH ' + ',\n'.join('%s AS (%s)' % (view.identifier, self._write(view.query))
|
||||
for view in with_clause.with_clause_inline_views)
|
||||
|
||||
def _write_select_clause(self, select_clause):
|
||||
items = select_clause.non_agg_items + select_clause.agg_items
|
||||
sql = 'SELECT'
|
||||
if select_clause.distinct:
|
||||
sql += ' DISTINCT'
|
||||
sql += '\n' + ',\n'.join(self._write(item) for item in items)
|
||||
return sql
|
||||
|
||||
def _write_select_item(self, select_item):
|
||||
# If the query is nested, the items will have aliases so that the outer query can
|
||||
# easily reference them.
|
||||
if not select_item.alias:
|
||||
raise Exception('An alias is required')
|
||||
return '%s AS %s' % (self._write(select_item.val_expr), select_item.alias)
|
||||
|
||||
def _write_column(self, col):
|
||||
return '%s.%s' % (col.owner.identifier, col.name)
|
||||
|
||||
def _write_from_clause(self, from_clause):
|
||||
sql = 'FROM %s' % self._write(from_clause.table_expr)
|
||||
if from_clause.join_clauses:
|
||||
sql += '\n' + '\n'.join(self._write(join) for join in from_clause.join_clauses)
|
||||
return sql
|
||||
|
||||
def _write_table(self, table):
|
||||
if table.alias:
|
||||
return '%s AS %s' % (table.name, table.identifier)
|
||||
return table.name
|
||||
|
||||
def _write_inline_view(self, inline_view):
|
||||
if not inline_view.identifier:
|
||||
raise Exception('An inline view requires an identifier')
|
||||
return '(\n%s\n) AS %s' % (self._write(inline_view.query), inline_view.identifier)
|
||||
|
||||
def _write_with_clause_inline_view(self, with_clause_inline_view):
|
||||
if not with_clause_inline_view.with_clause_alias:
|
||||
raise Exception('An with clause entry requires an identifier')
|
||||
sql = with_clause_inline_view.with_clause_alias
|
||||
if with_clause_inline_view.alias:
|
||||
sql += ' AS ' + with_clause_inline_view.alias
|
||||
return sql
|
||||
|
||||
def _write_join_clause(self, join_clause):
|
||||
sql = '%s JOIN %s' % (join_clause.join_type, self._write(join_clause.table_expr))
|
||||
if join_clause.boolean_expr:
|
||||
sql += ' ON ' + self._write(join_clause.boolean_expr)
|
||||
return sql
|
||||
|
||||
def _write_where_clause(self, where_clause):
|
||||
return 'WHERE\n' + self._write(where_clause.boolean_expr)
|
||||
|
||||
def _write_group_by_clause(self, group_by_clause):
|
||||
return 'GROUP BY\n' + ',\n'.join(self._write(item.val_expr)
|
||||
for item in group_by_clause.group_by_items)
|
||||
|
||||
def _write_having_clause(self, having_clause):
|
||||
return 'HAVING\n' + self._write(having_clause.boolean_expr)
|
||||
|
||||
def _write_union_clause(self, union_clause):
|
||||
sql = 'UNION'
|
||||
if union_clause.all:
|
||||
sql += ' ALL'
|
||||
sql += '\n' + self._write(union_clause.query)
|
||||
return sql
|
||||
|
||||
def _write_data_type(self, data_type):
|
||||
'''Write a literal value.'''
|
||||
if data_type.returns_string:
|
||||
return "'{}'".format(data_type.val)
|
||||
if data_type.returns_timestamp:
|
||||
return "CAST('{}' AS TIMESTAMP)".format(data_type.val)
|
||||
return str(data_type.val)
|
||||
|
||||
def _write_func(self, func):
|
||||
return func.FORMAT.format(*[self._write(arg) for arg in func.args])
|
||||
|
||||
def _write_cast(self, cast):
|
||||
# Handle casts that produce different results across database types or just don't
|
||||
# make sense like casting a DATE as a BOOLEAN....
|
||||
if cast.val_expr.returns_boolean:
|
||||
if issubclass(cast.resulting_type, Timestamp):
|
||||
return "CAST(CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END AS TIMESTAMP)"\
|
||||
.format(self._write(cast.val_expr))
|
||||
elif cast.val_expr.returns_number:
|
||||
if issubclass(cast.resulting_type, Timestamp):
|
||||
return ("CAST(CONCAT('2000-01-', "
|
||||
"LPAD(CAST(ABS(FLOOR({})) % 31 + 1 AS STRING), 2, '0')) "
|
||||
"AS TIMESTAMP)").format(self._write(cast.val_expr))
|
||||
elif cast.val_expr.returns_string:
|
||||
if issubclass(cast.resulting_type, Boolean):
|
||||
return "(LENGTH({}) > 2)".format(self._write(cast.val_expr))
|
||||
if issubclass(cast.resulting_type, Timestamp):
|
||||
return ("CAST(CONCAT('2000-01-', LPAD(CAST(LENGTH({}) % 31 + 1 AS STRING), "
|
||||
"2, '0')) AS TIMESTAMP)").format(self._write(cast.val_expr))
|
||||
elif cast.val_expr.returns_timestamp:
|
||||
if issubclass(cast.resulting_type, Boolean):
|
||||
return '(DAY({0}) > MONTH({0}))'.format(self._write(cast.val_expr))
|
||||
if issubclass(cast.resulting_type, Number):
|
||||
return ('(DAY({0}) + 100 * MONTH({0}) + 100 * 100 * YEAR({0}))').format(
|
||||
self._write(cast.val_expr))
|
||||
return self._write_func(cast)
|
||||
|
||||
def _write_agg_func(self, agg_func):
|
||||
sql = type(agg_func).__name__.upper() + '('
|
||||
if agg_func.distinct:
|
||||
sql += 'DISTINCT '
|
||||
# All agg funcs only have a single arg
|
||||
sql += self._write(agg_func.args[0]) + ')'
|
||||
return sql
|
||||
|
||||
def _write_data_type_metaclass(self, data_type_class):
|
||||
'''Write a data type class such as Int or Boolean.'''
|
||||
return data_type_class.__name__.upper()
|
||||
|
||||
def _write(self, object_):
|
||||
'''Return a sql string representation of the given object.'''
|
||||
# What's below is effectively a giant switch statement. It works based on a func
|
||||
# naming and signature convention. It should match the incoming object with the
|
||||
# corresponding func defined, then call the func and return the result.
|
||||
#
|
||||
# Ex:
|
||||
# a = model.And(...)
|
||||
# _write(a) should call _write_func(a) because "And" is a subclass of "Func" and no
|
||||
# other _writer_<class name> methods have been defined higher up the method
|
||||
# resolution order (MRO). If _write_and(...) were to be defined, it would be called
|
||||
# instead.
|
||||
for type_ in getmro(type(object_)):
|
||||
writer_func_name = '_write' + sub('([A-Z])', r'_\1', type_.__name__).lower()
|
||||
writer_func = getattr(self, writer_func_name, None)
|
||||
if writer_func:
|
||||
return writer_func(object_)
|
||||
|
||||
# Handle any remaining cases
|
||||
if isinstance(object_, Query):
|
||||
return self.write_query(object_)
|
||||
|
||||
raise Exception('Unsupported object: %s<%s>' % (type(object_).__name__, object_))
|
||||
|
||||
|
||||
class PostgresqlSqlWriter(SqlWriter):
|
||||
# TODO: This class is out of date since switching to MySQL. This is left here as is
|
||||
# in case there is a desire to switch back in the future (it should be better than
|
||||
# starting from nothing).
|
||||
|
||||
def _write_divide(self, divide):
|
||||
# For ints, Postgresql does int division but Impala does float division.
|
||||
return 'CAST({} AS REAL) / {}' \
|
||||
.format(*[self._write(arg) for arg in divide.args])
|
||||
|
||||
def _write_data_type_metaclass(self, data_type_class):
|
||||
'''Write a data type class such as Int or Boolean.'''
|
||||
if hasattr(data_type_class, 'POSTGRESQL'):
|
||||
return data_type_class.POSTGRESQL[0]
|
||||
return data_type_class.__name__.upper()
|
||||
|
||||
def _write_cast(self, cast):
|
||||
# Handle casts that produce different results across database types or just don't
|
||||
# make sense like casting a DATE as a BOOLEAN....
|
||||
if cast.val_expr.returns_boolean:
|
||||
if issubclass(cast.resulting_type, Float):
|
||||
return "CASE {} WHEN TRUE THEN 1.0 WHEN FALSE THEN 0.0 END".format(
|
||||
self._write(cast.val_expr))
|
||||
if issubclass(cast.resulting_type, Timestamp):
|
||||
return "CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END".format(
|
||||
self._write(cast.val_expr))
|
||||
if issubclass(cast.resulting_type, String):
|
||||
return "CASE {} WHEN TRUE THEN '1' WHEN FALSE THEN '0' END".format(
|
||||
self._write(cast.val_expr))
|
||||
elif cast.val_expr.returns_number:
|
||||
if issubclass(cast.resulting_type, Boolean):
|
||||
return 'CASE WHEN ({0}) != 0 THEN TRUE WHEN ({0}) = 0 THEN FALSE END'.format(
|
||||
self._write(cast.val_expr))
|
||||
if issubclass(cast.resulting_type, Timestamp):
|
||||
return "CASE WHEN ({}) > 0 THEN '2000-01-01' ELSE '1999-01-01' END".format(
|
||||
self._write(cast.val_expr))
|
||||
elif cast.val_expr.returns_string:
|
||||
if issubclass(cast.resulting_type, Boolean):
|
||||
return "(LENGTH({}) > 2)".format(self._write(cast.val_expr))
|
||||
elif cast.val_expr.returns_timestamp:
|
||||
if issubclass(cast.resulting_type, Boolean):
|
||||
return '(EXTRACT(DAY FROM {0}) > EXTRACT(MONTH FROM {0}))'.format(
|
||||
self._write(cast.val_expr))
|
||||
if issubclass(cast.resulting_type, Number):
|
||||
return ('(EXTRACT(DAY FROM {0}) '
|
||||
'+ 100 * EXTRACT(MONTH FROM {0}) '
|
||||
'+ 100 * 100 * EXTRACT(YEAR FROM {0}))').format(
|
||||
self._write(cast.val_expr))
|
||||
return self._write_func(cast)
|
||||
|
||||
|
||||
class MySQLSqlWriter(SqlWriter):
|
||||
|
||||
def write_query(self, query, pretty=False):
|
||||
# MySQL doesn't support WITH clauses so they need to be converted into inline views.
|
||||
# We are going to cheat by making use of the fact that the query generator creates
|
||||
# with clause entries with unique aliases even considering nested queries.
|
||||
sql = list()
|
||||
for clause in (
|
||||
query.select_clause,
|
||||
query.from_clause,
|
||||
query.where_clause,
|
||||
query.group_by_clause,
|
||||
query.having_clause,
|
||||
query.union_clause):
|
||||
if clause:
|
||||
sql.append(self._write(clause))
|
||||
sql = '\n'.join(sql)
|
||||
if query.with_clause:
|
||||
# Just replace the named referenes with inline views. Go in reverse order because
|
||||
# entries at the bottom of the WITH clause definition may reference entries above.
|
||||
for with_clause_inline_view in reversed(query.with_clause.with_clause_inline_views):
|
||||
replacement_sql = '(' + self.write_query(with_clause_inline_view.query) + ')'
|
||||
sql = sql.replace(with_clause_inline_view.identifier, replacement_sql)
|
||||
if pretty:
|
||||
sql = self.make_pretty_sql(sql)
|
||||
return sql
|
||||
|
||||
def _write_data_type_metaclass(self, data_type_class):
|
||||
'''Write a data type class such as Int or Boolean.'''
|
||||
if issubclass(data_type_class, Int):
|
||||
return 'INTEGER'
|
||||
if issubclass(data_type_class, Float):
|
||||
return 'DECIMAL(65, 15)'
|
||||
if issubclass(data_type_class, String):
|
||||
return 'CHAR'
|
||||
if hasattr(data_type_class, 'MYSQL'):
|
||||
return data_type_class.MYSQL[0]
|
||||
return data_type_class.__name__.upper()
|
||||
|
||||
def _write_data_type(self, data_type):
|
||||
'''Write a literal value.'''
|
||||
if data_type.returns_timestamp:
|
||||
return "CAST('{}' AS DATETIME)".format(data_type.val)
|
||||
if data_type.returns_boolean:
|
||||
# MySQL will error if a data_type "FALSE" is used as a GROUP BY field
|
||||
return '(0 = 0)' if data_type.val else '(1 = 0)'
|
||||
return SqlWriter._write_data_type(self, data_type)
|
||||
|
||||
def _write_cast(self, cast):
|
||||
# Handle casts that produce different results across database types or just don't
|
||||
# make sense like casting a DATE as a BOOLEAN....
|
||||
if cast.val_expr.returns_boolean:
|
||||
if issubclass(cast.resulting_type, Timestamp):
|
||||
return "CAST(CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END AS DATETIME)"\
|
||||
.format(self._write(cast.val_expr))
|
||||
elif cast.val_expr.returns_number:
|
||||
if issubclass(cast.resulting_type, Boolean):
|
||||
return ("CASE WHEN ({0}) != 0 THEN TRUE WHEN ({0}) = 0 THEN FALSE END").format(
|
||||
self._write(cast.val_expr))
|
||||
if issubclass(cast.resulting_type, Timestamp):
|
||||
return "CAST(CONCAT('2000-01-', ABS(FLOOR({})) % 31 + 1) AS DATETIME)"\
|
||||
.format(self._write(cast.val_expr))
|
||||
elif cast.val_expr.returns_string:
|
||||
if issubclass(cast.resulting_type, Boolean):
|
||||
return "(LENGTH({}) > 2)".format(self._write(cast.val_expr))
|
||||
if issubclass(cast.resulting_type, Timestamp):
|
||||
return ("CAST(CONCAT('2000-01-', LENGTH({}) % 31 + 1) AS DATETIME)").format(
|
||||
self._write(cast.val_expr))
|
||||
elif cast.val_expr.returns_timestamp:
|
||||
if issubclass(cast.resulting_type, Number):
|
||||
return ('(EXTRACT(DAY FROM {0}) '
|
||||
'+ 100 * EXTRACT(MONTH FROM {0}) '
|
||||
'+ 100 * 100 * EXTRACT(YEAR FROM {0}))').format(
|
||||
self._write(cast.val_expr))
|
||||
if issubclass(cast.resulting_type, Boolean):
|
||||
return '(EXTRACT(DAY FROM {0}) > EXTRACT(MONTH FROM {0}))'.format(
|
||||
self._write(cast.val_expr))
|
||||
|
||||
# MySQL uses different type names when casting...
|
||||
if issubclass(cast.resulting_type, Boolean):
|
||||
data_type = 'UNSIGNED'
|
||||
elif issubclass(cast.resulting_type, Float):
|
||||
data_type = 'DECIMAL(65, 15)'
|
||||
elif issubclass(cast.resulting_type, Int):
|
||||
data_type = 'SIGNED'
|
||||
elif issubclass(cast.resulting_type, String):
|
||||
data_type = 'CHAR'
|
||||
elif issubclass(cast.resulting_type, Timestamp):
|
||||
data_type = 'DATETIME'
|
||||
return cast.FORMAT.format(self._write(cast.val_expr), data_type)
|
||||
556
tests/comparison/query_generator.py
Normal file
556
tests/comparison/query_generator.py
Normal file
@@ -0,0 +1,556 @@
|
||||
# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from random import choice, randint, shuffle
|
||||
|
||||
from tests.comparison.model import (
|
||||
AGG_FUNCS,
|
||||
AggFunc,
|
||||
And,
|
||||
BINARY_STRING_FUNCS,
|
||||
BigInt,
|
||||
Boolean,
|
||||
Cast,
|
||||
Column,
|
||||
Count,
|
||||
DataType,
|
||||
Double,
|
||||
Equals,
|
||||
Float,
|
||||
Floor,
|
||||
FromClause,
|
||||
Func,
|
||||
Greatest,
|
||||
GroupByClause,
|
||||
HavingClause,
|
||||
InlineView,
|
||||
Int,
|
||||
JoinClause,
|
||||
Length,
|
||||
MATH_OPERATORS,
|
||||
Number,
|
||||
Query,
|
||||
RELATIONAL_OPERATORS,
|
||||
SelectClause,
|
||||
SelectItem,
|
||||
String,
|
||||
Table,
|
||||
Timestamp,
|
||||
TYPES,
|
||||
UNARY_BOOLEAN_FUNCS,
|
||||
UnionClause,
|
||||
WhereClause,
|
||||
WithClause,
|
||||
WithClauseInlineView)
|
||||
|
||||
def random_boolean():
|
||||
'''Return a val that evaluates to True 50% of the time'''
|
||||
return randint(0, 1)
|
||||
|
||||
|
||||
def zero_or_more():
|
||||
'''The chance of the return val of n is 1 / 2 ^ (n + 1)'''
|
||||
val = 0
|
||||
while random_boolean():
|
||||
val += 1
|
||||
return val
|
||||
|
||||
|
||||
def one_or_more():
|
||||
return zero_or_more() + 1
|
||||
|
||||
|
||||
def random_non_empty_split(iterable):
|
||||
'''Return two non-empty lists'''
|
||||
if len(iterable) < 2:
|
||||
raise Exception('The iterable must contain at least two items')
|
||||
split_index = randint(1, len(iterable) - 1)
|
||||
left, right = list(), list()
|
||||
for idx, item in enumerate(iterable):
|
||||
if idx < split_index:
|
||||
left.append(item)
|
||||
else:
|
||||
right.append(item)
|
||||
return left, right
|
||||
|
||||
|
||||
class QueryGenerator(object):
|
||||
|
||||
def create_query(self,
|
||||
table_exprs,
|
||||
allow_with_clause=True,
|
||||
select_item_data_types=None):
|
||||
'''Create a random query using various language features.
|
||||
|
||||
The initial call to this method should only use tables in the table_exprs
|
||||
parameter, and not inline views or "with" definitions. The other types of
|
||||
table exprs may be added as part of the query generation.
|
||||
|
||||
If select_item_data_types is specified it must be a sequence or iterable of
|
||||
DataType. The generated query.select_clause.select_items will have data
|
||||
types suitable for use in a UNION.
|
||||
|
||||
'''
|
||||
# Make a copy so tables can be added if a "with" clause is used
|
||||
table_exprs = list(table_exprs)
|
||||
|
||||
with_clause = None
|
||||
if allow_with_clause and randint(1, 10) == 1:
|
||||
with_clause = self._create_with_clause(table_exprs)
|
||||
table_exprs.extend(with_clause.table_exprs)
|
||||
|
||||
from_clause = self._create_from_clause(table_exprs)
|
||||
|
||||
select_clause = self._create_select_clause(
|
||||
from_clause.table_exprs,
|
||||
select_item_data_types=select_item_data_types)
|
||||
|
||||
query = Query(select_clause, from_clause)
|
||||
|
||||
if with_clause:
|
||||
query.with_clause = with_clause
|
||||
|
||||
if random_boolean():
|
||||
query.where_clause = self._create_where_clause(from_clause.table_exprs)
|
||||
|
||||
if select_clause.agg_items and select_clause.non_agg_items:
|
||||
query.group_by_clause = GroupByClause(list(select_clause.non_agg_items))
|
||||
|
||||
if randint(1, 10) == 1:
|
||||
if select_clause.agg_items:
|
||||
self._enable_distinct_on_random_agg_items(select_clause.agg_items)
|
||||
else:
|
||||
select_clause.distinct = True
|
||||
|
||||
if random_boolean() and (query.group_by_clause or select_clause.agg_items):
|
||||
query.having_clause = self._create_having_clause(from_clause.table_exprs)
|
||||
|
||||
if randint(1, 10) == 1:
|
||||
select_item_data_types = list()
|
||||
for select_item in select_clause.select_items:
|
||||
# For numbers, choose the largest possible data type in case a CAST is needed.
|
||||
if select_item.val_expr.returns_float:
|
||||
select_item_data_types.append(Double)
|
||||
elif select_item.val_expr.returns_int:
|
||||
select_item_data_types.append(BigInt)
|
||||
else:
|
||||
select_item_data_types.append(select_item.val_expr.type)
|
||||
query.union_clause = UnionClause(self.create_query(
|
||||
table_exprs,
|
||||
allow_with_clause=False,
|
||||
select_item_data_types=select_item_data_types))
|
||||
query.union_clause.all = random_boolean()
|
||||
|
||||
return query
|
||||
|
||||
def _create_with_clause(self, table_exprs):
|
||||
# Make a copy so newly created tables can be added and made availabele for use in
|
||||
# future table definitions.
|
||||
table_exprs = list(table_exprs)
|
||||
with_clause_inline_views = list()
|
||||
for with_clause_inline_view_idx in xrange(one_or_more()):
|
||||
query = self.create_query(table_exprs)
|
||||
# To help prevent nested WITH clauses from having entries with the same alias,
|
||||
# choose a random alias. Of course it would be much better to know which aliases
|
||||
# were already chosen but that information isn't easy to get from here.
|
||||
with_clause_alias = 'with_%s_%s' % \
|
||||
(with_clause_inline_view_idx + 1, randint(1, 1000))
|
||||
with_clause_inline_view = WithClauseInlineView(query, with_clause_alias)
|
||||
table_exprs.append(with_clause_inline_view)
|
||||
with_clause_inline_views.append(with_clause_inline_view)
|
||||
return WithClause(with_clause_inline_views)
|
||||
|
||||
def _create_select_clause(self, table_exprs, select_item_data_types=None):
|
||||
while True:
|
||||
non_agg_items = [self._create_non_agg_select_item(table_exprs)
|
||||
for _ in xrange(zero_or_more())]
|
||||
agg_items = [self._create_agg_select_item(table_exprs)
|
||||
for _ in xrange(zero_or_more())]
|
||||
if non_agg_items or agg_items:
|
||||
if select_item_data_types:
|
||||
if len(select_item_data_types) > len(non_agg_items) + len(agg_items):
|
||||
# Not enough items generated, try again
|
||||
continue
|
||||
while len(select_item_data_types) < len(non_agg_items) + len(agg_items):
|
||||
items = choice([non_agg_items, agg_items])
|
||||
if items:
|
||||
items.pop()
|
||||
for data_type_idx, data_type in enumerate(select_item_data_types):
|
||||
if data_type_idx < len(non_agg_items):
|
||||
item = non_agg_items[data_type_idx]
|
||||
else:
|
||||
item = agg_items[data_type_idx - len(non_agg_items)]
|
||||
if not issubclass(item.type, data_type):
|
||||
item.val_expr = self.convert_val_expr_to_type(item.val_expr, data_type)
|
||||
for idx, item in enumerate(chain(non_agg_items, agg_items)):
|
||||
item.alias = '%s_col_%s' % (item.type.__name__.lower(), idx + 1)
|
||||
return SelectClause(non_agg_items=non_agg_items, agg_items=agg_items)
|
||||
|
||||
def _choose_col(self, table_exprs):
|
||||
table_expr = choice(table_exprs)
|
||||
return choice(table_expr.cols)
|
||||
|
||||
def _create_non_agg_select_item(self, table_exprs):
|
||||
return SelectItem(self._create_val_expr(table_exprs))
|
||||
|
||||
def _create_val_expr(self, table_exprs):
|
||||
vals = [self._choose_col(table_exprs) for _ in xrange(one_or_more())]
|
||||
return self._combine_val_exprs(vals)
|
||||
|
||||
def _create_agg_select_item(self, table_exprs):
|
||||
vals = [self._create_agg_val_expr(table_exprs) for _ in xrange(one_or_more())]
|
||||
return SelectItem(self._combine_val_exprs(vals))
|
||||
|
||||
def _create_agg_val_expr(self, table_exprs):
|
||||
val = self._create_val_expr(table_exprs)
|
||||
if issubclass(val.type, Number):
|
||||
funcs = list(AGG_FUNCS)
|
||||
else:
|
||||
funcs = [Count]
|
||||
return choice(funcs)(val)
|
||||
|
||||
def _create_from_clause(self, table_exprs):
|
||||
table_expr = self._create_table_expr(table_exprs)
|
||||
table_expr_count = 1
|
||||
table_expr.alias = 't%s' % table_expr_count
|
||||
from_clause = FromClause(table_expr)
|
||||
for join_idx in xrange(zero_or_more()):
|
||||
join_clause = self._create_join_clause(from_clause, table_exprs)
|
||||
table_expr_count += 1
|
||||
join_clause.table_expr.alias = 't%s' % table_expr_count
|
||||
from_clause.join_clauses.append(join_clause)
|
||||
return from_clause
|
||||
|
||||
def _create_table_expr(self, table_exprs):
|
||||
if randint(1, 10) == 1:
|
||||
return self._create_inline_view(table_exprs)
|
||||
return self._choose_table(table_exprs)
|
||||
|
||||
def _choose_table(self, table_exprs):
|
||||
return deepcopy(choice(table_exprs))
|
||||
|
||||
def _create_inline_view(self, table_exprs):
|
||||
return InlineView(self.create_query(table_exprs))
|
||||
|
||||
def _create_join_clause(self, from_clause, table_exprs):
|
||||
table_expr = self._create_table_expr(table_exprs)
|
||||
# Increase the chance of using the first join type which is INNER
|
||||
join_type_idx = (zero_or_more() / 2) % len(JoinClause.JOINS_TYPES)
|
||||
join_type = JoinClause.JOINS_TYPES[join_type_idx]
|
||||
join_clause = JoinClause(join_type, table_expr)
|
||||
|
||||
# Prefer non-boolean cols for the first condition. Boolean cols produce too
|
||||
# many results so it's unlikely that someone would want to join tables only using
|
||||
# boolean cols.
|
||||
non_boolean_types = set(type_ for type_ in TYPES if not issubclass(type_, Boolean))
|
||||
|
||||
if join_type != 'CROSS':
|
||||
join_clause.boolean_expr = self._combine_val_exprs(
|
||||
[self._create_relational_join_condition(
|
||||
table_expr,
|
||||
choice(from_clause.table_exprs),
|
||||
prefered_data_types=(non_boolean_types if idx == 0 else set()))
|
||||
for idx in xrange(one_or_more())],
|
||||
resulting_type=Boolean)
|
||||
return join_clause
|
||||
|
||||
def _create_relational_join_condition(self,
|
||||
left_table_expr,
|
||||
right_table_expr,
|
||||
prefered_data_types):
|
||||
# "base type" means condense all int types into just int, same for floats
|
||||
left_cols_by_base_type = left_table_expr.cols_by_base_type
|
||||
right_cols_by_base_type = right_table_expr.cols_by_base_type
|
||||
common_col_types = set(left_cols_by_base_type) & set(right_cols_by_base_type)
|
||||
if prefered_data_types:
|
||||
common_col_types &= prefered_data_types
|
||||
if common_col_types:
|
||||
col_type = choice(list(common_col_types))
|
||||
left = choice(left_cols_by_base_type[col_type])
|
||||
right = choice(right_cols_by_base_type[col_type])
|
||||
else:
|
||||
col_type = None
|
||||
if prefered_data_types:
|
||||
for available_col_types in (left_cols_by_base_type, right_cols_by_base_type):
|
||||
prefered_available_col_types = set(available_col_types) & prefered_data_types
|
||||
if prefered_available_col_types:
|
||||
col_type = choice(list(prefered_available_col_types))
|
||||
break
|
||||
if not col_type:
|
||||
col_type = choice(left_cols_by_base_type.keys())
|
||||
|
||||
if col_type in left_cols_by_base_type:
|
||||
left = choice(left_cols_by_base_type[col_type])
|
||||
else:
|
||||
left = choice(choice(left_cols_by_base_type.values()))
|
||||
left = self.convert_val_expr_to_type(left, col_type)
|
||||
if col_type in right_cols_by_base_type:
|
||||
right = choice(right_cols_by_base_type[col_type])
|
||||
else:
|
||||
right = choice(choice(right_cols_by_base_type.values()))
|
||||
right = self.convert_val_expr_to_type(right, col_type)
|
||||
return Equals(left, right)
|
||||
|
||||
def _create_where_clause(self, table_exprs):
|
||||
boolean_exprs = list()
|
||||
# Create one boolean expr per iteration...
|
||||
for _ in xrange(one_or_more()):
|
||||
col_type = None
|
||||
cols = list()
|
||||
# ...using one or more cols...
|
||||
for _ in xrange(one_or_more()):
|
||||
# ...from any random table, inline view, etc.
|
||||
table_expr = choice(table_exprs)
|
||||
if not col_type:
|
||||
col_type = choice(list(table_expr.cols_by_base_type))
|
||||
if col_type in table_expr.cols_by_base_type:
|
||||
col = choice(table_expr.cols_by_base_type[col_type])
|
||||
else:
|
||||
col = choice(table_expr.cols)
|
||||
cols.append(col)
|
||||
boolean_exprs.append(self._combine_val_exprs(cols, resulting_type=Boolean))
|
||||
return WhereClause(self._combine_val_exprs(boolean_exprs))
|
||||
|
||||
def _combine_val_exprs(self, vals, resulting_type=None):
|
||||
'''Combine the given vals into a single val.
|
||||
|
||||
If resulting_type is specified, the returned val will be of that type. If
|
||||
the resulting data type was not specified, it will be randomly chosen from the
|
||||
types of the input vals.
|
||||
|
||||
'''
|
||||
if not vals:
|
||||
raise Exception('At least one val is required')
|
||||
|
||||
types_to_vals = DataType.group_by_base_type(vals)
|
||||
|
||||
if not resulting_type:
|
||||
resulting_type = choice(types_to_vals.keys())
|
||||
|
||||
vals_of_resulting_type = list()
|
||||
|
||||
for val_type, vals in types_to_vals.iteritems():
|
||||
if issubclass(val_type, resulting_type):
|
||||
vals_of_resulting_type.extend(vals)
|
||||
elif resulting_type == Boolean:
|
||||
# To produce other result types, the vals will be aggd into a single val
|
||||
# then converted into the desired type. However to make a boolean, relational
|
||||
# operaters can be used on the vals to make a more realistic query.
|
||||
val = self._create_boolean_expr_from_vals_of_same_type(vals)
|
||||
vals_of_resulting_type.append(val)
|
||||
else:
|
||||
val = self._combine_vals_of_same_type(vals)
|
||||
if not (issubclass(val.type, Number) and issubclass(resulting_type, Number)):
|
||||
val = self.convert_val_expr_to_type(val, resulting_type)
|
||||
vals_of_resulting_type.append(val)
|
||||
|
||||
return self._combine_vals_of_same_type(vals_of_resulting_type)
|
||||
|
||||
def _create_boolean_expr_from_vals_of_same_type(self, vals):
|
||||
if not vals:
|
||||
raise Exception('At least one val is required')
|
||||
|
||||
if len(vals) == 1:
|
||||
val = vals[0]
|
||||
if Boolean == val.type:
|
||||
return val
|
||||
# Convert a single non-boolean val into a boolean using a func like
|
||||
# IsNull or IsNotNull.
|
||||
return choice(UNARY_BOOLEAN_FUNCS)(val)
|
||||
if len(vals) == 2:
|
||||
left, right = vals
|
||||
if left.type == right.type:
|
||||
if left.type == String:
|
||||
# Databases may vary in how string comparisons are done. Results may differ
|
||||
# when using operators like > or <, so just always use =.
|
||||
return Equals(left, right)
|
||||
if left.type == Boolean:
|
||||
# TODO: Enable "OR" at some frequency, using OR at 50% will probably produce
|
||||
# too many slow queries.
|
||||
return And(left, right)
|
||||
# At this point we've got two data points of the same type so any valid
|
||||
# relational operator is valid and will produce a boolean.
|
||||
return choice(RELATIONAL_OPERATORS)(left, right)
|
||||
elif issubclass(left.type, Number) and issubclass(right.type, Number):
|
||||
# Numbers need not be of the same type. SmallInt, BigInt, etc can all be compared.
|
||||
# Note: For now ints are the only numbers enabled and division is disabled
|
||||
# though AVG() is in use. If floats are enabled this will likely need to be
|
||||
# updated to do some rounding based comparison.
|
||||
return choice(RELATIONAL_OPERATORS)(left, right)
|
||||
raise Exception('Vals are not of the same type: %s<%s> vs %s<%s>'
|
||||
% (left, left.type, right, right.type))
|
||||
# Reduce the number of inputs and try again...
|
||||
left_subset, right_subset = random_non_empty_split(vals)
|
||||
return self._create_boolean_expr_from_vals_of_same_type([
|
||||
self._combine_vals_of_same_type(left_subset),
|
||||
self._combine_vals_of_same_type(right_subset)])
|
||||
|
||||
def _combine_vals_of_same_type(self, vals):
|
||||
'''Combine the given vals into a single expr of the same type. The input
|
||||
vals must be of the same base data type. For example Int's must not be mixed
|
||||
with Strings.
|
||||
|
||||
'''
|
||||
if not vals:
|
||||
raise Exception('At least one val is required')
|
||||
|
||||
val_type = None
|
||||
for val in vals:
|
||||
if not val_type:
|
||||
if issubclass(val.type, Number):
|
||||
val_type = Number
|
||||
else:
|
||||
val_type = val.type
|
||||
elif not issubclass(val.type, val_type):
|
||||
raise Exception('Incompatable types %s and %s' % (val_type, val.type))
|
||||
|
||||
if len(vals) == 1:
|
||||
return vals[0]
|
||||
|
||||
if val_type == Number:
|
||||
funcs = MATH_OPERATORS
|
||||
elif val_type == Boolean:
|
||||
# TODO: Enable "OR" at some frequency
|
||||
funcs = [And]
|
||||
elif val_type == String:
|
||||
funcs = BINARY_STRING_FUNCS
|
||||
return vals[0]
|
||||
elif val_type == Timestamp:
|
||||
funcs = [Greatest]
|
||||
|
||||
vals = list(vals)
|
||||
shuffle(vals)
|
||||
left = vals.pop()
|
||||
right = vals.pop()
|
||||
while True:
|
||||
func = choice(funcs)
|
||||
left = func(left, right)
|
||||
if not vals:
|
||||
return left
|
||||
right = vals.pop()
|
||||
|
||||
def convert_val_expr_to_type(self, val_expr, resulting_type):
|
||||
if resulting_type not in TYPES:
|
||||
raise Exception('Unexpected type: {}'.format(resulting_type))
|
||||
val_type = val_expr.type
|
||||
if issubclass(val_type, resulting_type):
|
||||
return val_expr
|
||||
|
||||
if issubclass(resulting_type, Int):
|
||||
if val_expr.returns_float:
|
||||
# Impala will FLOOR while Postgresql will ROUND. Use FLOOR to be conistent.
|
||||
return Floor(val_expr)
|
||||
if issubclass(resulting_type, Number):
|
||||
if val_expr.returns_string:
|
||||
return Length(val_expr)
|
||||
if issubclass(resulting_type, String):
|
||||
if val_expr.returns_float:
|
||||
# Different databases may use different precision.
|
||||
return Cast(Floor(val_expr), resulting_type)
|
||||
|
||||
return Cast(val_expr, resulting_type)
|
||||
|
||||
def _create_having_clause(self, table_exprs):
|
||||
boolean_exprs = list()
|
||||
# Create one boolean expr per iteration...
|
||||
for _ in xrange(one_or_more()):
|
||||
agg_items = list()
|
||||
# ...using one or more agg exprs...
|
||||
for _ in xrange(one_or_more()):
|
||||
vals = [self._create_agg_val_expr(table_exprs) for _ in xrange(one_or_more())]
|
||||
agg_items.append(self._combine_val_exprs(vals))
|
||||
boolean_exprs.append(self._combine_val_exprs(agg_items, resulting_type=Boolean))
|
||||
return HavingClause(self._combine_val_exprs(boolean_exprs))
|
||||
|
||||
def _enable_distinct_on_random_agg_items(self, agg_items):
|
||||
'''Randomly choose an agg func and set it to use DISTINCT'''
|
||||
# Impala has a limitation where 'DISTINCT' may only be applied to one agg
|
||||
# expr. If an agg expr is used more than once, each usage may
|
||||
# or may not include DISTINCT.
|
||||
#
|
||||
# Examples:
|
||||
# OK: SELECT COUNT(DISTINCT a) + SUM(DISTINCT a) + MAX(a)...
|
||||
# Not OK: SELECT COUNT(DISTINCT a) + COUNT(DISTINCT b)...
|
||||
#
|
||||
# Given a select list like:
|
||||
# COUNT(a), SUM(a), MAX(b)
|
||||
#
|
||||
# We want to ouput one of:
|
||||
# COUNT(DISTINCT a), SUM(DISTINCT a), AVG(b)
|
||||
# COUNT(DISTINCT a), SUM(a), AVG(b)
|
||||
# COUNT(a), SUM(a), AVG(DISTINCT b)
|
||||
#
|
||||
# This will be done by first grouping all agg funcs by their inner
|
||||
# expr:
|
||||
# {a: [COUNT(a), SUM(a)],
|
||||
# b: [MAX(b)]}
|
||||
#
|
||||
# then choosing a random val (which is a list of aggs) in the above dict, and
|
||||
# finaly randomly adding DISTINCT to items in the list.
|
||||
exprs_to_funcs = defaultdict(list)
|
||||
for item in agg_items:
|
||||
for expr, funcs in self._group_agg_funcs_by_expr(item.val_expr).iteritems():
|
||||
exprs_to_funcs[expr].extend(funcs)
|
||||
funcs = choice(exprs_to_funcs.values())
|
||||
for func in funcs:
|
||||
if random_boolean():
|
||||
func.distinct = True
|
||||
|
||||
def _group_agg_funcs_by_expr(self, val_expr):
|
||||
'''Group exprs and return a dict mapping the expr to the agg items
|
||||
it is used in.
|
||||
|
||||
Example: COUNT(a) * SUM(a) - MAX(b) + MIN(c) -> {a: [COUNT(a), SUM(a)],
|
||||
b: [MAX(b)],
|
||||
c: [MIN(c)]}
|
||||
|
||||
'''
|
||||
exprs_to_funcs = defaultdict(list)
|
||||
if isinstance(val_expr, AggFunc):
|
||||
exprs_to_funcs[tuple(val_expr.args)].append(val_expr)
|
||||
elif isinstance(val_expr, Func):
|
||||
for arg in val_expr.args:
|
||||
for expr, funcs in self._group_agg_funcs_by_expr(arg).iteritems():
|
||||
exprs_to_funcs[expr].extend(funcs)
|
||||
# else: The remaining case could happen if the original expr was something like
|
||||
# "SUM(a) + b + 1" where b is a GROUP BY field.
|
||||
return exprs_to_funcs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''Generate some queries for manual inspection. The query won't run anywhere because the
|
||||
tables used are fake. To make real queries, we'd need to connect to a database and
|
||||
read the table metadata and such.
|
||||
'''
|
||||
tables = list()
|
||||
data_types = TYPES
|
||||
data_types.remove(Float)
|
||||
data_types.remove(Double)
|
||||
for table_idx in xrange(5):
|
||||
table = Table('table_%s' % table_idx)
|
||||
tables.append(table)
|
||||
for col_idx in xrange(3):
|
||||
col_type = choice(data_types)
|
||||
col = Column(table, '%s_col_%s' % (col_type.__name__.lower(), col_idx), col_type)
|
||||
table.cols.append(col)
|
||||
|
||||
query_generator = QueryGenerator()
|
||||
from model_translator import SqlWriter
|
||||
sql_writer = SqlWriter.create()
|
||||
for _ in range(3000):
|
||||
query = query_generator.create_query(tables)
|
||||
print(sql_writer.write_query(query) + '\n')
|
||||
Reference in New Issue
Block a user