diff --git a/bin/copy-test-data.sh b/bin/copy-test-data.sh
index 56524dc6a..8fe63cbdd 100755
--- a/bin/copy-test-data.sh
+++ b/bin/copy-test-data.sh
@@ -14,4 +14,4 @@ DATASRC="a2226.halxg.cloudera.com:/data/1/workspace/impala-data"
DATADST=$IMPALA_HOME/testdata/impala-data
mkdir -p $DATADST
-scp -i $IMPALA_HOME/ssh_keys/id_rsa_impala -o "StrictHostKeyChecking=no" -r $DATASRC/* $DATADST
+scp -i $HOME/.ssh/id_rsa_jenkins -o "StrictHostKeyChecking=no" -r systest@$DATASRC/* $DATADST
diff --git a/tests/comparison/README b/tests/comparison/README
new file mode 100644
index 000000000..059de094d
--- /dev/null
+++ b/tests/comparison/README
@@ -0,0 +1,164 @@
+Purpose:
+
+This package is intended to augment the standard test suite. The standard tests are
+more efficient with regards to features tested versus execution time. However their
+coverage as a test suite still leaves gaps in query coverage. This package provides a
+random query generator to compare the results of a wide range of queries against a
+reference database engine. The queries will range from very simple single table selects to
+extremely complicated with multiple level of nesting. This method of testing will be
+slower but has a larger coverage area.
+
+
+Requirements:
+
+1) It's assumed that Impala is running locally.
+
+2) Impyla -- an implementation of DB API 2 for Impala.
+
+ sudo pip install git+http://github.com/laserson/impyla.git#impyla
+
+3) At least one python driver for a reference database.
+
+ sudo apt-get install python-mysqldb
+ sudo apt-get install python-psycopg2 # Postgresql
+
+
+Usage:
+
+1) Generate test data
+
+ ./data_generator.py --use-mysql
+
+ This will generate tables and data in MySQL and Impala
+
+
+2) Run the comparison
+
+ ./discrepancy_searcher.py
+
+ This will generate queries using the test database and compare the results against
+ MySQL (the default).
+
+
+Known Issues:
+
+1) Floats will produce false-positives. For example the results of a query that has
+
+ SELECT FLOOR(COUNT(...) * AVG(...)) AS col_1
+
+ will produce different results on Impala and MySQL if COUNT() == 3 and AVG() == 1/3.
+ One of the databasses will FLOOR(1) while the other will FLOOR(0.999).
+
+ Maybe this could be worked around or reduced by replacing all uses of AVG() with
+ "AVG() + foo", where foo is some number that makes it unlikely that
+ "COUNT() * (AVG() + foo)" will result in an int.
+
+ I'd guess this issue comes up in 1 out of 10-20k queries.
+
+2) Impyla may fail with "Invalid query handle". Some queries will fail every time when run
+ through Impyla but run fine through the impala-shell. I need to research more and file
+ an issue with Impyla.
+
+3) Impyla will fail with "Invalid session". I'm pretty sure this is also an Impyla issue
+ but also need to investigate more.
+
+
+Things to Know:
+
+1) A good number of queries to run seems to be about 5k. Ideally each test run would
+ discover the complete list of known issues. From experience a 1k query test run may
+ complete without finding any issues that were discovered in previous runs. 5k seems
+ to be about the magic number were most issues will be rediscovered. This can take 1-2
+ hours. However as of this writing it's rare to run 1k queries without finding at
+ least one discrepancy.
+
+2) It's possible to provide a randomization seed so that the randomness is actually
+ reproducable. The data generation currently has a default seed so will always produce
+ the same tables. This also mean if a new data type is added those generated tables
+ will change.
+
+3) There is a query log. It's possible that a sequence of queries is required to expose
+ a bug. If you come across a failure that can't be reproduced by rerunning the failed
+ query, try running the queries leading up to that query as well.
+
+
+Miscellaneous:
+
+1) Instead of generating new random queries with each run, it may be better to reuse a
+ list of queries from a previous run that are known to produce results. As of this
+ writing only about 50% of queries produce results. So it may be better to trade high
+ randomness for higher quality queries. For example it would be possible to build up a
+ library of 100k queries that produce results then randomly select 2.5k of those.
+ Maybe that would provide testing equivalent to 5k totally random queries in less
+ time.
+
+ This would also be useful in eliminating queries that have known issues above.
+
+
+Postgresql:
+
+1) Supports bascially all Impala language features
+
+2) Does int division, 1 / 2 == 0
+
+3) Has strange sorting of strings, '-1' > '1'. This may be important if ORDER BY is ever
+ used. The databases being compared would need to have the same collation, which is
+ probably configurable.
+
+4) This was the original reference database but I moved to MySQL while trying to add
+ support for floats and never moved back.
+
+
+MySQL:
+
+1) Supports bascially all Impala language features, except WITH clause requires emulation
+ with inline views.
+
+2) Has poor boolean support. It may be worth switching back to Postgresql for this.
+
+
+Improvements:
+
+1) Add support for simplifing buggy queries. When a random query fails the comparison
+ check it is basically always much too complex for directly posting a bug report. It
+ is also time consuming to simplify the queries because there is a lot of trial and
+ error and manually editing queries.
+
+2) Add more language features
+
+ a) SEMI JOIN
+ b) ORDER BY
+ c) LIMIT, OFFSET
+ d) CASE / WHEN
+
+3) Add common built-in functions. Ex: CAST, IF, NVL, ...
+
+4) Make randomization of the query generation configurable. As of this writing all the
+ probabilities are hard-coded. At a very minimum it should be easy to disable or force
+ the use of some language features such as CROSS JOIN, GROUP BY, etc.
+
+5) More investingation of using the existing "functional" test datasets. A very quick
+ trial run wasn't successful but another attempt with more effort should be made before
+ introducing a new dataset.
+
+ I suspect the problem with using the functional dataset was that I only imported a few
+ tables, maybe alltypes, alltypesagg, and something else. I don't think I imported the
+ tiny tables since the odds of them producing results from a random query would be
+ very low.
+
+6) If the functional dataset cannot be used, someone should think more about what the
+ random data should be like. Only a few minutes of thought were put into selecting
+ random value ranges (including number of tables and columns), and it's not clear how
+ important those ranges are.
+
+7) Add support for comparing results with codegen enabled and disabled. Uri recently added
+ support for query options in Impyla.
+
+8) Consider adding Oracle or SQL Server support, these could be useful in the future for
+ analytic queries.
+
+9) Try running with tables in various formats. Ex: parquet and/or avro.
+
+10) Support for more data types. Only int types are known to give good results.
+ Floats may work but non-numeric types are not supported yet.
+
diff --git a/tests/comparison/__init__.py b/tests/comparison/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/comparison/data_generator.py b/tests/comparison/data_generator.py
new file mode 100755
index 000000000..9944dd75d
--- /dev/null
+++ b/tests/comparison/data_generator.py
@@ -0,0 +1,409 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+'''This module provides random data generation and database population.
+
+ When this module is run directly for purposes of database population, the default is
+ to use a fixed seed for randomization. The result should be that the generated random
+ data is the same regardless of when or where the execution is done.
+
+'''
+
+from datetime import datetime, timedelta
+from logging import basicConfig, getLogger
+from random import choice, randint, random, seed, uniform
+
+from tests.comparison.db_connector import (
+ DbConnector,
+ IMPALA,
+ MYSQL,
+ POSTGRESQL)
+from tests.comparison.model import (
+ Boolean,
+ Column,
+ Float,
+ Int,
+ Number,
+ String,
+ Table,
+ Timestamp,
+ TYPES)
+
+LOG = getLogger(__name__)
+
+class RandomValGenerator(object):
+ '''This class will generate random data of various data types. Currently only numeric
+ and string data types are supported.
+
+ '''
+
+ def __init__(self,
+ min_number=-1000,
+ max_number=1000,
+ min_date=datetime(1990, 1, 1),
+ max_date=datetime(2030, 1, 1),
+ null_val_percentage=0.1):
+ self.min_number = min_number
+ self.max_number = max_number
+ self.min_date = min_date
+ self.max_date = max_date
+ self.null_val_percentage = null_val_percentage
+
+ def generate_val(self, val_type):
+ '''Generate and return a single random val. Use the val_type parameter to
+ specify the type of val to generate. See model.DataType for valid val_type
+ options.
+
+ Ex:
+ generator = RandomValGenerator(min_number=1, max_number=5)
+ val = generator.generate_val(model.Int)
+ assert 1 <= val and val <= 5
+ '''
+ if issubclass(val_type, String):
+ val = self.generate_val(Int)
+ return None if val is None else str(val)
+ if random() < self.null_val_percentage:
+ return None
+ if issubclass(val_type, Int):
+ return randint(
+ max(self.min_number, val_type.MIN), min(val_type.MAX, self.max_number))
+ if issubclass(val_type, Number):
+ return uniform(self.min_number, self.max_number)
+ if issubclass(val_type, Timestamp):
+ delta = self.max_date - self.min_date
+ delta_in_seconds = delta.days * 24 * 60 * 60 + delta.seconds
+ offset_in_seconds = randint(0, delta_in_seconds)
+ val = self.min_date + timedelta(0, offset_in_seconds)
+ return datetime(val.year, val.month, val.day)
+ if issubclass(val_type, Boolean):
+ return randint(0, 1) == 1
+ raise Exception('Unsupported type %s' % val_type.__name__)
+
+
+class DatabasePopulator(object):
+ '''This class will populate a database with randomly generated data. The population
+ includes table creation and data generation. Table names are hard coded as
+ table_
.
+
+ '''
+
+ def __init__(self):
+ self.val_generator = RandomValGenerator()
+
+ def populate_db_with_random_data(self,
+ db_name,
+ db_connectors,
+ number_of_tables=10,
+ allowed_data_types=TYPES,
+ create_files=False):
+ '''Create tables with a random number of cols with data types chosen from
+ allowed_data_types, then fill the tables with data.
+
+ The given db_name must have already been created.
+
+ '''
+ connections = [connector.create_connection(db_name=db_name)
+ for connector in db_connectors]
+ for table_idx in xrange(number_of_tables):
+ table = self.create_random_table(
+ 'table_%s' % (table_idx + 1),
+ allowed_data_types=allowed_data_types)
+ for connection in connections:
+ sql = self.make_create_table_sql(table, dialect=connection.db_type)
+ LOG.info('Creating %s table %s', connection.db_type, table.name)
+ if create_files:
+ with open('%s_%s.sql' % (table.name, connection.db_type.lower()), 'w') \
+ as f:
+ f.write(sql + '\n')
+ connection.execute(sql)
+ LOG.info('Inserting data into %s', table.name)
+ for _ in xrange(100): # each iteration will insert 100 rows
+ rows = self.generate_table_data(table)
+ for connection in connections:
+ sql = self.make_insert_sql_from_data(
+ table, rows, dialect=connection.db_type)
+ if create_files:
+ with open('%s_%s.sql' %
+ (table.name, connection.db_type.lower()), 'a') as f:
+ f.write(sql + '\n')
+ try:
+ connection.execute(sql)
+ except:
+ LOG.error('Error executing SQL: %s', sql)
+ raise
+
+ self.index_tables_in_database(connections)
+
+ for connection in connections:
+ connection.close()
+
+ def migrate_database(self,
+ db_name,
+ source_db_connector,
+ destination_db_connectors,
+ include_table_names=None):
+ '''Read table metadata and data from the source database and create a replica in
+ the destination databases. For example, the Impala funcal test database could
+ be copied into Postgresql.
+
+ source_db_connector and items in destination_db_connectors should be
+ of type db_connector.DbConnector. destination_db_connectors and
+ include_table_names should be iterables.
+ '''
+ source_connection = source_db_connector.create_connection(db_name)
+
+ cursors = [connector.create_connection(db_name=db_name).create_cursor()
+ for connector in destination_db_connectors]
+
+ for table_name in source_connection.list_table_names():
+ if include_table_names and table_name not in include_table_names:
+ continue
+ try:
+ table = source_connection.describe_table(table_name)
+ except Exception as e:
+ LOG.warn('Error fetching metadata for %s: %s', table_name, e)
+ continue
+ for destination_cursor in cursors:
+ sql = self.make_create_table_sql(
+ table, dialect=destination_cursor.connection.db_type)
+ destination_cursor.execute(sql)
+ with source_connection.open_cursor() as source_cursor:
+ try:
+ source_cursor.execute('SELECT * FROM ' + table_name)
+ while True:
+ rows = source_cursor.fetchmany(size=100)
+ if not rows:
+ break
+ for destination_cursor in cursors:
+ sql = self.make_insert_sql_from_data(
+ table, rows, dialect=destination_cursor.connection.db_type)
+ destination_cursor.execute(sql)
+ except Exception as e:
+ LOG.error('Error fetching data for %s: %s', table_name, e)
+ continue
+
+ self.index_tables_in_database([cursor.connection for cursor in cursors])
+
+ for cursor in cursors:
+ cursor.close()
+ cursor.connection.close()
+
+ def create_random_table(self, table_name, allowed_data_types):
+ '''Create and return a Table with a random number of cols chosen from the
+ given allowed_data_types.
+
+ '''
+ data_type_count = len(allowed_data_types)
+ col_count = randint(data_type_count / 2, data_type_count * 2)
+ table = Table(table_name)
+ for col_idx in xrange(col_count):
+ col_type = choice(allowed_data_types)
+ col = Column(
+ table,
+ '%s_col_%s' % (col_type.__name__.lower(), col_idx + 1),
+ col_type)
+ table.cols.append(col)
+ return table
+
+ def make_create_table_sql(self, table, dialect=IMPALA):
+ sql = 'CREATE TABLE %s (%s)' % (
+ table.name,
+ ', '.join('%s %s' %
+ (col.name, self.get_sql_for_data_type(col.type, dialect)) +
+ ('' if dialect == IMPALA else ' NULL')
+ for col in table.cols))
+ if dialect == MYSQL:
+ sql += ' ENGINE = MYISAM'
+ return sql
+
+ def get_sql_for_data_type(self, data_type, dialect=IMPALA):
+ # Check to see if there is an alias and if so, use the first one
+ if hasattr(data_type, dialect):
+ return getattr(data_type, dialect)[0]
+ return data_type.__name__.upper()
+
+ def make_insert_sql_from_data(self, table, rows, dialect=IMPALA):
+ # TODO: Consider using parameterized inserts so the database connector handles
+ # formatting the data. For example the CAST to workaround IMPALA-803 can
+ # probably be removed. The vals were generated this way so a data file
+ # could be made and attached to jiras.
+ if not rows:
+ raise Exception('At least one row is required')
+ if not table.cols:
+ raise Exception('At least one col is required')
+
+ sql = 'INSERT INTO %s VALUES ' % table.name
+ for row_idx, row in enumerate(rows):
+ if row_idx > 0:
+ sql += ', '
+ sql += '('
+ for col_idx, col in enumerate(table.cols):
+ if col_idx > 0:
+ sql += ', '
+ val = row[col_idx]
+ if val is None:
+ sql += 'NULL'
+ elif issubclass(col.type, Timestamp):
+ if dialect != IMPALA:
+ sql += 'TIMESTAMP '
+ sql += "'%s'" % val
+ elif issubclass(col.type, String):
+ val = val.replace("'", "''")
+ if dialect == POSTGRESQL:
+ val = val.replace('\\', '\\\\')
+ sql += "'%s'" % val
+ elif dialect == IMPALA \
+ and issubclass(col.type, Float):
+ # https://issues.cloudera.org/browse/IMPALA-803
+ sql += 'CAST(%s AS FLOAT)' % val
+ else:
+ sql += str(val)
+ sql += ')'
+ return sql
+
+ def generate_table_data(self, table, number_of_rows=100):
+ rows = list()
+ for row_idx in xrange(number_of_rows):
+ row = list()
+ for col in table.cols:
+ row.append(self.val_generator.generate_val(col.type))
+ rows.append(row)
+ return rows
+
+ def drop_and_create_database(self, db_name, db_connectors):
+ for connector in db_connectors:
+ with connector.open_connection() as connection:
+ connection.drop_db_if_exists(db_name)
+ connection.execute('CREATE DATABASE ' + db_name)
+
+ def index_tables_in_database(self, connections):
+ for connection in connections:
+ if connection.supports_index_creation:
+ for table_name in connection.list_table_names():
+ LOG.info('Indexing %s on %s' % (table_name, connection.db_type))
+ connection.index_table(table_name)
+
+if __name__ == '__main__':
+ from optparse import NO_DEFAULT, OptionGroup, OptionParser
+
+ parser = OptionParser(
+ usage='usage: \n'
+ ' %prog [options] [populate]\n\n'
+ ' Create and populate database(s). The Impala database will always be \n'
+ ' included, the other database types are optional.\n\n'
+ ' %prog [options] migrate\n\n'
+ ' Migrate an Impala database to another database type. The destination \n'
+ ' database will be dropped and recreated.')
+ parser.add_option('--log-level', default='INFO',
+ help='The log level to use.', choices=('DEBUG', 'INFO', 'WARN', 'ERROR'))
+ parser.add_option('--db-name', default='randomness',
+ help='The name of the database to use. Ex: functional.')
+
+ group = OptionGroup(parser, 'MySQL Options')
+ group.add_option('--use-mysql', action='store_true', default=False,
+ help='Use MySQL')
+ group.add_option('--mysql-host', default='localhost',
+ help='The name of the host running the MySQL database.')
+ group.add_option('--mysql-port', default=3306, type=int,
+ help='The port of the host running the MySQL database.')
+ group.add_option('--mysql-user', default='root',
+ help='The user name to use when connecting to the MySQL database.')
+ group.add_option('--mysql-password',
+ help='The password to use when connecting to the MySQL database.')
+ parser.add_option_group(group)
+
+ group = OptionGroup(parser, 'Postgresql Options')
+ group.add_option('--use-postgresql', action='store_true', default=False,
+ help='Use Postgresql')
+ group.add_option('--postgresql-host', default='localhost',
+ help='The name of the host running the Postgresql database.')
+ group.add_option('--postgresql-port', default=5432, type=int,
+ help='The port of the host running the Postgresql database.')
+ group.add_option('--postgresql-user', default='postgres',
+ help='The user name to use when connecting to the Postgresql database.')
+ group.add_option('--postgresql-password',
+ help='The password to use when connecting to the Postgresql database.')
+ parser.add_option_group(group)
+
+ group = OptionGroup(parser, 'Database Population Options')
+ group.add_option('--randomization-seed', default=1, type='int',
+ help='The randomization will be initialized with this seed. Using the same seed '
+ 'will produce the same results across runs.')
+ group.add_option('--create-data-files', default=False, action='store_true',
+ help='Create files that can be used to repopulate the databasese elsewhere.')
+ group.add_option('--table-count', default=10, type='int',
+ help='The number of tables to generate.')
+ parser.add_option_group(group)
+
+ group = OptionGroup(parser, 'Database Migration Options')
+ group.add_option('--migrate-table-names',
+ help='Table names should be separated with commas. The default is to migrate all '
+ 'tables.')
+ parser.add_option_group(group)
+
+ for group in parser.option_groups + [parser]:
+ for option in group.option_list:
+ if option.default != NO_DEFAULT:
+ option.help += ' [default: %default]'
+
+ options, args = parser.parse_args()
+ command = args[0] if args else 'populate'
+ if len(args) > 1 or command not in ['populate', 'migrate']:
+ raise Exception('Command must either be "populate" or "migrate" but was "%s"' %
+ ' '.join(args))
+ if command == 'migrate' and not any((options.use_mysql, options.use_postgresql)):
+ raise Exception('At least one destination database must be chosen with '
+ '--use-')
+
+ basicConfig(level=options.log_level)
+
+ seed(options.randomization_seed)
+
+ impala_connector = DbConnector(IMPALA)
+ db_connectors = []
+ if options.use_postgresql:
+ db_connectors.append(DbConnector(POSTGRESQL,
+ user_name=options.postgresql_user,
+ password=options.postgresql_password,
+ host_name=options.postgresql_host,
+ port=options.postgresql_port))
+ if options.use_mysql:
+ db_connectors.append(DbConnector(MYSQL,
+ user_name=options.mysql_user,
+ password=options.mysql_password,
+ host_name=options.mysql_host,
+ port=options.mysql_port))
+
+ populator = DatabasePopulator()
+ if command == 'populate':
+ db_connectors.append(impala_connector)
+ populator.drop_and_create_database(options.db_name, db_connectors)
+ populator.populate_db_with_random_data(
+ options.db_name,
+ db_connectors,
+ number_of_tables=options.table_count,
+ create_files=options.create_data_files)
+ else:
+ populator.drop_and_create_database(options.db_name, db_connectors)
+ if options.migrate_table_names:
+ table_names = options.migrate_table_names.split(',')
+ else:
+ table_names = None
+ populator.migrate_database(
+ options.db_name,
+ impala_connector,
+ db_connectors,
+ include_table_names=table_names)
diff --git a/tests/comparison/db_connector.py b/tests/comparison/db_connector.py
new file mode 100644
index 000000000..5e9c98aa4
--- /dev/null
+++ b/tests/comparison/db_connector.py
@@ -0,0 +1,411 @@
+# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+'''This module is intended to standardize workflows when working with various databases
+ such as Impala, Postgresql, etc. Even with pep-249 (DB API 2), workflows differ
+ slightly. For example Postgresql does not allow changing databases from within a
+ connection, instead a new connection must be made. However Impala does not allow
+ specifying a database upon connection, instead a cursor must be created and a USE
+ command must be issued.
+
+'''
+
+from contextlib import contextmanager
+try:
+ from impala.dbapi import connect as impala_connect
+except:
+ print('Error importing impyla. Please make sure it is installed. '
+ 'See the README for details.')
+ raise
+from itertools import izip
+from logging import getLogger
+from tests.comparison.model import Column, Table, TYPES, String
+
+LOG = getLogger(__name__)
+
+IMPALA = 'IMPALA'
+POSTGRESQL = 'POSTGRESQL'
+MYSQL = 'MYSQL'
+
+DATABASES = [IMPALA, POSTGRESQL, MYSQL]
+
+mysql_connect = None
+postgresql_connect = None
+
+class DbConnector(object):
+ '''Wraps a DB API 2 implementation to provide a standard way of obtaining a
+ connection and selecting a database.
+
+ Any database that supports transactions will have auto-commit enabled.
+
+ '''
+
+ def __init__(self, db_type, user_name=None, password=None, host_name=None, port=None):
+ self.db_type = db_type.upper()
+ if self.db_type not in DATABASES:
+ raise Exception('Unsupported database: %s' % db_type)
+ self.user_name = user_name
+ self.password = password
+ self.host_name = host_name or 'localhost'
+ self.port = port
+
+ def create_connection(self, db_name=None):
+ if self.db_type == IMPALA:
+ connection_class = ImpalaDbConnection
+ connection = impala_connect(host=self.host_name, port=self.port or 21050)
+ elif self.db_type == POSTGRESQL:
+ connection_class = PostgresqlDbConnection
+ connection_args = {'user': self.user_name or 'postgres'}
+ if self.password:
+ connection_args['password'] = self.password
+ if db_name:
+ connection_args['database'] = db_name
+ if self.host_name:
+ connection_args['host'] = self.host_name
+ if self.port:
+ connection_args['port'] = self.port
+ global postgresql_connect
+ if not postgresql_connect:
+ try:
+ from psycopg2 import connect as postgresql_connect
+ except:
+ print('Error importing psycopg2. Please make sure it is installed. '
+ 'See the README for details.')
+ raise
+ connection = postgresql_connect(**connection_args)
+ connection.autocommit = True
+ elif self.db_type == MYSQL:
+ connection_class = MySQLDbConnection
+ connection_args = {'user': self.user_name or 'root'}
+ if self.password:
+ connection_args['passwd'] = self.password
+ if db_name:
+ connection_args['db'] = db_name
+ if self.host_name:
+ connection_args['host'] = self.host_name
+ if self.port:
+ connection_args['port'] = self.port
+ global mysql_connect
+ if not mysql_connect:
+ try:
+ from MySQLdb import connect as mysql_connect
+ except:
+ print('Error importing MySQLdb. Please make sure it is installed. '
+ 'See the README for details.')
+ raise
+ connection = mysql_connect(**connection_args)
+ else:
+ raise Exception('Unexpected database type: %s' % self.db_type)
+ return connection_class(self, connection, db_name=db_name)
+
+ @contextmanager
+ def open_connection(self, db_name=None):
+ connection = None
+ try:
+ connection = self.create_connection(db_name=db_name)
+ yield connection
+ finally:
+ if connection:
+ try:
+ connection.close()
+ except Exception as e:
+ LOG.debug('Error closing connection: %s', e, exc_info=True)
+
+
+class DbConnection(object):
+ '''Wraps a DB API 2 connection. Instances should only be obtained through the
+ DbConnector.create_connection(...) method.
+
+ '''
+
+ @staticmethod
+ def describe_common_tables(db_connections, filter_col_types=[]):
+ '''Find and return a list of Table objects that the given connections have in
+ common.
+
+ @param filter_col_types: Ignore any cols if they are of a data type contained
+ in this collection.
+
+ '''
+ common_table_names = None
+ for db_connection in db_connections:
+ table_names = set(db_connection.list_table_names())
+ if common_table_names is None:
+ common_table_names = table_names
+ else:
+ common_table_names &= table_names
+ common_table_names = sorted(common_table_names)
+
+ tables = list()
+ for table_name in common_table_names:
+ common_table = None
+ mismatch = False
+ for db_connection in db_connections:
+ table = db_connection.describe_table(table_name)
+ table.cols = [col for col in table.cols if col.type not in filter_col_types]
+ if common_table is None:
+ common_table = table
+ continue
+ if len(common_table.cols) != len(table.cols):
+ LOG.debug('Ignoring table %s.'
+ ' It has a different number of columns across databases.', table_name)
+ mismatch = True
+ break
+ for left, right in izip(common_table.cols, table.cols):
+ if not left.name == right.name and left.type == right.type:
+ LOG.debug('Ignoring table %s. It has different columns %s vs %s.' %
+ (table_name, left, right))
+ mismatch = True
+ break
+ if mismatch:
+ break
+ if not mismatch:
+ tables.append(common_table)
+
+ return tables
+
+ def __init__(self, connector, connection, db_name=None):
+ self.connector = connector
+ self.connection = connection
+ self.db_name = db_name
+
+ @property
+ def db_type(self):
+ return self.connector.db_type
+
+ def create_cursor(self):
+ return DatabaseCursor(self.connection.cursor(), self)
+
+ @contextmanager
+ def open_cursor(self):
+ '''Returns a new cursor for use in a "with" statement. When the "with" statement ends,
+ the cursor will be closed.
+
+ '''
+ cursor = None
+ try:
+ cursor = self.create_cursor()
+ yield cursor
+ finally:
+ self.close_cursor_quietly(cursor)
+
+ def close_cursor_quietly(self, cursor):
+ if cursor:
+ try:
+ cursor.close()
+ except Exception as e:
+ LOG.debug('Error closing cursor: %s', e, exc_info=True)
+
+ def list_db_names(self):
+ '''Return a list of database names always in lowercase.'''
+ rows = self.execute_and_fetchall(self.make_list_db_names_sql())
+ return [row[0].lower() for row in rows]
+
+ def make_list_db_names_sql(self):
+ return 'SHOW DATABASES'
+
+ def list_table_names(self):
+ '''Return a list of table names always in lowercase.'''
+ rows = self.execute_and_fetchall(self.make_list_table_names_sql())
+ return [row[0].lower() for row in rows]
+
+ def make_list_table_names_sql(self):
+ return 'SHOW TABLES'
+
+ def describe_table(self, table_name):
+ '''Return a Table with table and col names always in lowercase.'''
+ rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name))
+ table = Table(table_name.lower())
+ for row in rows:
+ col_name, data_type = row[:2]
+ table.cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type)))
+ return table
+
+ def make_describe_table_sql(self, table_name):
+ return 'DESCRIBE ' + table_name
+
+ def parse_data_type(self, sql):
+ sql = sql.upper()
+ # Types may have declared a database specific alias
+ for type_ in TYPES:
+ if sql in getattr(type_, self.db_type, []):
+ return type_
+ for type_ in TYPES:
+ if type_.__name__.upper() == sql:
+ return type_
+ if 'CHAR' in sql:
+ return String
+ raise Exception('Unknown data type: ' + sql)
+
+ def create_database(self, db_name):
+ db_name = db_name.lower()
+ with self.open_cursor() as cursor:
+ cursor.execute('CREATE DATABASE ' + db_name)
+
+ def drop_db_if_exists(self, db_name):
+ '''This should not be called from a connection to the database being dropped.'''
+ db_name = db_name.lower()
+ if db_name not in self.list_db_names():
+ return
+ if self.db_name and self.db_name.lower() == db_name:
+ raise Exception('Cannot drop database while still connected to it')
+ self.drop_database(db_name)
+
+ def drop_database(self, db_name):
+ db_name = db_name.lower()
+ self.execute('DROP DATABASE ' + db_name)
+
+ @property
+ def supports_index_creation(self):
+ return True
+
+ def index_table(self, table_name):
+ table = self.describe_table(table_name)
+ with self.open_cursor() as cursor:
+ for col in table.cols:
+ index_name = '%s_%s' % (table_name, col.name)
+ if self.db_name:
+ index_name = '%s_%s' % (self.db_name, index_name)
+ cursor.execute('CREATE INDEX %s ON %s(%s)' % (index_name, table_name, col.name))
+
+ @property
+ def supports_kill_connection(self):
+ return False
+
+ def kill_connection(self):
+ '''Kill the current connection and any currently running queries assosiated with the
+ connection.
+ '''
+ raise Exception('Killing connection is not supported')
+
+ def materialize_query(self, query_as_text, table_name):
+ self.execute('CREATE TABLE %s AS %s' % (table_name.lower(), query_as_text))
+
+ def drop_table(self, table_name):
+ self.execute('DROP TABLE ' + table_name.lower())
+
+ def execute(self, sql):
+ with self.open_cursor() as cursor:
+ cursor.execute(sql)
+
+ def execute_and_fetchall(self, sql):
+ with self.open_cursor() as cursor:
+ cursor.execute(sql)
+ return cursor.fetchall()
+
+ def close(self):
+ '''Close the underlying connection.'''
+ self.connection.close()
+
+ def reconnect(self):
+ self.close()
+ other = self.connector.create_connection(db_name=self.db_name)
+ self.connection = other.connection
+
+
+class DatabaseCursor(object):
+ '''Wraps a DB API 2 cursor to provide access to the related connection. This class
+ implements the DB API 2 interface by delegation.
+
+ '''
+
+ def __init__(self, cursor, connection):
+ self.cursor = cursor
+ self.connection = connection
+
+ def __getattr__(self, attr):
+ return getattr(self.cursor, attr)
+
+
+class ImpalaDbConnection(DbConnection):
+
+ def create_cursor(self):
+ cursor = DbConnection.create_cursor(self)
+ if self.db_name:
+ cursor.execute('USE %s' % self.db_name)
+ return cursor
+
+ def drop_database(self, db_name):
+ '''This should not be called from a connection to the database being dropped.'''
+ db_name = db_name.lower()
+ with self.connector.open_connection(db_name) as list_tables_connection:
+ with list_tables_connection.open_cursor() as drop_table_cursor:
+ for table_name in list_tables_connection.list_table_names():
+ drop_table_cursor.execute('DROP TABLE ' + table_name)
+ self.execute('DROP DATABASE ' + db_name)
+
+ @property
+ def supports_index_creation(self):
+ return False
+
+
+class PostgresqlDbConnection(DbConnection):
+
+ def make_list_db_names_sql(self):
+ return 'SELECT datname FROM pg_database'
+
+ def make_list_table_names_sql(self):
+ return '''
+ SELECT table_name
+ FROM information_schema.tables
+ WHERE table_schema = 'public' '''
+
+ def make_describe_table_sql(self, table_name):
+ return '''
+ SELECT column_name, data_type
+ FROM information_schema.columns
+ WHERE table_name = '%s'
+ ORDER BY ordinal_position''' % table_name
+
+
+class MySQLDbConnection(DbConnection):
+
+ def __init__(self, connector, connection, db_name=None):
+ DbConnection.__init__(self, connector, connection, db_name=db_name)
+ self.session_id = self.execute_and_fetchall('SELECT connection_id()')[0][0]
+
+ def describe_table(self, table_name):
+ '''Return a Table with table and col names always in lowercase.'''
+ rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name))
+ table = Table(table_name.lower())
+ for row in rows:
+ col_name, data_type = row[:2]
+ if data_type == 'tinyint(1)':
+ # Just assume this is a boolean...
+ data_type = 'boolean'
+ if '(' in data_type:
+ # Strip the size of the data type
+ data_type = data_type[:data_type.index('(')]
+ table.cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type)))
+ return table
+
+ @property
+ def supports_kill_connection(self):
+ return True
+
+ def kill_connection(self):
+ with self.connector.open_connection(db_name=self.db_name) as connection:
+ connection.execute('KILL %s' % (self.session_id))
+
+ def index_table(self, table_name):
+ table = self.describe_table(table_name)
+ with self.open_cursor() as cursor:
+ for col in table.cols:
+ try:
+ cursor.execute('ALTER TABLE %s ADD INDEX (%s)' % (table_name, col.name))
+ except Exception as e:
+ if 'Incorrect index name' not in str(e):
+ raise
+ # Some sort of MySQL bug...
+ LOG.warn('Could not create index on %s.%s: %s' % (table_name, col.name, e))
diff --git a/tests/comparison/discrepancy_searcher.py b/tests/comparison/discrepancy_searcher.py
new file mode 100755
index 000000000..1ebfb3fd9
--- /dev/null
+++ b/tests/comparison/discrepancy_searcher.py
@@ -0,0 +1,529 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+'''This module will run random queries against existing databases and compare the
+ results.
+
+'''
+
+from contextlib import closing
+from decimal import Decimal
+from itertools import izip, izip_longest
+from logging import basicConfig, getLogger
+from math import isinf, isnan
+from os import getenv, remove
+from os.path import exists, join
+from shelve import open as open_shelve
+from subprocess import call
+from threading import current_thread, Thread
+from tempfile import gettempdir
+from time import time
+
+from tests.comparison.db_connector import (
+ DbConnection,
+ DbConnector,
+ IMPALA,
+ MYSQL,
+ POSTGRESQL)
+from tests.comparison.model import BigInt, TYPES
+from tests.comparison.query_generator import QueryGenerator
+from tests.comparison.model_translator import SqlWriter
+
+LOG = getLogger(__name__)
+
+class QueryResultComparator(object):
+
+ # If the number of rows * cols is greater than this val, then the comparison will
+ # be aborted. Raising this value also raises the risk of python being OOM killed. At
+ # 10M python would get OOM killed occasionally even on a physical machine with 32GB
+ # ram.
+ TOO_MUCH_DATA = 1000 * 1000
+
+ # Used when comparing float vals
+ EPSILON = 0.1
+
+ # The decimal vals will be rounded before comparison
+ DECIMAL_PLACES = 2
+
+ def __init__(self, impala_connection, reference_connection):
+ self.reference_db_type = reference_connection.db_type
+
+ self.impala_cursor = impala_connection.create_cursor()
+ self.reference_cursor = reference_connection.create_cursor()
+
+ self.impala_sql_writer = SqlWriter.create(dialect=impala_connection.db_type)
+ self.reference_sql_writer = SqlWriter.create(dialect=reference_connection.db_type)
+
+ # At this time the connection will be killed and ther comparison result will be
+ # timeout.
+ self.query_timeout_seconds = 3 * 60
+
+ def compare_query_results(self, query):
+ '''Execute the query, compare the data, and return a summary of the result.'''
+ comparison_result = ComparisonResult(query, self.reference_db_type)
+
+ reference_data_set = None
+ impala_data_set = None
+ # Impala doesn't support getting the row count without getting the rows too. So run
+ # the query on the other database first.
+ try:
+ for sql_writer, cursor in ((self.reference_sql_writer, self.reference_cursor),
+ (self.impala_sql_writer, self.impala_cursor)):
+ self.execute_query(cursor, sql_writer.write_query(query))
+ if (cursor.rowcount * len(query.select_clause.select_items)) > self.TOO_MUCH_DATA:
+ comparison_result.exception = Exception('Too much data to compare')
+ return comparison_result
+ if reference_data_set is None:
+ # MySQL returns a tuple of rows but a list is needed for sorting
+ reference_data_set = list(cursor.fetchall())
+ comparison_result.reference_row_count = len(reference_data_set)
+ else:
+ impala_data_set = cursor.fetchall()
+ comparison_result.impala_row_count = len(impala_data_set)
+ except Exception as e:
+ comparison_result.exception = e
+ LOG.debug('Error running query: %s', e, exc_info=True)
+ return comparison_result
+
+ comparison_result.query_resulted_in_data = (comparison_result.impala_row_count > 0
+ or comparison_result.reference_row_count > 0)
+
+ if comparison_result.impala_row_count != comparison_result.reference_row_count:
+ return comparison_result
+
+ for data_set in (reference_data_set, impala_data_set):
+ for row_idx, row in enumerate(data_set):
+ data_set[row_idx] = [self.standardize_data(data) for data in row]
+ data_set.sort(cmp=self.row_sort_cmp)
+
+ for impala_row, reference_row in \
+ izip_longest(impala_data_set, reference_data_set):
+ for col_idx, (impala_val, reference_val) \
+ in enumerate(izip_longest(impala_row, reference_row)):
+ if not self.vals_are_equal(impala_val, reference_val):
+ if isinstance(impala_val, int) \
+ and isinstance(reference_val, (int, float, Decimal)) \
+ and abs(reference_val) > BigInt.MAX:
+ # Impala will return incorrect results if the val is greater than max bigint
+ comparison_result.exception = KnownError(
+ 'https://issues.cloudera.org/browse/IMPALA-865')
+ elif isinstance(impala_val, float) \
+ and (isinf(impala_val) or isnan(impala_val)):
+ # In some cases, Impala gives NaNs and Infs instead of NULLs
+ comparison_result.exception = KnownError(
+ 'https://issues.cloudera.org/browse/IMPALA-724')
+ comparison_result.impala_row = impala_row
+ comparison_result.reference_row = reference_row
+ comparison_result.mismatch_at_row_number = row_idx + 1
+ comparison_result.mismatch_at_col_number = col_idx + 1
+ return comparison_result
+
+ if len(impala_data_set) == 1:
+ for val in impala_data_set[0]:
+ if val:
+ break
+ else:
+ comparison_result.query_resulted_in_data = False
+
+ return comparison_result
+
+ def execute_query(self, cursor, sql):
+ '''Execute the query and throw a timeout if needed.'''
+ def _execute_query():
+ try:
+ cursor.execute(sql)
+ except Exception as e:
+ current_thread().exception = e
+ query_thread = Thread(target=_execute_query, name='Query execution thread')
+ query_thread.daemon = True
+ query_thread.start()
+ query_thread.join(self.query_timeout_seconds)
+ if query_thread.is_alive():
+ if cursor.connection.supports_kill_connection:
+ LOG.debug('Attempting to kill connection')
+ cursor.connection.kill_connection()
+ LOG.debug('Kill connection')
+ cursor.close()
+ cursor.connection.close()
+ cursor = cursor\
+ .connection\
+ .connector\
+ .create_connection(db_name=cursor.connection.db_name)\
+ .create_cursor()
+ if cursor.connection.db_type == IMPALA:
+ self.impala_cursor = cursor
+ else:
+ self.reference_cursor = cursor
+ raise QueryTimeout('Query timed out after %s seconds' % self.query_timeout_seconds)
+ if hasattr(query_thread, 'exception'):
+ raise query_thread.exception
+
+ def standardize_data(self, data):
+ '''Return a val that is suitable for comparison.'''
+ # For float data we need to round otherwise differences in precision will cause errors
+ if isinstance(data, (float, Decimal)):
+ return round(data, self.DECIMAL_PLACES)
+ return data
+
+ def row_sort_cmp(self, left_row, right_row):
+ for left, right in izip(left_row, right_row):
+ if left is None and right is not None:
+ return -1
+ if left is not None and right is None:
+ return 1
+ result = cmp(left, right)
+ if result:
+ return result
+ return 0
+
+ def vals_are_equal(self, left, right):
+ if left == right:
+ return True
+ if isinstance(left, (int, float, Decimal)) and \
+ isinstance(right, (int, float, Decimal)):
+ return self.floats_are_equal(left, right)
+ return False
+
+ def floats_are_equal(self, left, right):
+ left = round(left, self.DECIMAL_PLACES)
+ right = round(right, self.DECIMAL_PLACES)
+ diff = abs(left - right)
+ if left * right == 0:
+ return diff < self.EPSILON
+ return diff / (abs(left) + abs(right)) < self.EPSILON
+
+
+class ComparisonResult(object):
+
+ def __init__(self, query, reference_db_type):
+ self.query = query
+ self.reference_db_type = reference_db_type
+ self.query_resulted_in_data = False
+ self.impala_row_count = None
+ self.reference_row_count = None
+ self.mismatch_at_row_number = None
+ self.mismatch_at_col_number = None
+ self.impala_row = None
+ self.reference_row = None
+ self.exception = None
+ self._error_message = None
+
+ @property
+ def error(self):
+ if not self._error_message:
+ if self.exception:
+ self._error_message = str(self.exception)
+ elif self.impala_row_count and \
+ self.impala_row_count != self.reference_row_count:
+ self._error_message = 'Row counts do not match: %s Impala rows vs %s %s rows' \
+ % (self.impala_row_count,
+ self.reference_db_type,
+ self.reference_row_count)
+ elif self.mismatch_at_row_number is not None:
+ # Write a row like "[a, b, <>, d]" where c is a bad value
+ impala_row = '[' + ', '.join(
+ '<<' + str(val) + '>>' if idx == self.mismatch_at_col_number - 1 else str(val)
+ for idx, val in enumerate(self.impala_row)
+ ) + ']'
+ reference_row = '[' + ', '.join(
+ '<<' + str(val) + '>>' if idx == self.mismatch_at_col_number - 1 else str(val)
+ for idx, val in enumerate(self.reference_row)
+ ) + ']'
+ self._error_message = \
+ 'Column %s in row %s does not match: %s Impala row vs %s %s row' \
+ % (self.mismatch_at_col_number,
+ self.mismatch_at_row_number,
+ impala_row,
+ reference_row,
+ self.reference_db_type)
+ return self._error_message
+
+ @property
+ def is_known_error(self):
+ return isinstance(self.exception, KnownError)
+
+ @property
+ def query_timed_out(self):
+ return isinstance(self.exception, QueryTimeout)
+
+
+class QueryTimeout(Exception):
+ pass
+
+
+class KnownError(Exception):
+
+ def __init__(self, jira_url):
+ Exception.__init__(self, 'Known issue: ' + jira_url)
+ self.jira_url = jira_url
+
+
+class QueryResultDiffSearcher(object):
+
+ # Sometimes things get into a bad state and the same error loops forever
+ ABORT_ON_REPEAT_ERROR_COUNT = 2
+
+ def __init__(self, impala_connection, reference_connection, filter_col_types=[]):
+ self.impala_connection = impala_connection
+ self.reference_connection = reference_connection
+ self.common_tables = DbConnection.describe_common_tables(
+ [impala_connection, reference_connection],
+ filter_col_types=filter_col_types)
+
+ # A file-backed dict of queries that produced a discrepancy, keyed by query number
+ # (in string form, as required by the dict).
+ self.query_shelve_path = gettempdir() + '/query.shelve'
+
+ # A list of all queries attempted
+ self.query_log_path = gettempdir() + '/impala_query_log.sql'
+
+ def search(self, number_of_test_queries, stop_on_result_mismatch, stop_on_crash):
+ if exists(self.query_shelve_path):
+ # Ensure a clean shelve will be created
+ remove(self.query_shelve_path)
+
+ start_time = time()
+ impala_sql_writer = SqlWriter.create(dialect=IMPALA)
+ reference_sql_writer = SqlWriter.create(
+ dialect=self.reference_connection.db_type)
+ query_result_comparator = QueryResultComparator(
+ self.impala_connection, self.reference_connection)
+ query_generator = QueryGenerator()
+ query_count = 0
+ queries_resulted_in_data_count = 0
+ mismatch_count = 0
+ query_timeout_count = 0
+ known_error_count = 0
+ impala_crash_count = 0
+ last_error = None
+ repeat_error_count = 0
+ with open(self.query_log_path, 'w') as impala_query_log:
+ impala_query_log.write(
+ '--\n'
+ '-- Stating new run\n'
+ '--\n')
+ while number_of_test_queries > query_count:
+ query = query_generator.create_query(self.common_tables)
+ impala_sql = impala_sql_writer.write_query(query)
+ if 'FULL OUTER JOIN' in impala_sql and self.reference_connection.db_type == MYSQL:
+ # Not supported by MySQL
+ continue
+
+ query_count += 1
+ LOG.info('Running query #%s', query_count)
+ impala_query_log.write(impala_sql + ';\n')
+ result = query_result_comparator.compare_query_results(query)
+ if result.query_resulted_in_data:
+ queries_resulted_in_data_count += 1
+ if result.error:
+ # TODO: These first two come from psycopg2, the postgres driver. Maybe we should
+ # try a different driver? Or maybe the usage of the driver isn't correct.
+ # Anyhow ignore these failures.
+ if 'division by zero' in result.error \
+ or 'out of range' in result.error \
+ or 'Too much data' in result.error:
+ LOG.debug('Ignoring error: %s', result.error)
+ query_count -= 1
+ continue
+
+ if result.is_known_error:
+ known_error_count += 1
+ elif result.query_timed_out:
+ query_timeout_count += 1
+ else:
+ mismatch_count += 1
+ with closing(open_shelve(self.query_shelve_path)) as query_shelve:
+ query_shelve[str(query_count)] = query
+
+ print('---Impala Query---\n')
+ print(impala_sql_writer.write_query(query, pretty=True) + '\n')
+ print('---Reference Query---\n')
+ print(reference_sql_writer.write_query(query, pretty=True) + '\n')
+ print('---Error---\n')
+ print(result.error + '\n')
+ print('------\n')
+
+ if 'Could not connect' in result.error \
+ or "Couldn't open transport for" in result.error:
+ # if stop_on_crash:
+ # break
+ # Assume Impala crashed and try restarting
+ impala_crash_count += 1
+ LOG.info('Restarting Impala')
+ call([join(getenv('IMPALA_HOME'), 'bin/start-impala-cluster.py'),
+ '--log_dir=%s' % getenv('LOG_DIR', "/tmp/")])
+ self.impala_connection.reconnect()
+ query_result_comparator.impala_cursor = self.impala_connection.create_cursor()
+ result = query_result_comparator.compare_query_results(query)
+ if result.error:
+ LOG.info('Restarting Impala')
+ call([join(getenv('IMPALA_HOME'), 'bin/start-impala-cluster.py'),
+ '--log_dir=%s' % getenv('LOG_DIR', "/tmp/")])
+ self.impala_connection.reconnect()
+ query_result_comparator.impala_cursor = self.impala_connection.create_cursor()
+ else:
+ break
+
+ if stop_on_result_mismatch and \
+ not (result.is_known_error or result.query_timed_out):
+ break
+
+ if last_error == result.error \
+ and not (result.is_known_error or result.query_timed_out):
+ repeat_error_count += 1
+ if repeat_error_count == self.ABORT_ON_REPEAT_ERROR_COUNT:
+ break
+ else:
+ last_error = result.error
+ repeat_error_count = 0
+ else:
+ if result.query_resulted_in_data:
+ LOG.info('Results matched (%s rows)', result.impala_row_count)
+ else:
+ LOG.info('Query did not produce meaningful data')
+ last_error = None
+ repeat_error_count = 0
+
+ return SearchResults(
+ query_count,
+ queries_resulted_in_data_count,
+ mismatch_count,
+ query_timeout_count,
+ known_error_count,
+ impala_crash_count,
+ time() - start_time)
+
+
+class SearchResults(object):
+ '''This class holds information about the outcome of a search run.'''
+
+ def __init__(self,
+ query_count,
+ queries_resulted_in_data_count,
+ mismatch_count,
+ query_timeout_count,
+ known_error_count,
+ impala_crash_count,
+ run_time_in_seconds):
+ # Approx number of queries run, some queries may have been ignored
+ self.query_count = query_count
+ self.queries_resulted_in_data_count = queries_resulted_in_data_count
+ # Number of queries that had an error or result mismatch
+ self.mismatch_count = mismatch_count
+ self.query_timeout_count = query_timeout_count
+ self.known_error_count = known_error_count
+ self.impala_crash_count = impala_crash_count
+ self.run_time_in_seconds = run_time_in_seconds
+
+ def get_summary_text(self):
+ mins, secs = divmod(self.run_time_in_seconds, 60)
+ hours, mins = divmod(mins, 60)
+ hours = int(hours)
+ mins = int(mins)
+ if hours:
+ run_time = '%s hour and %s minutes' % (hours, mins)
+ else:
+ secs = int(secs)
+ run_time = '%s seconds' % secs
+ if mins:
+ run_time = '%s mins and ' % mins + run_time
+ summary_params = self.__dict__
+ summary_params['run_time'] = run_time
+ return (
+ '%(mismatch_count)s mismatches found after running %(query_count)s queries in '
+ '%(run_time)s.\n'
+ '%(queries_resulted_in_data_count)s of %(query_count)s queries produced results.'
+ '\n'
+ '%(impala_crash_count)s Impala crashes occurred.\n'
+ '%(known_error_count)s queries were excluded from the mismatch count because '
+ 'they are known errors.\n'
+ '%(query_timeout_count)s queries timed out and were excluded from all counts.') \
+ % summary_params
+
+
+if __name__ == '__main__':
+ import sys
+ from optparse import NO_DEFAULT, OptionGroup, OptionParser
+
+ parser = OptionParser()
+ parser.add_option('--log-level', default='INFO',
+ help='The log level to use.', choices=('DEBUG', 'INFO', 'WARN', 'ERROR'))
+ parser.add_option('--db-name', default='randomness',
+ help='The name of the database to use. Ex: funcal.')
+
+ parser.add_option('--reference-db-type', default=MYSQL, choices=(MYSQL, POSTGRESQL),
+ help='The type of the reference database to use. Ex: MYSQL.')
+ parser.add_option('--stop-on-mismatch', default=False, action='store_true',
+ help='Exit immediately upon find a discrepancy in a query result.')
+ parser.add_option('--stop-on-crash', default=False, action='store_true',
+ help='Exit immediately if Impala crashes.')
+ parser.add_option('--query-count', default=1000, type=int,
+ help='Exit after running the given number of queries.')
+ parser.add_option('--exclude-types', default='Double,Float,TinyInt',
+ help='A comma separated list of data types to exclude while generating queries.')
+
+ group = OptionGroup(parser, 'MySQL Options')
+ group.add_option('--mysql-host', default='localhost',
+ help='The name of the host running the MySQL database.')
+ group.add_option('--mysql-port', default=3306, type=int,
+ help='The port of the host running the MySQL database.')
+ group.add_option('--mysql-user', default='root',
+ help='The user name to use when connecting to the MySQL database.')
+ group.add_option('--mysql-password',
+ help='The password to use when connecting to the MySQL database.')
+ parser.add_option_group(group)
+
+ group = OptionGroup(parser, 'Postgresql Options')
+ group.add_option('--postgresql-host', default='localhost',
+ help='The name of the host running the Postgresql database.')
+ group.add_option('--postgresql-port', default=5432, type=int,
+ help='The port of the host running the Postgresql database.')
+ group.add_option('--postgresql-user', default='postgres',
+ help='The user name to use when connecting to the Postgresql database.')
+ group.add_option('--postgresql-password',
+ help='The password to use when connecting to the Postgresql database.')
+ parser.add_option_group(group)
+
+ for group in parser.option_groups + [parser]:
+ for option in group.option_list:
+ if option.default != NO_DEFAULT:
+ option.help += " [default: %default]"
+
+ options, args = parser.parse_args()
+
+ basicConfig(level=options.log_level)
+
+ impala_connection = DbConnector(IMPALA).create_connection(options.db_name)
+ db_connector_param_key = options.reference_db_type.lower()
+ reference_connection = DbConnector(options.reference_db_type,
+ user_name=getattr(options, db_connector_param_key + '_user'),
+ password=getattr(options, db_connector_param_key + '_password'),
+ host_name=getattr(options, db_connector_param_key + '_host'),
+ port=getattr(options, db_connector_param_key + '_port')) \
+ .create_connection(options.db_name)
+ if options.exclude_types:
+ exclude_types = set(type_name.lower() for type_name
+ in options.exclude_types.split(','))
+ filter_col_types = [type_ for type_ in TYPES
+ if type_.__name__.lower() in exclude_types]
+ else:
+ filter_col_types = []
+ diff_searcher = QueryResultDiffSearcher(
+ impala_connection, reference_connection, filter_col_types=filter_col_types)
+ search_results = diff_searcher.search(
+ options.query_count, options.stop_on_mismatch, options.stop_on_crash)
+ print(search_results.get_summary_text())
+ sys.exit(search_results.mismatch_count)
diff --git a/tests/comparison/model.py b/tests/comparison/model.py
new file mode 100644
index 000000000..7249a2017
--- /dev/null
+++ b/tests/comparison/model.py
@@ -0,0 +1,741 @@
+# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+
+class Query(object):
+ '''A representation of the stucture of a SQL query. Only the select_clause and
+ from_clause are required for a valid query.
+
+ '''
+
+ def __init__(self, select_clause, from_clause):
+ self.with_clause = None
+ self.select_clause = select_clause
+ self.from_clause = from_clause
+ self.where_clause = None
+ self.group_by_clause = None
+ self.having_clause = None
+ self.union_clause = None
+
+ @property
+ def table_exprs(self):
+ '''Provides a list of all table_exprs that are declared by this query. This
+ includes table_exprs in the WITH and FROM sections.
+ '''
+ table_exprs = self.from_clause.table_exprs
+ if self.with_clause:
+ table_exprs += self.with_clause.table_exprs
+ return table_exprs
+
+
+class SelectClause(object):
+ '''This encapuslates the SELECT part of a query. It is convenient to separate
+ non-agg items from agg items so that it is simple to know if the query
+ is an agg query or not.
+
+ '''
+
+ def __init__(self, non_agg_items=None, agg_items=None):
+ self.non_agg_items = non_agg_items or list()
+ self.agg_items = agg_items or list()
+ self.distinct = False
+
+ @property
+ def select_items(self):
+ '''Provides a consolidated view of all select items.'''
+ return self.non_agg_items + self.agg_items
+
+
+class SelectItem(object):
+ '''A representation of any possible expr than would be valid in
+
+ SELECT [, ...] FROM ...
+
+ Each SelectItem contains a ValExpr which will either be a instance of a
+ DataType (representing a constant), a Column, or a Func.
+
+ Ex: "SELECT int_col + smallint_col FROM alltypes" would have a val_expr of
+ Plus(Column(), Column()).
+
+ '''
+
+ def __init__(self, val_expr, alias=None):
+ self.val_expr = val_expr
+ self.alias = alias
+
+ @property
+ def type(self):
+ '''Returns the DataType of this item.'''
+ return self.val_expr.type
+
+ @property
+ def is_agg(self):
+ '''Evaluates to True if this item contains an aggregate expression.'''
+ return self.val_expr.is_agg
+
+
+class ValExpr(object):
+ '''This is an AbstractClass that represents a generic expr that results in a
+ scalar. The abc module was not used because it caused problems for the pickle
+ module.
+
+ '''
+
+ @property
+ def type(self):
+ '''This is declared for documentations purposes, subclasses should override this to
+ return the DataType that this expr represents.
+ '''
+ pass
+
+ @property
+ def base_type(self):
+ '''Return the most fundemental data type that the expr evaluates to. Only
+ numeric types will result in a different val than would be returned by self.type.
+
+ Ex:
+ if self.type == BigInt:
+ assert self.base_type == Int
+ if self.type == Double:
+ assert self.base_type == Float
+ if self.type == String:
+ assert self.base_type == self.type
+ '''
+ if self.returns_int:
+ return Int
+ if self.returns_float:
+ return Float
+ return self.type
+
+ @property
+ def is_func(self):
+ return isinstance(self, Func)
+
+ @property
+ def is_agg(self):
+ '''Evaluates to True if this expression contains an aggregate function.'''
+ if isinstance(self, AggFunc):
+ return True
+ if self.is_func:
+ for arg in self.args:
+ if arg.is_agg:
+ return True
+
+ @property
+ def is_col(self):
+ return isinstance(self, Column)
+
+ @property
+ def is_constant(self):
+ return isinstance(self, DataType)
+
+ @property
+ def returns_boolean(self):
+ return issubclass(self.type, Boolean)
+
+ @property
+ def returns_number(self):
+ return issubclass(self.type, Number)
+
+ @property
+ def returns_int(self):
+ return issubclass(self.type, Int)
+
+ @property
+ def returns_float(self):
+ return issubclass(self.type, Float)
+
+ @property
+ def returns_string(self):
+ return issubclass(self.type, String)
+
+ @property
+ def returns_timestamp(self):
+ return issubclass(self.type, Timestamp)
+
+
+class Column(ValExpr):
+ '''A representation of a col. All TableExprs will have Columns. So a Column
+ may belong to an InlineView as well as a standard Table.
+
+ This class is used in two ways:
+
+ 1) As a piece of metadata in a table definiton. In this usage the col isn't
+ intended to represent an val.
+
+ 2) As an expr in a query, for example an item being selected or as part of
+ a join condition. In this usage the col is more like a val, which is why
+ it implements/extends ValExpr.
+
+ '''
+
+ def __init__(self, owner, name, type_):
+ self.owner = owner
+ self.name = name
+ self._type = type_
+
+ @property
+ def type(self):
+ return self._type
+
+ def __hash__(self):
+ return hash(self.name)
+
+ def __eq__(self, other):
+ if not isinstance(other, Column):
+ return False
+ if self is other:
+ return True
+ return self.name == other.name and self.owner.identifier == other.owner.identifier
+
+ def __repr__(self):
+ return '%s' % (
+ type(self).__name__, self.name, self._type.__name__)
+
+
+class FromClause(object):
+ '''A representation of a FROM clause. The member variable join_clauses may optionally
+ contain JoinClause items.
+
+ '''
+
+ def __init__(self, table_expr, join_clauses=None):
+ self.table_expr = table_expr
+ self.join_clauses = join_clauses or list()
+
+ @property
+ def table_exprs(self):
+ '''Provides a list of all table_exprs that are declared within this FROM
+ block.
+ '''
+ table_exprs = [join_clause.table_expr for join_clause in self.join_clauses]
+ table_exprs.append(self.table_expr)
+ return table_exprs
+
+
+class TableExpr(object):
+ '''This is an AbstractClass that represents something that a query may use to select
+ from or join on. The abc module was not used because it caused problems for the
+ pickle module.
+
+ '''
+
+ def identifier(self):
+ '''Returns either a table name or alias if one has been declared.'''
+ pass
+
+ def cols(self):
+ pass
+
+ @property
+ def cols_by_base_type(self):
+ '''Group cols by their basic data type and return a dict of the results.
+
+ As an example, a "BigInt" would be considered as an "Int".
+ '''
+ return DataType.group_by_base_type(self.cols)
+
+ @property
+ def is_table(self):
+ return isinstance(self, Table)
+
+ @property
+ def is_inline_view(self):
+ return isinstance(self, InlineView)
+
+ @property
+ def is_with_clause_inline_view(self):
+ return isinstance(self, WithClauseInlineView)
+
+ def __eq__(self, other):
+ if not isinstance(other, type(self)):
+ return False
+ return self.identifier == other.identifier
+
+
+class Table(TableExpr):
+ '''Represents a standard database table.'''
+
+ def __init__(self, name):
+ self.name = name
+ self._cols = []
+ self.alias = None
+
+ @property
+ def identifier(self):
+ return self.alias or self.name
+
+ @property
+ def cols(self):
+ return self._cols
+
+ @cols.setter
+ def cols(self, cols):
+ self._cols = cols
+
+
+class InlineView(TableExpr):
+ '''Represents an inline view.
+
+ Ex: In the query "SELECT * FROM (SELECT * FROM foo) AS bar",
+ "(SELECT * FROM foo) AS bar" would be an inline view.
+
+ '''
+
+ def __init__(self, query):
+ self.query = query
+ self.alias = None
+
+ @property
+ def identifier(self):
+ return self.alias
+
+ @property
+ def cols(self):
+ return [Column(self, item.alias, item.type) for item in
+ self.query.select_clause.non_agg_items + self.query.select_clause.agg_items]
+
+
+class WithClause(object):
+ '''Represents a WITH clause.
+
+ Ex: In the query "WITH bar AS (SELECT * FROM foo) SELECT * FROM bar",
+ "WITH bar AS (SELECT * FROM foo)" would be the with clause.
+
+ '''
+
+ def __init__(self, with_clause_inline_views):
+ self.with_clause_inline_views = with_clause_inline_views
+
+ @property
+ def table_exprs(self):
+ return self.with_clause_inline_views
+
+
+class WithClauseInlineView(InlineView):
+ '''Represents the entries in a WITH clause. These are very similar to InlineViews but
+ may have an additional alias.
+
+ Ex: WITH bar AS (SELECT * FROM foo)
+ SELECT *
+ FROM bar as r
+ JOIN (SELECT * FROM baz) AS z ON ...
+
+ The WithClauseInlineView has aliases "bar" and "r" while the InlineView has
+ only the alias "z".
+
+ '''
+
+ def __init__(self, query, with_clause_alias):
+ self.query = query
+ self.with_clause_alias = with_clause_alias
+ self.alias = None
+
+ @property
+ def identifier(self):
+ return self.alias or self.with_clause_alias
+
+
+class JoinClause(object):
+ '''A representation of a JOIN clause.
+
+ Ex: SELECT * FROM foo JOIN [ON ]
+
+ The member variable boolean_expr will be an instance of a boolean func
+ defined below.
+
+ '''
+
+ JOINS_TYPES = ['INNER', 'LEFT', 'RIGHT', 'FULL OUTER', 'CROSS']
+
+ def __init__(self, join_type, table_expr, boolean_expr=None):
+ self.join_type = join_type
+ self.table_expr = table_expr
+ self.boolean_expr = boolean_expr
+
+
+class WhereClause(object):
+ '''The member variable boolean_expr will be an instance of a boolean func
+ defined below.
+
+ '''
+
+ def __init__(self, boolean_expr):
+ self.boolean_expr = boolean_expr
+
+
+class GroupByClause(object):
+
+ def __init__(self, select_items):
+ self.group_by_items = select_items
+
+
+class HavingClause(object):
+ '''The member variable boolean_expr will be an instance of a boolean func
+ defined below.
+
+ '''
+
+ def __init__(self, boolean_expr):
+ self.boolean_expr = boolean_expr
+
+
+class UnionClause(object):
+ '''A representation of a UNION clause.
+
+ If the member variable "all" is True, the instance represents a "UNION ALL".
+
+ '''
+
+ def __init__(self, query):
+ self.query = query
+ self.all = False
+
+ @property
+ def queries(self):
+ queries = list()
+ query = self.query
+ while True:
+ queries.append(query)
+ if not query.union_clause:
+ break
+ query = query.union_clause.query
+ return queries
+
+
+class DataTypeMetaclass(type):
+ '''Provides sorting of classes used to determine upcasting.'''
+
+ def __cmp__(cls, other):
+ return cmp(
+ getattr(cls, 'CMP_VALUE', cls.__name__),
+ getattr(other, 'CMP_VALUE', other.__name__))
+
+
+class DataType(ValExpr):
+ '''Base class for data types.
+
+ Data types are represented as classes so inheritence can be used.
+
+ '''
+
+ __metaclass__ = DataTypeMetaclass
+
+ def __init__(self, val):
+ self.val = val
+
+ @property
+ def type(self):
+ return type(self)
+
+ @staticmethod
+ def group_by_base_type(vals):
+ '''Group cols by their basic data type and return a dict of the results.
+
+ As an example, a "BigInt" would be considered as an "Int".
+ '''
+ vals_by_type = defaultdict(list)
+ for val in vals:
+ type_ = val.type
+ if issubclass(type_, Int):
+ type_ = Int
+ elif issubclass(type_, Float):
+ type_ = Float
+ vals_by_type[type_].append(val)
+ return vals_by_type
+
+
+class Boolean(DataType):
+ pass
+
+
+class Number(DataType):
+ pass
+
+
+class Int(Number):
+
+ # Used to compare with other numbers for determining upcasting
+ CMP_VALUE = 2
+
+ # Used during data generation to keep vals in range
+ MIN = -2 ** 31
+ MAX = -MIN - 1
+
+ # Aliases used when reading and writing table definitions
+ POSTGRESQL = ['INTEGER']
+
+
+class TinyInt(Int):
+
+ CMP_VALUE = 0
+
+ MIN = -2 ** 7
+ MAX = -MIN - 1
+
+ POSTGRESQL = ['SMALLINT']
+
+
+class SmallInt(Int):
+
+ CMP_VALUE = 1
+
+ MIN = -2 ** 15
+ MAX = -MIN - 1
+
+
+class BigInt(Int):
+
+ CMP_VALUE = 3
+
+ MIN = -2 ** 63
+ MAX = -MIN - 1
+
+
+class Float(Number):
+
+ CMP_VALUE = 4
+
+ POSTGRESQL = ['REAL']
+
+
+class Double(Float):
+
+ CMP_VALUE = 5
+
+ MYSQL = ['DOUBLE', 'DECIMAL'] # Use double by default but add decimal synonym
+ POSTGRESQL = ['DOUBLE PRECISION']
+
+
+class String(DataType):
+
+ MIN = 0
+ # The Impala limit is 32,767 but MySQL has a row size limit of 65,535. To allow 3+
+ # String cols per table, the limit will be lowered to 1,000. That should be fine
+ # for testing anyhow.
+ MAX = 1000
+
+ MYSQL = ['VARCHAR(%s)' % MAX]
+ POSTGRESQL = MYSQL + ['CHARACTER VARYING']
+
+
+class Timestamp(DataType):
+
+ MYSQL = ['DATETIME']
+ POSTGRESQL = ['TIMESTAMP WITHOUT TIME ZONE']
+
+
+NUMBER_TYPES = [Int, TinyInt, SmallInt, BigInt, Float, Double]
+TYPES = NUMBER_TYPES + [Boolean, String, Timestamp]
+
+class Func(ValExpr):
+ '''Base class for funcs'''
+
+ def __init__(self, *args):
+ self.args = list(args)
+
+ def __hash__(self):
+ return hash(type(self)) + hash(tuple(self.args))
+
+ def __eq__(self, other):
+ if not isinstance(other, type(self)):
+ return False
+ if self is other:
+ return True
+ return self.args == other.args
+
+
+class UnaryFunc(Func):
+
+ def __init__(self, arg):
+ Func.__init__(self, arg)
+
+
+class BinaryFunc(Func):
+
+ def __init__(self, left, right):
+ Func.__init__(self, left, right)
+
+ @property
+ def left(self):
+ return self.args[0]
+
+ @left.setter
+ def left(self, left):
+ self.args[0] = left
+
+ @property
+ def right(self):
+ return self.args[1]
+
+ @right.setter
+ def right(self, right):
+ self.args[1] = right
+
+
+class BooleanFunc(Func):
+
+ @property
+ def type(self):
+ return Boolean
+
+
+class IntFunc(Func):
+
+ @property
+ def type(self):
+ return Int
+
+
+class DoubleFunc(Func):
+
+ @property
+ def type(self):
+ return Double
+
+
+class StringFunc(Func):
+
+ @property
+ def type(self):
+ return String
+
+
+class UpcastingFunc(Func):
+
+ @property
+ def type(self):
+ return max(arg.type for arg in self.args)
+
+
+class AggFunc(Func):
+
+ # Avoid having a self.distinct because it would need to be __init__'d explictly,
+ # which none of the AggFunc subclasses do (ex: Avg doesn't have it's
+ # own __init__).
+
+ @property
+ def distinct(self):
+ return getattr(self, '_distinct', False)
+
+ @distinct.setter
+ def distinct(self, val):
+ return setattr(self, '_distinct', val)
+
+# The classes below diverge from above by including the SQL representation. It's a lot
+# easier this way because there are a lot of funcs but they all have the same
+# structure. Non-standard funcs, such as string concatenation, would need to have
+# their representation information elsewhere (like classes above).
+
+Parentheses = type('Parentheses', (UnaryFunc, UpcastingFunc), {'FORMAT': '({})'})
+
+IsNull = type('IsNull', (UnaryFunc, BooleanFunc), {'FORMAT': '{} IS NULL'})
+IsNotNull = type('IsNotNull', (UnaryFunc, BooleanFunc), {'FORMAT': '{} IS NOT NULL'})
+And = type('And', (BinaryFunc, BooleanFunc), {'FORMAT': '{} AND {}'})
+Or = type('Or', (BinaryFunc, BooleanFunc), {'FORMAT': '{} OR {}'})
+Equals = type('Equals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} = {}'})
+NotEquals = type('NotEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} != {}'})
+GreaterThan = type('GreaterThan', (BinaryFunc, BooleanFunc), {'FORMAT': '{} > {}'})
+LessThan = type('LessThan', (BinaryFunc, BooleanFunc), {'FORMAT': '{} < {}'})
+GreaterThanOrEquals = type(
+ 'GreaterThanOrEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} >= {}'})
+LessThanOrEquals = type(
+ 'LessThanOrEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} <= {}'})
+
+Plus = type('Plus', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} + {}'})
+Minus = type('Minus', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} - {}'})
+Multiply = type('Multiply', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} * {}'})
+Divide = type('Divide', (BinaryFunc, DoubleFunc), {'FORMAT': '{} / {}'})
+
+Floor = type('Floor', (UnaryFunc, IntFunc), {'FORMAT': 'FLOOR({})'})
+
+Concat = type('Concat', (BinaryFunc, StringFunc), {'FORMAT': 'CONCAT({}, {})'})
+Length = type('Length', (UnaryFunc, IntFunc), {'FORMAT': 'LENGTH({})'})
+
+ExtractYear = type(
+ 'ExtractYear', (UnaryFunc, IntFunc), {'FORMAT': "EXTRACT('YEAR' FROM {})"})
+
+# Formatting of agg funcs is a little trickier since they may have a distinct
+Avg = type('Avg', (UnaryFunc, DoubleFunc, AggFunc), {})
+Count = type('Count', (UnaryFunc, IntFunc, AggFunc), {})
+Max = type('Max', (UnaryFunc, UpcastingFunc, AggFunc), {})
+Min = type('Min', (UnaryFunc, UpcastingFunc, AggFunc), {})
+Sum = type('Sum', (UnaryFunc, UpcastingFunc, AggFunc), {})
+
+UNARY_BOOLEAN_FUNCS = [IsNull, IsNotNull]
+BINARY_BOOLEAN_FUNCS = [And, Or]
+RELATIONAL_OPERATORS = [
+ Equals, NotEquals, GreaterThan, LessThan, GreaterThanOrEquals, LessThanOrEquals]
+MATH_OPERATORS = [Plus, Minus, Multiply] # Leaving out Divide
+BINARY_STRING_FUNCS = [Concat]
+AGG_FUNCS = [Avg, Count, Max, Min, Sum]
+
+
+class If(Func):
+
+ FORMAT = 'CASE WHEN {} THEN {} ELSE {} END'
+
+ def __init__(self, boolean_expr, consquent_expr, alternative_expr):
+ Func.__init__(
+ self, boolean_expr, consquent_expr, alternative_expr)
+
+ @property
+ def boolean_expr(self):
+ return self.args[0]
+
+ @property
+ def consquent_expr(self):
+ return self.args[1]
+
+ @property
+ def alternative_expr(self):
+ return self.args[2]
+
+ @property
+ def type(self):
+ return max((self.consquent_expr, self.alternative_expr))
+
+
+class Greatest(BinaryFunc, UpcastingFunc, If):
+
+ def __init__(self, left, rigt):
+ BinaryFunc.__init__(self, left, rigt)
+ If.__init__(self, GreaterThan(left, rigt), left, rigt)
+
+ @property
+ def type(self):
+ return UpcastingFunc.type.fget(self)
+
+
+class Cast(Func):
+
+ FORMAT = 'CAST({} AS {})'
+
+ def __init__(self, val_expr, resulting_type):
+ if resulting_type not in TYPES:
+ raise Exception('Unexpected type: {}'.format(resulting_type))
+ Func.__init__(self, val_expr, resulting_type)
+
+ @property
+ def val_expr(self):
+ return self.args[0]
+
+ @property
+ def resulting_type(self):
+ return self.args[1]
+
+ @property
+ def type(self):
+ return self.resulting_type
diff --git a/tests/comparison/model_translator.py b/tests/comparison/model_translator.py
new file mode 100644
index 000000000..0c9194ed4
--- /dev/null
+++ b/tests/comparison/model_translator.py
@@ -0,0 +1,367 @@
+# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from inspect import getmro
+from logging import getLogger
+from re import sub
+from sqlparse import format
+
+from tests.comparison.model import (
+ Boolean,
+ Float,
+ Int,
+ Number,
+ Query,
+ String,
+ Timestamp)
+
+LOG = getLogger(__name__)
+
+class SqlWriter(object):
+ '''Subclasses of SQLWriter will take a Query and provide the SQL representation for a
+ specific database such as Impala or MySQL. The SqlWriter.create([dialect=])
+ factory method may be used instead of specifying the concrete class.
+
+ Another important function of this class is to ensure that CASTs produce the same
+ results across different databases. Sometimes the CASTs implemented here produce odd
+ results. For example, the result of CAST(date_col AS INT) in MySQL may be an int of
+ YYYYMMDD whereas in Impala it may be seconds since the epoch. For comparison purposes
+ the CAST could be transformed into EXTRACT(DAY from date_col).
+ '''
+
+ @staticmethod
+ def create(dialect='impala'):
+ '''Create and return a new SqlWriter appropriate for the given sql dialect. "dialect"
+ refers to database specific deviations of sql, and the val should be one of
+ "IMPALA", "MYSQL", or "POSTGRESQL".
+ '''
+ dialect = dialect.upper()
+ if dialect == 'IMPALA':
+ return SqlWriter()
+ if dialect == 'POSTGRESQL':
+ return PostgresqlSqlWriter()
+ if dialect == 'MYSQL':
+ return MySQLSqlWriter()
+ raise Exception('Unknown dialect: %s' % dialect)
+
+ def write_query(self, query, pretty=False):
+ '''Return SQL as a string for the given query.'''
+ sql = list()
+ # Write out each section in the proper order
+ for clause in (
+ query.with_clause,
+ query.select_clause,
+ query.from_clause,
+ query.where_clause,
+ query.group_by_clause,
+ query.having_clause,
+ query.union_clause):
+ if clause:
+ sql.append(self._write(clause))
+ sql = '\n'.join(sql)
+ if pretty:
+ sql = self.make_pretty_sql(sql)
+ return sql
+
+ def make_pretty_sql(self, sql):
+ try:
+ sql = format(sql, reindent=True)
+ except Exception as e:
+ LOG.warn('Unable to format sql: %s', e)
+ return sql
+
+ def _write_with_clause(self, with_clause):
+ return 'WITH ' + ',\n'.join('%s AS (%s)' % (view.identifier, self._write(view.query))
+ for view in with_clause.with_clause_inline_views)
+
+ def _write_select_clause(self, select_clause):
+ items = select_clause.non_agg_items + select_clause.agg_items
+ sql = 'SELECT'
+ if select_clause.distinct:
+ sql += ' DISTINCT'
+ sql += '\n' + ',\n'.join(self._write(item) for item in items)
+ return sql
+
+ def _write_select_item(self, select_item):
+ # If the query is nested, the items will have aliases so that the outer query can
+ # easily reference them.
+ if not select_item.alias:
+ raise Exception('An alias is required')
+ return '%s AS %s' % (self._write(select_item.val_expr), select_item.alias)
+
+ def _write_column(self, col):
+ return '%s.%s' % (col.owner.identifier, col.name)
+
+ def _write_from_clause(self, from_clause):
+ sql = 'FROM %s' % self._write(from_clause.table_expr)
+ if from_clause.join_clauses:
+ sql += '\n' + '\n'.join(self._write(join) for join in from_clause.join_clauses)
+ return sql
+
+ def _write_table(self, table):
+ if table.alias:
+ return '%s AS %s' % (table.name, table.identifier)
+ return table.name
+
+ def _write_inline_view(self, inline_view):
+ if not inline_view.identifier:
+ raise Exception('An inline view requires an identifier')
+ return '(\n%s\n) AS %s' % (self._write(inline_view.query), inline_view.identifier)
+
+ def _write_with_clause_inline_view(self, with_clause_inline_view):
+ if not with_clause_inline_view.with_clause_alias:
+ raise Exception('An with clause entry requires an identifier')
+ sql = with_clause_inline_view.with_clause_alias
+ if with_clause_inline_view.alias:
+ sql += ' AS ' + with_clause_inline_view.alias
+ return sql
+
+ def _write_join_clause(self, join_clause):
+ sql = '%s JOIN %s' % (join_clause.join_type, self._write(join_clause.table_expr))
+ if join_clause.boolean_expr:
+ sql += ' ON ' + self._write(join_clause.boolean_expr)
+ return sql
+
+ def _write_where_clause(self, where_clause):
+ return 'WHERE\n' + self._write(where_clause.boolean_expr)
+
+ def _write_group_by_clause(self, group_by_clause):
+ return 'GROUP BY\n' + ',\n'.join(self._write(item.val_expr)
+ for item in group_by_clause.group_by_items)
+
+ def _write_having_clause(self, having_clause):
+ return 'HAVING\n' + self._write(having_clause.boolean_expr)
+
+ def _write_union_clause(self, union_clause):
+ sql = 'UNION'
+ if union_clause.all:
+ sql += ' ALL'
+ sql += '\n' + self._write(union_clause.query)
+ return sql
+
+ def _write_data_type(self, data_type):
+ '''Write a literal value.'''
+ if data_type.returns_string:
+ return "'{}'".format(data_type.val)
+ if data_type.returns_timestamp:
+ return "CAST('{}' AS TIMESTAMP)".format(data_type.val)
+ return str(data_type.val)
+
+ def _write_func(self, func):
+ return func.FORMAT.format(*[self._write(arg) for arg in func.args])
+
+ def _write_cast(self, cast):
+ # Handle casts that produce different results across database types or just don't
+ # make sense like casting a DATE as a BOOLEAN....
+ if cast.val_expr.returns_boolean:
+ if issubclass(cast.resulting_type, Timestamp):
+ return "CAST(CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END AS TIMESTAMP)"\
+ .format(self._write(cast.val_expr))
+ elif cast.val_expr.returns_number:
+ if issubclass(cast.resulting_type, Timestamp):
+ return ("CAST(CONCAT('2000-01-', "
+ "LPAD(CAST(ABS(FLOOR({})) % 31 + 1 AS STRING), 2, '0')) "
+ "AS TIMESTAMP)").format(self._write(cast.val_expr))
+ elif cast.val_expr.returns_string:
+ if issubclass(cast.resulting_type, Boolean):
+ return "(LENGTH({}) > 2)".format(self._write(cast.val_expr))
+ if issubclass(cast.resulting_type, Timestamp):
+ return ("CAST(CONCAT('2000-01-', LPAD(CAST(LENGTH({}) % 31 + 1 AS STRING), "
+ "2, '0')) AS TIMESTAMP)").format(self._write(cast.val_expr))
+ elif cast.val_expr.returns_timestamp:
+ if issubclass(cast.resulting_type, Boolean):
+ return '(DAY({0}) > MONTH({0}))'.format(self._write(cast.val_expr))
+ if issubclass(cast.resulting_type, Number):
+ return ('(DAY({0}) + 100 * MONTH({0}) + 100 * 100 * YEAR({0}))').format(
+ self._write(cast.val_expr))
+ return self._write_func(cast)
+
+ def _write_agg_func(self, agg_func):
+ sql = type(agg_func).__name__.upper() + '('
+ if agg_func.distinct:
+ sql += 'DISTINCT '
+ # All agg funcs only have a single arg
+ sql += self._write(agg_func.args[0]) + ')'
+ return sql
+
+ def _write_data_type_metaclass(self, data_type_class):
+ '''Write a data type class such as Int or Boolean.'''
+ return data_type_class.__name__.upper()
+
+ def _write(self, object_):
+ '''Return a sql string representation of the given object.'''
+ # What's below is effectively a giant switch statement. It works based on a func
+ # naming and signature convention. It should match the incoming object with the
+ # corresponding func defined, then call the func and return the result.
+ #
+ # Ex:
+ # a = model.And(...)
+ # _write(a) should call _write_func(a) because "And" is a subclass of "Func" and no
+ # other _writer_ methods have been defined higher up the method
+ # resolution order (MRO). If _write_and(...) were to be defined, it would be called
+ # instead.
+ for type_ in getmro(type(object_)):
+ writer_func_name = '_write' + sub('([A-Z])', r'_\1', type_.__name__).lower()
+ writer_func = getattr(self, writer_func_name, None)
+ if writer_func:
+ return writer_func(object_)
+
+ # Handle any remaining cases
+ if isinstance(object_, Query):
+ return self.write_query(object_)
+
+ raise Exception('Unsupported object: %s<%s>' % (type(object_).__name__, object_))
+
+
+class PostgresqlSqlWriter(SqlWriter):
+ # TODO: This class is out of date since switching to MySQL. This is left here as is
+ # in case there is a desire to switch back in the future (it should be better than
+ # starting from nothing).
+
+ def _write_divide(self, divide):
+ # For ints, Postgresql does int division but Impala does float division.
+ return 'CAST({} AS REAL) / {}' \
+ .format(*[self._write(arg) for arg in divide.args])
+
+ def _write_data_type_metaclass(self, data_type_class):
+ '''Write a data type class such as Int or Boolean.'''
+ if hasattr(data_type_class, 'POSTGRESQL'):
+ return data_type_class.POSTGRESQL[0]
+ return data_type_class.__name__.upper()
+
+ def _write_cast(self, cast):
+ # Handle casts that produce different results across database types or just don't
+ # make sense like casting a DATE as a BOOLEAN....
+ if cast.val_expr.returns_boolean:
+ if issubclass(cast.resulting_type, Float):
+ return "CASE {} WHEN TRUE THEN 1.0 WHEN FALSE THEN 0.0 END".format(
+ self._write(cast.val_expr))
+ if issubclass(cast.resulting_type, Timestamp):
+ return "CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END".format(
+ self._write(cast.val_expr))
+ if issubclass(cast.resulting_type, String):
+ return "CASE {} WHEN TRUE THEN '1' WHEN FALSE THEN '0' END".format(
+ self._write(cast.val_expr))
+ elif cast.val_expr.returns_number:
+ if issubclass(cast.resulting_type, Boolean):
+ return 'CASE WHEN ({0}) != 0 THEN TRUE WHEN ({0}) = 0 THEN FALSE END'.format(
+ self._write(cast.val_expr))
+ if issubclass(cast.resulting_type, Timestamp):
+ return "CASE WHEN ({}) > 0 THEN '2000-01-01' ELSE '1999-01-01' END".format(
+ self._write(cast.val_expr))
+ elif cast.val_expr.returns_string:
+ if issubclass(cast.resulting_type, Boolean):
+ return "(LENGTH({}) > 2)".format(self._write(cast.val_expr))
+ elif cast.val_expr.returns_timestamp:
+ if issubclass(cast.resulting_type, Boolean):
+ return '(EXTRACT(DAY FROM {0}) > EXTRACT(MONTH FROM {0}))'.format(
+ self._write(cast.val_expr))
+ if issubclass(cast.resulting_type, Number):
+ return ('(EXTRACT(DAY FROM {0}) '
+ '+ 100 * EXTRACT(MONTH FROM {0}) '
+ '+ 100 * 100 * EXTRACT(YEAR FROM {0}))').format(
+ self._write(cast.val_expr))
+ return self._write_func(cast)
+
+
+class MySQLSqlWriter(SqlWriter):
+
+ def write_query(self, query, pretty=False):
+ # MySQL doesn't support WITH clauses so they need to be converted into inline views.
+ # We are going to cheat by making use of the fact that the query generator creates
+ # with clause entries with unique aliases even considering nested queries.
+ sql = list()
+ for clause in (
+ query.select_clause,
+ query.from_clause,
+ query.where_clause,
+ query.group_by_clause,
+ query.having_clause,
+ query.union_clause):
+ if clause:
+ sql.append(self._write(clause))
+ sql = '\n'.join(sql)
+ if query.with_clause:
+ # Just replace the named referenes with inline views. Go in reverse order because
+ # entries at the bottom of the WITH clause definition may reference entries above.
+ for with_clause_inline_view in reversed(query.with_clause.with_clause_inline_views):
+ replacement_sql = '(' + self.write_query(with_clause_inline_view.query) + ')'
+ sql = sql.replace(with_clause_inline_view.identifier, replacement_sql)
+ if pretty:
+ sql = self.make_pretty_sql(sql)
+ return sql
+
+ def _write_data_type_metaclass(self, data_type_class):
+ '''Write a data type class such as Int or Boolean.'''
+ if issubclass(data_type_class, Int):
+ return 'INTEGER'
+ if issubclass(data_type_class, Float):
+ return 'DECIMAL(65, 15)'
+ if issubclass(data_type_class, String):
+ return 'CHAR'
+ if hasattr(data_type_class, 'MYSQL'):
+ return data_type_class.MYSQL[0]
+ return data_type_class.__name__.upper()
+
+ def _write_data_type(self, data_type):
+ '''Write a literal value.'''
+ if data_type.returns_timestamp:
+ return "CAST('{}' AS DATETIME)".format(data_type.val)
+ if data_type.returns_boolean:
+ # MySQL will error if a data_type "FALSE" is used as a GROUP BY field
+ return '(0 = 0)' if data_type.val else '(1 = 0)'
+ return SqlWriter._write_data_type(self, data_type)
+
+ def _write_cast(self, cast):
+ # Handle casts that produce different results across database types or just don't
+ # make sense like casting a DATE as a BOOLEAN....
+ if cast.val_expr.returns_boolean:
+ if issubclass(cast.resulting_type, Timestamp):
+ return "CAST(CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END AS DATETIME)"\
+ .format(self._write(cast.val_expr))
+ elif cast.val_expr.returns_number:
+ if issubclass(cast.resulting_type, Boolean):
+ return ("CASE WHEN ({0}) != 0 THEN TRUE WHEN ({0}) = 0 THEN FALSE END").format(
+ self._write(cast.val_expr))
+ if issubclass(cast.resulting_type, Timestamp):
+ return "CAST(CONCAT('2000-01-', ABS(FLOOR({})) % 31 + 1) AS DATETIME)"\
+ .format(self._write(cast.val_expr))
+ elif cast.val_expr.returns_string:
+ if issubclass(cast.resulting_type, Boolean):
+ return "(LENGTH({}) > 2)".format(self._write(cast.val_expr))
+ if issubclass(cast.resulting_type, Timestamp):
+ return ("CAST(CONCAT('2000-01-', LENGTH({}) % 31 + 1) AS DATETIME)").format(
+ self._write(cast.val_expr))
+ elif cast.val_expr.returns_timestamp:
+ if issubclass(cast.resulting_type, Number):
+ return ('(EXTRACT(DAY FROM {0}) '
+ '+ 100 * EXTRACT(MONTH FROM {0}) '
+ '+ 100 * 100 * EXTRACT(YEAR FROM {0}))').format(
+ self._write(cast.val_expr))
+ if issubclass(cast.resulting_type, Boolean):
+ return '(EXTRACT(DAY FROM {0}) > EXTRACT(MONTH FROM {0}))'.format(
+ self._write(cast.val_expr))
+
+ # MySQL uses different type names when casting...
+ if issubclass(cast.resulting_type, Boolean):
+ data_type = 'UNSIGNED'
+ elif issubclass(cast.resulting_type, Float):
+ data_type = 'DECIMAL(65, 15)'
+ elif issubclass(cast.resulting_type, Int):
+ data_type = 'SIGNED'
+ elif issubclass(cast.resulting_type, String):
+ data_type = 'CHAR'
+ elif issubclass(cast.resulting_type, Timestamp):
+ data_type = 'DATETIME'
+ return cast.FORMAT.format(self._write(cast.val_expr), data_type)
diff --git a/tests/comparison/query_generator.py b/tests/comparison/query_generator.py
new file mode 100644
index 000000000..9d238a419
--- /dev/null
+++ b/tests/comparison/query_generator.py
@@ -0,0 +1,556 @@
+# Copyright (c) 2014 Cloudera, Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import defaultdict
+from copy import deepcopy
+from itertools import chain
+from random import choice, randint, shuffle
+
+from tests.comparison.model import (
+ AGG_FUNCS,
+ AggFunc,
+ And,
+ BINARY_STRING_FUNCS,
+ BigInt,
+ Boolean,
+ Cast,
+ Column,
+ Count,
+ DataType,
+ Double,
+ Equals,
+ Float,
+ Floor,
+ FromClause,
+ Func,
+ Greatest,
+ GroupByClause,
+ HavingClause,
+ InlineView,
+ Int,
+ JoinClause,
+ Length,
+ MATH_OPERATORS,
+ Number,
+ Query,
+ RELATIONAL_OPERATORS,
+ SelectClause,
+ SelectItem,
+ String,
+ Table,
+ Timestamp,
+ TYPES,
+ UNARY_BOOLEAN_FUNCS,
+ UnionClause,
+ WhereClause,
+ WithClause,
+ WithClauseInlineView)
+
+def random_boolean():
+ '''Return a val that evaluates to True 50% of the time'''
+ return randint(0, 1)
+
+
+def zero_or_more():
+ '''The chance of the return val of n is 1 / 2 ^ (n + 1)'''
+ val = 0
+ while random_boolean():
+ val += 1
+ return val
+
+
+def one_or_more():
+ return zero_or_more() + 1
+
+
+def random_non_empty_split(iterable):
+ '''Return two non-empty lists'''
+ if len(iterable) < 2:
+ raise Exception('The iterable must contain at least two items')
+ split_index = randint(1, len(iterable) - 1)
+ left, right = list(), list()
+ for idx, item in enumerate(iterable):
+ if idx < split_index:
+ left.append(item)
+ else:
+ right.append(item)
+ return left, right
+
+
+class QueryGenerator(object):
+
+ def create_query(self,
+ table_exprs,
+ allow_with_clause=True,
+ select_item_data_types=None):
+ '''Create a random query using various language features.
+
+ The initial call to this method should only use tables in the table_exprs
+ parameter, and not inline views or "with" definitions. The other types of
+ table exprs may be added as part of the query generation.
+
+ If select_item_data_types is specified it must be a sequence or iterable of
+ DataType. The generated query.select_clause.select_items will have data
+ types suitable for use in a UNION.
+
+ '''
+ # Make a copy so tables can be added if a "with" clause is used
+ table_exprs = list(table_exprs)
+
+ with_clause = None
+ if allow_with_clause and randint(1, 10) == 1:
+ with_clause = self._create_with_clause(table_exprs)
+ table_exprs.extend(with_clause.table_exprs)
+
+ from_clause = self._create_from_clause(table_exprs)
+
+ select_clause = self._create_select_clause(
+ from_clause.table_exprs,
+ select_item_data_types=select_item_data_types)
+
+ query = Query(select_clause, from_clause)
+
+ if with_clause:
+ query.with_clause = with_clause
+
+ if random_boolean():
+ query.where_clause = self._create_where_clause(from_clause.table_exprs)
+
+ if select_clause.agg_items and select_clause.non_agg_items:
+ query.group_by_clause = GroupByClause(list(select_clause.non_agg_items))
+
+ if randint(1, 10) == 1:
+ if select_clause.agg_items:
+ self._enable_distinct_on_random_agg_items(select_clause.agg_items)
+ else:
+ select_clause.distinct = True
+
+ if random_boolean() and (query.group_by_clause or select_clause.agg_items):
+ query.having_clause = self._create_having_clause(from_clause.table_exprs)
+
+ if randint(1, 10) == 1:
+ select_item_data_types = list()
+ for select_item in select_clause.select_items:
+ # For numbers, choose the largest possible data type in case a CAST is needed.
+ if select_item.val_expr.returns_float:
+ select_item_data_types.append(Double)
+ elif select_item.val_expr.returns_int:
+ select_item_data_types.append(BigInt)
+ else:
+ select_item_data_types.append(select_item.val_expr.type)
+ query.union_clause = UnionClause(self.create_query(
+ table_exprs,
+ allow_with_clause=False,
+ select_item_data_types=select_item_data_types))
+ query.union_clause.all = random_boolean()
+
+ return query
+
+ def _create_with_clause(self, table_exprs):
+ # Make a copy so newly created tables can be added and made availabele for use in
+ # future table definitions.
+ table_exprs = list(table_exprs)
+ with_clause_inline_views = list()
+ for with_clause_inline_view_idx in xrange(one_or_more()):
+ query = self.create_query(table_exprs)
+ # To help prevent nested WITH clauses from having entries with the same alias,
+ # choose a random alias. Of course it would be much better to know which aliases
+ # were already chosen but that information isn't easy to get from here.
+ with_clause_alias = 'with_%s_%s' % \
+ (with_clause_inline_view_idx + 1, randint(1, 1000))
+ with_clause_inline_view = WithClauseInlineView(query, with_clause_alias)
+ table_exprs.append(with_clause_inline_view)
+ with_clause_inline_views.append(with_clause_inline_view)
+ return WithClause(with_clause_inline_views)
+
+ def _create_select_clause(self, table_exprs, select_item_data_types=None):
+ while True:
+ non_agg_items = [self._create_non_agg_select_item(table_exprs)
+ for _ in xrange(zero_or_more())]
+ agg_items = [self._create_agg_select_item(table_exprs)
+ for _ in xrange(zero_or_more())]
+ if non_agg_items or agg_items:
+ if select_item_data_types:
+ if len(select_item_data_types) > len(non_agg_items) + len(agg_items):
+ # Not enough items generated, try again
+ continue
+ while len(select_item_data_types) < len(non_agg_items) + len(agg_items):
+ items = choice([non_agg_items, agg_items])
+ if items:
+ items.pop()
+ for data_type_idx, data_type in enumerate(select_item_data_types):
+ if data_type_idx < len(non_agg_items):
+ item = non_agg_items[data_type_idx]
+ else:
+ item = agg_items[data_type_idx - len(non_agg_items)]
+ if not issubclass(item.type, data_type):
+ item.val_expr = self.convert_val_expr_to_type(item.val_expr, data_type)
+ for idx, item in enumerate(chain(non_agg_items, agg_items)):
+ item.alias = '%s_col_%s' % (item.type.__name__.lower(), idx + 1)
+ return SelectClause(non_agg_items=non_agg_items, agg_items=agg_items)
+
+ def _choose_col(self, table_exprs):
+ table_expr = choice(table_exprs)
+ return choice(table_expr.cols)
+
+ def _create_non_agg_select_item(self, table_exprs):
+ return SelectItem(self._create_val_expr(table_exprs))
+
+ def _create_val_expr(self, table_exprs):
+ vals = [self._choose_col(table_exprs) for _ in xrange(one_or_more())]
+ return self._combine_val_exprs(vals)
+
+ def _create_agg_select_item(self, table_exprs):
+ vals = [self._create_agg_val_expr(table_exprs) for _ in xrange(one_or_more())]
+ return SelectItem(self._combine_val_exprs(vals))
+
+ def _create_agg_val_expr(self, table_exprs):
+ val = self._create_val_expr(table_exprs)
+ if issubclass(val.type, Number):
+ funcs = list(AGG_FUNCS)
+ else:
+ funcs = [Count]
+ return choice(funcs)(val)
+
+ def _create_from_clause(self, table_exprs):
+ table_expr = self._create_table_expr(table_exprs)
+ table_expr_count = 1
+ table_expr.alias = 't%s' % table_expr_count
+ from_clause = FromClause(table_expr)
+ for join_idx in xrange(zero_or_more()):
+ join_clause = self._create_join_clause(from_clause, table_exprs)
+ table_expr_count += 1
+ join_clause.table_expr.alias = 't%s' % table_expr_count
+ from_clause.join_clauses.append(join_clause)
+ return from_clause
+
+ def _create_table_expr(self, table_exprs):
+ if randint(1, 10) == 1:
+ return self._create_inline_view(table_exprs)
+ return self._choose_table(table_exprs)
+
+ def _choose_table(self, table_exprs):
+ return deepcopy(choice(table_exprs))
+
+ def _create_inline_view(self, table_exprs):
+ return InlineView(self.create_query(table_exprs))
+
+ def _create_join_clause(self, from_clause, table_exprs):
+ table_expr = self._create_table_expr(table_exprs)
+ # Increase the chance of using the first join type which is INNER
+ join_type_idx = (zero_or_more() / 2) % len(JoinClause.JOINS_TYPES)
+ join_type = JoinClause.JOINS_TYPES[join_type_idx]
+ join_clause = JoinClause(join_type, table_expr)
+
+ # Prefer non-boolean cols for the first condition. Boolean cols produce too
+ # many results so it's unlikely that someone would want to join tables only using
+ # boolean cols.
+ non_boolean_types = set(type_ for type_ in TYPES if not issubclass(type_, Boolean))
+
+ if join_type != 'CROSS':
+ join_clause.boolean_expr = self._combine_val_exprs(
+ [self._create_relational_join_condition(
+ table_expr,
+ choice(from_clause.table_exprs),
+ prefered_data_types=(non_boolean_types if idx == 0 else set()))
+ for idx in xrange(one_or_more())],
+ resulting_type=Boolean)
+ return join_clause
+
+ def _create_relational_join_condition(self,
+ left_table_expr,
+ right_table_expr,
+ prefered_data_types):
+ # "base type" means condense all int types into just int, same for floats
+ left_cols_by_base_type = left_table_expr.cols_by_base_type
+ right_cols_by_base_type = right_table_expr.cols_by_base_type
+ common_col_types = set(left_cols_by_base_type) & set(right_cols_by_base_type)
+ if prefered_data_types:
+ common_col_types &= prefered_data_types
+ if common_col_types:
+ col_type = choice(list(common_col_types))
+ left = choice(left_cols_by_base_type[col_type])
+ right = choice(right_cols_by_base_type[col_type])
+ else:
+ col_type = None
+ if prefered_data_types:
+ for available_col_types in (left_cols_by_base_type, right_cols_by_base_type):
+ prefered_available_col_types = set(available_col_types) & prefered_data_types
+ if prefered_available_col_types:
+ col_type = choice(list(prefered_available_col_types))
+ break
+ if not col_type:
+ col_type = choice(left_cols_by_base_type.keys())
+
+ if col_type in left_cols_by_base_type:
+ left = choice(left_cols_by_base_type[col_type])
+ else:
+ left = choice(choice(left_cols_by_base_type.values()))
+ left = self.convert_val_expr_to_type(left, col_type)
+ if col_type in right_cols_by_base_type:
+ right = choice(right_cols_by_base_type[col_type])
+ else:
+ right = choice(choice(right_cols_by_base_type.values()))
+ right = self.convert_val_expr_to_type(right, col_type)
+ return Equals(left, right)
+
+ def _create_where_clause(self, table_exprs):
+ boolean_exprs = list()
+ # Create one boolean expr per iteration...
+ for _ in xrange(one_or_more()):
+ col_type = None
+ cols = list()
+ # ...using one or more cols...
+ for _ in xrange(one_or_more()):
+ # ...from any random table, inline view, etc.
+ table_expr = choice(table_exprs)
+ if not col_type:
+ col_type = choice(list(table_expr.cols_by_base_type))
+ if col_type in table_expr.cols_by_base_type:
+ col = choice(table_expr.cols_by_base_type[col_type])
+ else:
+ col = choice(table_expr.cols)
+ cols.append(col)
+ boolean_exprs.append(self._combine_val_exprs(cols, resulting_type=Boolean))
+ return WhereClause(self._combine_val_exprs(boolean_exprs))
+
+ def _combine_val_exprs(self, vals, resulting_type=None):
+ '''Combine the given vals into a single val.
+
+ If resulting_type is specified, the returned val will be of that type. If
+ the resulting data type was not specified, it will be randomly chosen from the
+ types of the input vals.
+
+ '''
+ if not vals:
+ raise Exception('At least one val is required')
+
+ types_to_vals = DataType.group_by_base_type(vals)
+
+ if not resulting_type:
+ resulting_type = choice(types_to_vals.keys())
+
+ vals_of_resulting_type = list()
+
+ for val_type, vals in types_to_vals.iteritems():
+ if issubclass(val_type, resulting_type):
+ vals_of_resulting_type.extend(vals)
+ elif resulting_type == Boolean:
+ # To produce other result types, the vals will be aggd into a single val
+ # then converted into the desired type. However to make a boolean, relational
+ # operaters can be used on the vals to make a more realistic query.
+ val = self._create_boolean_expr_from_vals_of_same_type(vals)
+ vals_of_resulting_type.append(val)
+ else:
+ val = self._combine_vals_of_same_type(vals)
+ if not (issubclass(val.type, Number) and issubclass(resulting_type, Number)):
+ val = self.convert_val_expr_to_type(val, resulting_type)
+ vals_of_resulting_type.append(val)
+
+ return self._combine_vals_of_same_type(vals_of_resulting_type)
+
+ def _create_boolean_expr_from_vals_of_same_type(self, vals):
+ if not vals:
+ raise Exception('At least one val is required')
+
+ if len(vals) == 1:
+ val = vals[0]
+ if Boolean == val.type:
+ return val
+ # Convert a single non-boolean val into a boolean using a func like
+ # IsNull or IsNotNull.
+ return choice(UNARY_BOOLEAN_FUNCS)(val)
+ if len(vals) == 2:
+ left, right = vals
+ if left.type == right.type:
+ if left.type == String:
+ # Databases may vary in how string comparisons are done. Results may differ
+ # when using operators like > or <, so just always use =.
+ return Equals(left, right)
+ if left.type == Boolean:
+ # TODO: Enable "OR" at some frequency, using OR at 50% will probably produce
+ # too many slow queries.
+ return And(left, right)
+ # At this point we've got two data points of the same type so any valid
+ # relational operator is valid and will produce a boolean.
+ return choice(RELATIONAL_OPERATORS)(left, right)
+ elif issubclass(left.type, Number) and issubclass(right.type, Number):
+ # Numbers need not be of the same type. SmallInt, BigInt, etc can all be compared.
+ # Note: For now ints are the only numbers enabled and division is disabled
+ # though AVG() is in use. If floats are enabled this will likely need to be
+ # updated to do some rounding based comparison.
+ return choice(RELATIONAL_OPERATORS)(left, right)
+ raise Exception('Vals are not of the same type: %s<%s> vs %s<%s>'
+ % (left, left.type, right, right.type))
+ # Reduce the number of inputs and try again...
+ left_subset, right_subset = random_non_empty_split(vals)
+ return self._create_boolean_expr_from_vals_of_same_type([
+ self._combine_vals_of_same_type(left_subset),
+ self._combine_vals_of_same_type(right_subset)])
+
+ def _combine_vals_of_same_type(self, vals):
+ '''Combine the given vals into a single expr of the same type. The input
+ vals must be of the same base data type. For example Int's must not be mixed
+ with Strings.
+
+ '''
+ if not vals:
+ raise Exception('At least one val is required')
+
+ val_type = None
+ for val in vals:
+ if not val_type:
+ if issubclass(val.type, Number):
+ val_type = Number
+ else:
+ val_type = val.type
+ elif not issubclass(val.type, val_type):
+ raise Exception('Incompatable types %s and %s' % (val_type, val.type))
+
+ if len(vals) == 1:
+ return vals[0]
+
+ if val_type == Number:
+ funcs = MATH_OPERATORS
+ elif val_type == Boolean:
+ # TODO: Enable "OR" at some frequency
+ funcs = [And]
+ elif val_type == String:
+ funcs = BINARY_STRING_FUNCS
+ return vals[0]
+ elif val_type == Timestamp:
+ funcs = [Greatest]
+
+ vals = list(vals)
+ shuffle(vals)
+ left = vals.pop()
+ right = vals.pop()
+ while True:
+ func = choice(funcs)
+ left = func(left, right)
+ if not vals:
+ return left
+ right = vals.pop()
+
+ def convert_val_expr_to_type(self, val_expr, resulting_type):
+ if resulting_type not in TYPES:
+ raise Exception('Unexpected type: {}'.format(resulting_type))
+ val_type = val_expr.type
+ if issubclass(val_type, resulting_type):
+ return val_expr
+
+ if issubclass(resulting_type, Int):
+ if val_expr.returns_float:
+ # Impala will FLOOR while Postgresql will ROUND. Use FLOOR to be conistent.
+ return Floor(val_expr)
+ if issubclass(resulting_type, Number):
+ if val_expr.returns_string:
+ return Length(val_expr)
+ if issubclass(resulting_type, String):
+ if val_expr.returns_float:
+ # Different databases may use different precision.
+ return Cast(Floor(val_expr), resulting_type)
+
+ return Cast(val_expr, resulting_type)
+
+ def _create_having_clause(self, table_exprs):
+ boolean_exprs = list()
+ # Create one boolean expr per iteration...
+ for _ in xrange(one_or_more()):
+ agg_items = list()
+ # ...using one or more agg exprs...
+ for _ in xrange(one_or_more()):
+ vals = [self._create_agg_val_expr(table_exprs) for _ in xrange(one_or_more())]
+ agg_items.append(self._combine_val_exprs(vals))
+ boolean_exprs.append(self._combine_val_exprs(agg_items, resulting_type=Boolean))
+ return HavingClause(self._combine_val_exprs(boolean_exprs))
+
+ def _enable_distinct_on_random_agg_items(self, agg_items):
+ '''Randomly choose an agg func and set it to use DISTINCT'''
+ # Impala has a limitation where 'DISTINCT' may only be applied to one agg
+ # expr. If an agg expr is used more than once, each usage may
+ # or may not include DISTINCT.
+ #
+ # Examples:
+ # OK: SELECT COUNT(DISTINCT a) + SUM(DISTINCT a) + MAX(a)...
+ # Not OK: SELECT COUNT(DISTINCT a) + COUNT(DISTINCT b)...
+ #
+ # Given a select list like:
+ # COUNT(a), SUM(a), MAX(b)
+ #
+ # We want to ouput one of:
+ # COUNT(DISTINCT a), SUM(DISTINCT a), AVG(b)
+ # COUNT(DISTINCT a), SUM(a), AVG(b)
+ # COUNT(a), SUM(a), AVG(DISTINCT b)
+ #
+ # This will be done by first grouping all agg funcs by their inner
+ # expr:
+ # {a: [COUNT(a), SUM(a)],
+ # b: [MAX(b)]}
+ #
+ # then choosing a random val (which is a list of aggs) in the above dict, and
+ # finaly randomly adding DISTINCT to items in the list.
+ exprs_to_funcs = defaultdict(list)
+ for item in agg_items:
+ for expr, funcs in self._group_agg_funcs_by_expr(item.val_expr).iteritems():
+ exprs_to_funcs[expr].extend(funcs)
+ funcs = choice(exprs_to_funcs.values())
+ for func in funcs:
+ if random_boolean():
+ func.distinct = True
+
+ def _group_agg_funcs_by_expr(self, val_expr):
+ '''Group exprs and return a dict mapping the expr to the agg items
+ it is used in.
+
+ Example: COUNT(a) * SUM(a) - MAX(b) + MIN(c) -> {a: [COUNT(a), SUM(a)],
+ b: [MAX(b)],
+ c: [MIN(c)]}
+
+ '''
+ exprs_to_funcs = defaultdict(list)
+ if isinstance(val_expr, AggFunc):
+ exprs_to_funcs[tuple(val_expr.args)].append(val_expr)
+ elif isinstance(val_expr, Func):
+ for arg in val_expr.args:
+ for expr, funcs in self._group_agg_funcs_by_expr(arg).iteritems():
+ exprs_to_funcs[expr].extend(funcs)
+ # else: The remaining case could happen if the original expr was something like
+ # "SUM(a) + b + 1" where b is a GROUP BY field.
+ return exprs_to_funcs
+
+
+if __name__ == '__main__':
+ '''Generate some queries for manual inspection. The query won't run anywhere because the
+ tables used are fake. To make real queries, we'd need to connect to a database and
+ read the table metadata and such.
+ '''
+ tables = list()
+ data_types = TYPES
+ data_types.remove(Float)
+ data_types.remove(Double)
+ for table_idx in xrange(5):
+ table = Table('table_%s' % table_idx)
+ tables.append(table)
+ for col_idx in xrange(3):
+ col_type = choice(data_types)
+ col = Column(table, '%s_col_%s' % (col_type.__name__.lower(), col_idx), col_type)
+ table.cols.append(col)
+
+ query_generator = QueryGenerator()
+ from model_translator import SqlWriter
+ sql_writer = SqlWriter.create()
+ for _ in range(3000):
+ query = query_generator.create_query(tables)
+ print(sql_writer.write_query(query) + '\n')