diff --git a/bin/copy-test-data.sh b/bin/copy-test-data.sh index 56524dc6a..8fe63cbdd 100755 --- a/bin/copy-test-data.sh +++ b/bin/copy-test-data.sh @@ -14,4 +14,4 @@ DATASRC="a2226.halxg.cloudera.com:/data/1/workspace/impala-data" DATADST=$IMPALA_HOME/testdata/impala-data mkdir -p $DATADST -scp -i $IMPALA_HOME/ssh_keys/id_rsa_impala -o "StrictHostKeyChecking=no" -r $DATASRC/* $DATADST +scp -i $HOME/.ssh/id_rsa_jenkins -o "StrictHostKeyChecking=no" -r systest@$DATASRC/* $DATADST diff --git a/tests/comparison/README b/tests/comparison/README new file mode 100644 index 000000000..059de094d --- /dev/null +++ b/tests/comparison/README @@ -0,0 +1,164 @@ +Purpose: + +This package is intended to augment the standard test suite. The standard tests are +more efficient with regards to features tested versus execution time. However their +coverage as a test suite still leaves gaps in query coverage. This package provides a +random query generator to compare the results of a wide range of queries against a +reference database engine. The queries will range from very simple single table selects to +extremely complicated with multiple level of nesting. This method of testing will be +slower but has a larger coverage area. + + +Requirements: + +1) It's assumed that Impala is running locally. + +2) Impyla -- an implementation of DB API 2 for Impala. + + sudo pip install git+http://github.com/laserson/impyla.git#impyla + +3) At least one python driver for a reference database. + + sudo apt-get install python-mysqldb + sudo apt-get install python-psycopg2 # Postgresql + + +Usage: + +1) Generate test data + + ./data_generator.py --use-mysql + + This will generate tables and data in MySQL and Impala + + +2) Run the comparison + + ./discrepancy_searcher.py + + This will generate queries using the test database and compare the results against + MySQL (the default). + + +Known Issues: + +1) Floats will produce false-positives. For example the results of a query that has + + SELECT FLOOR(COUNT(...) * AVG(...)) AS col_1 + + will produce different results on Impala and MySQL if COUNT() == 3 and AVG() == 1/3. + One of the databasses will FLOOR(1) while the other will FLOOR(0.999). + + Maybe this could be worked around or reduced by replacing all uses of AVG() with + "AVG() + foo", where foo is some number that makes it unlikely that + "COUNT() * (AVG() + foo)" will result in an int. + + I'd guess this issue comes up in 1 out of 10-20k queries. + +2) Impyla may fail with "Invalid query handle". Some queries will fail every time when run + through Impyla but run fine through the impala-shell. I need to research more and file + an issue with Impyla. + +3) Impyla will fail with "Invalid session". I'm pretty sure this is also an Impyla issue + but also need to investigate more. + + +Things to Know: + +1) A good number of queries to run seems to be about 5k. Ideally each test run would + discover the complete list of known issues. From experience a 1k query test run may + complete without finding any issues that were discovered in previous runs. 5k seems + to be about the magic number were most issues will be rediscovered. This can take 1-2 + hours. However as of this writing it's rare to run 1k queries without finding at + least one discrepancy. + +2) It's possible to provide a randomization seed so that the randomness is actually + reproducable. The data generation currently has a default seed so will always produce + the same tables. This also mean if a new data type is added those generated tables + will change. + +3) There is a query log. It's possible that a sequence of queries is required to expose + a bug. If you come across a failure that can't be reproduced by rerunning the failed + query, try running the queries leading up to that query as well. + + +Miscellaneous: + +1) Instead of generating new random queries with each run, it may be better to reuse a + list of queries from a previous run that are known to produce results. As of this + writing only about 50% of queries produce results. So it may be better to trade high + randomness for higher quality queries. For example it would be possible to build up a + library of 100k queries that produce results then randomly select 2.5k of those. + Maybe that would provide testing equivalent to 5k totally random queries in less + time. + + This would also be useful in eliminating queries that have known issues above. + + +Postgresql: + +1) Supports bascially all Impala language features + +2) Does int division, 1 / 2 == 0 + +3) Has strange sorting of strings, '-1' > '1'. This may be important if ORDER BY is ever + used. The databases being compared would need to have the same collation, which is + probably configurable. + +4) This was the original reference database but I moved to MySQL while trying to add + support for floats and never moved back. + + +MySQL: + +1) Supports bascially all Impala language features, except WITH clause requires emulation + with inline views. + +2) Has poor boolean support. It may be worth switching back to Postgresql for this. + + +Improvements: + +1) Add support for simplifing buggy queries. When a random query fails the comparison + check it is basically always much too complex for directly posting a bug report. It + is also time consuming to simplify the queries because there is a lot of trial and + error and manually editing queries. + +2) Add more language features + + a) SEMI JOIN + b) ORDER BY + c) LIMIT, OFFSET + d) CASE / WHEN + +3) Add common built-in functions. Ex: CAST, IF, NVL, ... + +4) Make randomization of the query generation configurable. As of this writing all the + probabilities are hard-coded. At a very minimum it should be easy to disable or force + the use of some language features such as CROSS JOIN, GROUP BY, etc. + +5) More investingation of using the existing "functional" test datasets. A very quick + trial run wasn't successful but another attempt with more effort should be made before + introducing a new dataset. + + I suspect the problem with using the functional dataset was that I only imported a few + tables, maybe alltypes, alltypesagg, and something else. I don't think I imported the + tiny tables since the odds of them producing results from a random query would be + very low. + +6) If the functional dataset cannot be used, someone should think more about what the + random data should be like. Only a few minutes of thought were put into selecting + random value ranges (including number of tables and columns), and it's not clear how + important those ranges are. + +7) Add support for comparing results with codegen enabled and disabled. Uri recently added + support for query options in Impyla. + +8) Consider adding Oracle or SQL Server support, these could be useful in the future for + analytic queries. + +9) Try running with tables in various formats. Ex: parquet and/or avro. + +10) Support for more data types. Only int types are known to give good results. + Floats may work but non-numeric types are not supported yet. + diff --git a/tests/comparison/__init__.py b/tests/comparison/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/comparison/data_generator.py b/tests/comparison/data_generator.py new file mode 100755 index 000000000..9944dd75d --- /dev/null +++ b/tests/comparison/data_generator.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python + +# Copyright (c) 2014 Cloudera, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''This module provides random data generation and database population. + + When this module is run directly for purposes of database population, the default is + to use a fixed seed for randomization. The result should be that the generated random + data is the same regardless of when or where the execution is done. + +''' + +from datetime import datetime, timedelta +from logging import basicConfig, getLogger +from random import choice, randint, random, seed, uniform + +from tests.comparison.db_connector import ( + DbConnector, + IMPALA, + MYSQL, + POSTGRESQL) +from tests.comparison.model import ( + Boolean, + Column, + Float, + Int, + Number, + String, + Table, + Timestamp, + TYPES) + +LOG = getLogger(__name__) + +class RandomValGenerator(object): + '''This class will generate random data of various data types. Currently only numeric + and string data types are supported. + + ''' + + def __init__(self, + min_number=-1000, + max_number=1000, + min_date=datetime(1990, 1, 1), + max_date=datetime(2030, 1, 1), + null_val_percentage=0.1): + self.min_number = min_number + self.max_number = max_number + self.min_date = min_date + self.max_date = max_date + self.null_val_percentage = null_val_percentage + + def generate_val(self, val_type): + '''Generate and return a single random val. Use the val_type parameter to + specify the type of val to generate. See model.DataType for valid val_type + options. + + Ex: + generator = RandomValGenerator(min_number=1, max_number=5) + val = generator.generate_val(model.Int) + assert 1 <= val and val <= 5 + ''' + if issubclass(val_type, String): + val = self.generate_val(Int) + return None if val is None else str(val) + if random() < self.null_val_percentage: + return None + if issubclass(val_type, Int): + return randint( + max(self.min_number, val_type.MIN), min(val_type.MAX, self.max_number)) + if issubclass(val_type, Number): + return uniform(self.min_number, self.max_number) + if issubclass(val_type, Timestamp): + delta = self.max_date - self.min_date + delta_in_seconds = delta.days * 24 * 60 * 60 + delta.seconds + offset_in_seconds = randint(0, delta_in_seconds) + val = self.min_date + timedelta(0, offset_in_seconds) + return datetime(val.year, val.month, val.day) + if issubclass(val_type, Boolean): + return randint(0, 1) == 1 + raise Exception('Unsupported type %s' % val_type.__name__) + + +class DatabasePopulator(object): + '''This class will populate a database with randomly generated data. The population + includes table creation and data generation. Table names are hard coded as + table_. + + ''' + + def __init__(self): + self.val_generator = RandomValGenerator() + + def populate_db_with_random_data(self, + db_name, + db_connectors, + number_of_tables=10, + allowed_data_types=TYPES, + create_files=False): + '''Create tables with a random number of cols with data types chosen from + allowed_data_types, then fill the tables with data. + + The given db_name must have already been created. + + ''' + connections = [connector.create_connection(db_name=db_name) + for connector in db_connectors] + for table_idx in xrange(number_of_tables): + table = self.create_random_table( + 'table_%s' % (table_idx + 1), + allowed_data_types=allowed_data_types) + for connection in connections: + sql = self.make_create_table_sql(table, dialect=connection.db_type) + LOG.info('Creating %s table %s', connection.db_type, table.name) + if create_files: + with open('%s_%s.sql' % (table.name, connection.db_type.lower()), 'w') \ + as f: + f.write(sql + '\n') + connection.execute(sql) + LOG.info('Inserting data into %s', table.name) + for _ in xrange(100): # each iteration will insert 100 rows + rows = self.generate_table_data(table) + for connection in connections: + sql = self.make_insert_sql_from_data( + table, rows, dialect=connection.db_type) + if create_files: + with open('%s_%s.sql' % + (table.name, connection.db_type.lower()), 'a') as f: + f.write(sql + '\n') + try: + connection.execute(sql) + except: + LOG.error('Error executing SQL: %s', sql) + raise + + self.index_tables_in_database(connections) + + for connection in connections: + connection.close() + + def migrate_database(self, + db_name, + source_db_connector, + destination_db_connectors, + include_table_names=None): + '''Read table metadata and data from the source database and create a replica in + the destination databases. For example, the Impala funcal test database could + be copied into Postgresql. + + source_db_connector and items in destination_db_connectors should be + of type db_connector.DbConnector. destination_db_connectors and + include_table_names should be iterables. + ''' + source_connection = source_db_connector.create_connection(db_name) + + cursors = [connector.create_connection(db_name=db_name).create_cursor() + for connector in destination_db_connectors] + + for table_name in source_connection.list_table_names(): + if include_table_names and table_name not in include_table_names: + continue + try: + table = source_connection.describe_table(table_name) + except Exception as e: + LOG.warn('Error fetching metadata for %s: %s', table_name, e) + continue + for destination_cursor in cursors: + sql = self.make_create_table_sql( + table, dialect=destination_cursor.connection.db_type) + destination_cursor.execute(sql) + with source_connection.open_cursor() as source_cursor: + try: + source_cursor.execute('SELECT * FROM ' + table_name) + while True: + rows = source_cursor.fetchmany(size=100) + if not rows: + break + for destination_cursor in cursors: + sql = self.make_insert_sql_from_data( + table, rows, dialect=destination_cursor.connection.db_type) + destination_cursor.execute(sql) + except Exception as e: + LOG.error('Error fetching data for %s: %s', table_name, e) + continue + + self.index_tables_in_database([cursor.connection for cursor in cursors]) + + for cursor in cursors: + cursor.close() + cursor.connection.close() + + def create_random_table(self, table_name, allowed_data_types): + '''Create and return a Table with a random number of cols chosen from the + given allowed_data_types. + + ''' + data_type_count = len(allowed_data_types) + col_count = randint(data_type_count / 2, data_type_count * 2) + table = Table(table_name) + for col_idx in xrange(col_count): + col_type = choice(allowed_data_types) + col = Column( + table, + '%s_col_%s' % (col_type.__name__.lower(), col_idx + 1), + col_type) + table.cols.append(col) + return table + + def make_create_table_sql(self, table, dialect=IMPALA): + sql = 'CREATE TABLE %s (%s)' % ( + table.name, + ', '.join('%s %s' % + (col.name, self.get_sql_for_data_type(col.type, dialect)) + + ('' if dialect == IMPALA else ' NULL') + for col in table.cols)) + if dialect == MYSQL: + sql += ' ENGINE = MYISAM' + return sql + + def get_sql_for_data_type(self, data_type, dialect=IMPALA): + # Check to see if there is an alias and if so, use the first one + if hasattr(data_type, dialect): + return getattr(data_type, dialect)[0] + return data_type.__name__.upper() + + def make_insert_sql_from_data(self, table, rows, dialect=IMPALA): + # TODO: Consider using parameterized inserts so the database connector handles + # formatting the data. For example the CAST to workaround IMPALA-803 can + # probably be removed. The vals were generated this way so a data file + # could be made and attached to jiras. + if not rows: + raise Exception('At least one row is required') + if not table.cols: + raise Exception('At least one col is required') + + sql = 'INSERT INTO %s VALUES ' % table.name + for row_idx, row in enumerate(rows): + if row_idx > 0: + sql += ', ' + sql += '(' + for col_idx, col in enumerate(table.cols): + if col_idx > 0: + sql += ', ' + val = row[col_idx] + if val is None: + sql += 'NULL' + elif issubclass(col.type, Timestamp): + if dialect != IMPALA: + sql += 'TIMESTAMP ' + sql += "'%s'" % val + elif issubclass(col.type, String): + val = val.replace("'", "''") + if dialect == POSTGRESQL: + val = val.replace('\\', '\\\\') + sql += "'%s'" % val + elif dialect == IMPALA \ + and issubclass(col.type, Float): + # https://issues.cloudera.org/browse/IMPALA-803 + sql += 'CAST(%s AS FLOAT)' % val + else: + sql += str(val) + sql += ')' + return sql + + def generate_table_data(self, table, number_of_rows=100): + rows = list() + for row_idx in xrange(number_of_rows): + row = list() + for col in table.cols: + row.append(self.val_generator.generate_val(col.type)) + rows.append(row) + return rows + + def drop_and_create_database(self, db_name, db_connectors): + for connector in db_connectors: + with connector.open_connection() as connection: + connection.drop_db_if_exists(db_name) + connection.execute('CREATE DATABASE ' + db_name) + + def index_tables_in_database(self, connections): + for connection in connections: + if connection.supports_index_creation: + for table_name in connection.list_table_names(): + LOG.info('Indexing %s on %s' % (table_name, connection.db_type)) + connection.index_table(table_name) + +if __name__ == '__main__': + from optparse import NO_DEFAULT, OptionGroup, OptionParser + + parser = OptionParser( + usage='usage: \n' + ' %prog [options] [populate]\n\n' + ' Create and populate database(s). The Impala database will always be \n' + ' included, the other database types are optional.\n\n' + ' %prog [options] migrate\n\n' + ' Migrate an Impala database to another database type. The destination \n' + ' database will be dropped and recreated.') + parser.add_option('--log-level', default='INFO', + help='The log level to use.', choices=('DEBUG', 'INFO', 'WARN', 'ERROR')) + parser.add_option('--db-name', default='randomness', + help='The name of the database to use. Ex: functional.') + + group = OptionGroup(parser, 'MySQL Options') + group.add_option('--use-mysql', action='store_true', default=False, + help='Use MySQL') + group.add_option('--mysql-host', default='localhost', + help='The name of the host running the MySQL database.') + group.add_option('--mysql-port', default=3306, type=int, + help='The port of the host running the MySQL database.') + group.add_option('--mysql-user', default='root', + help='The user name to use when connecting to the MySQL database.') + group.add_option('--mysql-password', + help='The password to use when connecting to the MySQL database.') + parser.add_option_group(group) + + group = OptionGroup(parser, 'Postgresql Options') + group.add_option('--use-postgresql', action='store_true', default=False, + help='Use Postgresql') + group.add_option('--postgresql-host', default='localhost', + help='The name of the host running the Postgresql database.') + group.add_option('--postgresql-port', default=5432, type=int, + help='The port of the host running the Postgresql database.') + group.add_option('--postgresql-user', default='postgres', + help='The user name to use when connecting to the Postgresql database.') + group.add_option('--postgresql-password', + help='The password to use when connecting to the Postgresql database.') + parser.add_option_group(group) + + group = OptionGroup(parser, 'Database Population Options') + group.add_option('--randomization-seed', default=1, type='int', + help='The randomization will be initialized with this seed. Using the same seed ' + 'will produce the same results across runs.') + group.add_option('--create-data-files', default=False, action='store_true', + help='Create files that can be used to repopulate the databasese elsewhere.') + group.add_option('--table-count', default=10, type='int', + help='The number of tables to generate.') + parser.add_option_group(group) + + group = OptionGroup(parser, 'Database Migration Options') + group.add_option('--migrate-table-names', + help='Table names should be separated with commas. The default is to migrate all ' + 'tables.') + parser.add_option_group(group) + + for group in parser.option_groups + [parser]: + for option in group.option_list: + if option.default != NO_DEFAULT: + option.help += ' [default: %default]' + + options, args = parser.parse_args() + command = args[0] if args else 'populate' + if len(args) > 1 or command not in ['populate', 'migrate']: + raise Exception('Command must either be "populate" or "migrate" but was "%s"' % + ' '.join(args)) + if command == 'migrate' and not any((options.use_mysql, options.use_postgresql)): + raise Exception('At least one destination database must be chosen with ' + '--use-') + + basicConfig(level=options.log_level) + + seed(options.randomization_seed) + + impala_connector = DbConnector(IMPALA) + db_connectors = [] + if options.use_postgresql: + db_connectors.append(DbConnector(POSTGRESQL, + user_name=options.postgresql_user, + password=options.postgresql_password, + host_name=options.postgresql_host, + port=options.postgresql_port)) + if options.use_mysql: + db_connectors.append(DbConnector(MYSQL, + user_name=options.mysql_user, + password=options.mysql_password, + host_name=options.mysql_host, + port=options.mysql_port)) + + populator = DatabasePopulator() + if command == 'populate': + db_connectors.append(impala_connector) + populator.drop_and_create_database(options.db_name, db_connectors) + populator.populate_db_with_random_data( + options.db_name, + db_connectors, + number_of_tables=options.table_count, + create_files=options.create_data_files) + else: + populator.drop_and_create_database(options.db_name, db_connectors) + if options.migrate_table_names: + table_names = options.migrate_table_names.split(',') + else: + table_names = None + populator.migrate_database( + options.db_name, + impala_connector, + db_connectors, + include_table_names=table_names) diff --git a/tests/comparison/db_connector.py b/tests/comparison/db_connector.py new file mode 100644 index 000000000..5e9c98aa4 --- /dev/null +++ b/tests/comparison/db_connector.py @@ -0,0 +1,411 @@ +# Copyright (c) 2014 Cloudera, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''This module is intended to standardize workflows when working with various databases + such as Impala, Postgresql, etc. Even with pep-249 (DB API 2), workflows differ + slightly. For example Postgresql does not allow changing databases from within a + connection, instead a new connection must be made. However Impala does not allow + specifying a database upon connection, instead a cursor must be created and a USE + command must be issued. + +''' + +from contextlib import contextmanager +try: + from impala.dbapi import connect as impala_connect +except: + print('Error importing impyla. Please make sure it is installed. ' + 'See the README for details.') + raise +from itertools import izip +from logging import getLogger +from tests.comparison.model import Column, Table, TYPES, String + +LOG = getLogger(__name__) + +IMPALA = 'IMPALA' +POSTGRESQL = 'POSTGRESQL' +MYSQL = 'MYSQL' + +DATABASES = [IMPALA, POSTGRESQL, MYSQL] + +mysql_connect = None +postgresql_connect = None + +class DbConnector(object): + '''Wraps a DB API 2 implementation to provide a standard way of obtaining a + connection and selecting a database. + + Any database that supports transactions will have auto-commit enabled. + + ''' + + def __init__(self, db_type, user_name=None, password=None, host_name=None, port=None): + self.db_type = db_type.upper() + if self.db_type not in DATABASES: + raise Exception('Unsupported database: %s' % db_type) + self.user_name = user_name + self.password = password + self.host_name = host_name or 'localhost' + self.port = port + + def create_connection(self, db_name=None): + if self.db_type == IMPALA: + connection_class = ImpalaDbConnection + connection = impala_connect(host=self.host_name, port=self.port or 21050) + elif self.db_type == POSTGRESQL: + connection_class = PostgresqlDbConnection + connection_args = {'user': self.user_name or 'postgres'} + if self.password: + connection_args['password'] = self.password + if db_name: + connection_args['database'] = db_name + if self.host_name: + connection_args['host'] = self.host_name + if self.port: + connection_args['port'] = self.port + global postgresql_connect + if not postgresql_connect: + try: + from psycopg2 import connect as postgresql_connect + except: + print('Error importing psycopg2. Please make sure it is installed. ' + 'See the README for details.') + raise + connection = postgresql_connect(**connection_args) + connection.autocommit = True + elif self.db_type == MYSQL: + connection_class = MySQLDbConnection + connection_args = {'user': self.user_name or 'root'} + if self.password: + connection_args['passwd'] = self.password + if db_name: + connection_args['db'] = db_name + if self.host_name: + connection_args['host'] = self.host_name + if self.port: + connection_args['port'] = self.port + global mysql_connect + if not mysql_connect: + try: + from MySQLdb import connect as mysql_connect + except: + print('Error importing MySQLdb. Please make sure it is installed. ' + 'See the README for details.') + raise + connection = mysql_connect(**connection_args) + else: + raise Exception('Unexpected database type: %s' % self.db_type) + return connection_class(self, connection, db_name=db_name) + + @contextmanager + def open_connection(self, db_name=None): + connection = None + try: + connection = self.create_connection(db_name=db_name) + yield connection + finally: + if connection: + try: + connection.close() + except Exception as e: + LOG.debug('Error closing connection: %s', e, exc_info=True) + + +class DbConnection(object): + '''Wraps a DB API 2 connection. Instances should only be obtained through the + DbConnector.create_connection(...) method. + + ''' + + @staticmethod + def describe_common_tables(db_connections, filter_col_types=[]): + '''Find and return a list of Table objects that the given connections have in + common. + + @param filter_col_types: Ignore any cols if they are of a data type contained + in this collection. + + ''' + common_table_names = None + for db_connection in db_connections: + table_names = set(db_connection.list_table_names()) + if common_table_names is None: + common_table_names = table_names + else: + common_table_names &= table_names + common_table_names = sorted(common_table_names) + + tables = list() + for table_name in common_table_names: + common_table = None + mismatch = False + for db_connection in db_connections: + table = db_connection.describe_table(table_name) + table.cols = [col for col in table.cols if col.type not in filter_col_types] + if common_table is None: + common_table = table + continue + if len(common_table.cols) != len(table.cols): + LOG.debug('Ignoring table %s.' + ' It has a different number of columns across databases.', table_name) + mismatch = True + break + for left, right in izip(common_table.cols, table.cols): + if not left.name == right.name and left.type == right.type: + LOG.debug('Ignoring table %s. It has different columns %s vs %s.' % + (table_name, left, right)) + mismatch = True + break + if mismatch: + break + if not mismatch: + tables.append(common_table) + + return tables + + def __init__(self, connector, connection, db_name=None): + self.connector = connector + self.connection = connection + self.db_name = db_name + + @property + def db_type(self): + return self.connector.db_type + + def create_cursor(self): + return DatabaseCursor(self.connection.cursor(), self) + + @contextmanager + def open_cursor(self): + '''Returns a new cursor for use in a "with" statement. When the "with" statement ends, + the cursor will be closed. + + ''' + cursor = None + try: + cursor = self.create_cursor() + yield cursor + finally: + self.close_cursor_quietly(cursor) + + def close_cursor_quietly(self, cursor): + if cursor: + try: + cursor.close() + except Exception as e: + LOG.debug('Error closing cursor: %s', e, exc_info=True) + + def list_db_names(self): + '''Return a list of database names always in lowercase.''' + rows = self.execute_and_fetchall(self.make_list_db_names_sql()) + return [row[0].lower() for row in rows] + + def make_list_db_names_sql(self): + return 'SHOW DATABASES' + + def list_table_names(self): + '''Return a list of table names always in lowercase.''' + rows = self.execute_and_fetchall(self.make_list_table_names_sql()) + return [row[0].lower() for row in rows] + + def make_list_table_names_sql(self): + return 'SHOW TABLES' + + def describe_table(self, table_name): + '''Return a Table with table and col names always in lowercase.''' + rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name)) + table = Table(table_name.lower()) + for row in rows: + col_name, data_type = row[:2] + table.cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type))) + return table + + def make_describe_table_sql(self, table_name): + return 'DESCRIBE ' + table_name + + def parse_data_type(self, sql): + sql = sql.upper() + # Types may have declared a database specific alias + for type_ in TYPES: + if sql in getattr(type_, self.db_type, []): + return type_ + for type_ in TYPES: + if type_.__name__.upper() == sql: + return type_ + if 'CHAR' in sql: + return String + raise Exception('Unknown data type: ' + sql) + + def create_database(self, db_name): + db_name = db_name.lower() + with self.open_cursor() as cursor: + cursor.execute('CREATE DATABASE ' + db_name) + + def drop_db_if_exists(self, db_name): + '''This should not be called from a connection to the database being dropped.''' + db_name = db_name.lower() + if db_name not in self.list_db_names(): + return + if self.db_name and self.db_name.lower() == db_name: + raise Exception('Cannot drop database while still connected to it') + self.drop_database(db_name) + + def drop_database(self, db_name): + db_name = db_name.lower() + self.execute('DROP DATABASE ' + db_name) + + @property + def supports_index_creation(self): + return True + + def index_table(self, table_name): + table = self.describe_table(table_name) + with self.open_cursor() as cursor: + for col in table.cols: + index_name = '%s_%s' % (table_name, col.name) + if self.db_name: + index_name = '%s_%s' % (self.db_name, index_name) + cursor.execute('CREATE INDEX %s ON %s(%s)' % (index_name, table_name, col.name)) + + @property + def supports_kill_connection(self): + return False + + def kill_connection(self): + '''Kill the current connection and any currently running queries assosiated with the + connection. + ''' + raise Exception('Killing connection is not supported') + + def materialize_query(self, query_as_text, table_name): + self.execute('CREATE TABLE %s AS %s' % (table_name.lower(), query_as_text)) + + def drop_table(self, table_name): + self.execute('DROP TABLE ' + table_name.lower()) + + def execute(self, sql): + with self.open_cursor() as cursor: + cursor.execute(sql) + + def execute_and_fetchall(self, sql): + with self.open_cursor() as cursor: + cursor.execute(sql) + return cursor.fetchall() + + def close(self): + '''Close the underlying connection.''' + self.connection.close() + + def reconnect(self): + self.close() + other = self.connector.create_connection(db_name=self.db_name) + self.connection = other.connection + + +class DatabaseCursor(object): + '''Wraps a DB API 2 cursor to provide access to the related connection. This class + implements the DB API 2 interface by delegation. + + ''' + + def __init__(self, cursor, connection): + self.cursor = cursor + self.connection = connection + + def __getattr__(self, attr): + return getattr(self.cursor, attr) + + +class ImpalaDbConnection(DbConnection): + + def create_cursor(self): + cursor = DbConnection.create_cursor(self) + if self.db_name: + cursor.execute('USE %s' % self.db_name) + return cursor + + def drop_database(self, db_name): + '''This should not be called from a connection to the database being dropped.''' + db_name = db_name.lower() + with self.connector.open_connection(db_name) as list_tables_connection: + with list_tables_connection.open_cursor() as drop_table_cursor: + for table_name in list_tables_connection.list_table_names(): + drop_table_cursor.execute('DROP TABLE ' + table_name) + self.execute('DROP DATABASE ' + db_name) + + @property + def supports_index_creation(self): + return False + + +class PostgresqlDbConnection(DbConnection): + + def make_list_db_names_sql(self): + return 'SELECT datname FROM pg_database' + + def make_list_table_names_sql(self): + return ''' + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' ''' + + def make_describe_table_sql(self, table_name): + return ''' + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_name = '%s' + ORDER BY ordinal_position''' % table_name + + +class MySQLDbConnection(DbConnection): + + def __init__(self, connector, connection, db_name=None): + DbConnection.__init__(self, connector, connection, db_name=db_name) + self.session_id = self.execute_and_fetchall('SELECT connection_id()')[0][0] + + def describe_table(self, table_name): + '''Return a Table with table and col names always in lowercase.''' + rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name)) + table = Table(table_name.lower()) + for row in rows: + col_name, data_type = row[:2] + if data_type == 'tinyint(1)': + # Just assume this is a boolean... + data_type = 'boolean' + if '(' in data_type: + # Strip the size of the data type + data_type = data_type[:data_type.index('(')] + table.cols.append(Column(table, col_name.lower(), self.parse_data_type(data_type))) + return table + + @property + def supports_kill_connection(self): + return True + + def kill_connection(self): + with self.connector.open_connection(db_name=self.db_name) as connection: + connection.execute('KILL %s' % (self.session_id)) + + def index_table(self, table_name): + table = self.describe_table(table_name) + with self.open_cursor() as cursor: + for col in table.cols: + try: + cursor.execute('ALTER TABLE %s ADD INDEX (%s)' % (table_name, col.name)) + except Exception as e: + if 'Incorrect index name' not in str(e): + raise + # Some sort of MySQL bug... + LOG.warn('Could not create index on %s.%s: %s' % (table_name, col.name, e)) diff --git a/tests/comparison/discrepancy_searcher.py b/tests/comparison/discrepancy_searcher.py new file mode 100755 index 000000000..1ebfb3fd9 --- /dev/null +++ b/tests/comparison/discrepancy_searcher.py @@ -0,0 +1,529 @@ +#!/usr/bin/env python + +# Copyright (c) 2014 Cloudera, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''This module will run random queries against existing databases and compare the + results. + +''' + +from contextlib import closing +from decimal import Decimal +from itertools import izip, izip_longest +from logging import basicConfig, getLogger +from math import isinf, isnan +from os import getenv, remove +from os.path import exists, join +from shelve import open as open_shelve +from subprocess import call +from threading import current_thread, Thread +from tempfile import gettempdir +from time import time + +from tests.comparison.db_connector import ( + DbConnection, + DbConnector, + IMPALA, + MYSQL, + POSTGRESQL) +from tests.comparison.model import BigInt, TYPES +from tests.comparison.query_generator import QueryGenerator +from tests.comparison.model_translator import SqlWriter + +LOG = getLogger(__name__) + +class QueryResultComparator(object): + + # If the number of rows * cols is greater than this val, then the comparison will + # be aborted. Raising this value also raises the risk of python being OOM killed. At + # 10M python would get OOM killed occasionally even on a physical machine with 32GB + # ram. + TOO_MUCH_DATA = 1000 * 1000 + + # Used when comparing float vals + EPSILON = 0.1 + + # The decimal vals will be rounded before comparison + DECIMAL_PLACES = 2 + + def __init__(self, impala_connection, reference_connection): + self.reference_db_type = reference_connection.db_type + + self.impala_cursor = impala_connection.create_cursor() + self.reference_cursor = reference_connection.create_cursor() + + self.impala_sql_writer = SqlWriter.create(dialect=impala_connection.db_type) + self.reference_sql_writer = SqlWriter.create(dialect=reference_connection.db_type) + + # At this time the connection will be killed and ther comparison result will be + # timeout. + self.query_timeout_seconds = 3 * 60 + + def compare_query_results(self, query): + '''Execute the query, compare the data, and return a summary of the result.''' + comparison_result = ComparisonResult(query, self.reference_db_type) + + reference_data_set = None + impala_data_set = None + # Impala doesn't support getting the row count without getting the rows too. So run + # the query on the other database first. + try: + for sql_writer, cursor in ((self.reference_sql_writer, self.reference_cursor), + (self.impala_sql_writer, self.impala_cursor)): + self.execute_query(cursor, sql_writer.write_query(query)) + if (cursor.rowcount * len(query.select_clause.select_items)) > self.TOO_MUCH_DATA: + comparison_result.exception = Exception('Too much data to compare') + return comparison_result + if reference_data_set is None: + # MySQL returns a tuple of rows but a list is needed for sorting + reference_data_set = list(cursor.fetchall()) + comparison_result.reference_row_count = len(reference_data_set) + else: + impala_data_set = cursor.fetchall() + comparison_result.impala_row_count = len(impala_data_set) + except Exception as e: + comparison_result.exception = e + LOG.debug('Error running query: %s', e, exc_info=True) + return comparison_result + + comparison_result.query_resulted_in_data = (comparison_result.impala_row_count > 0 + or comparison_result.reference_row_count > 0) + + if comparison_result.impala_row_count != comparison_result.reference_row_count: + return comparison_result + + for data_set in (reference_data_set, impala_data_set): + for row_idx, row in enumerate(data_set): + data_set[row_idx] = [self.standardize_data(data) for data in row] + data_set.sort(cmp=self.row_sort_cmp) + + for impala_row, reference_row in \ + izip_longest(impala_data_set, reference_data_set): + for col_idx, (impala_val, reference_val) \ + in enumerate(izip_longest(impala_row, reference_row)): + if not self.vals_are_equal(impala_val, reference_val): + if isinstance(impala_val, int) \ + and isinstance(reference_val, (int, float, Decimal)) \ + and abs(reference_val) > BigInt.MAX: + # Impala will return incorrect results if the val is greater than max bigint + comparison_result.exception = KnownError( + 'https://issues.cloudera.org/browse/IMPALA-865') + elif isinstance(impala_val, float) \ + and (isinf(impala_val) or isnan(impala_val)): + # In some cases, Impala gives NaNs and Infs instead of NULLs + comparison_result.exception = KnownError( + 'https://issues.cloudera.org/browse/IMPALA-724') + comparison_result.impala_row = impala_row + comparison_result.reference_row = reference_row + comparison_result.mismatch_at_row_number = row_idx + 1 + comparison_result.mismatch_at_col_number = col_idx + 1 + return comparison_result + + if len(impala_data_set) == 1: + for val in impala_data_set[0]: + if val: + break + else: + comparison_result.query_resulted_in_data = False + + return comparison_result + + def execute_query(self, cursor, sql): + '''Execute the query and throw a timeout if needed.''' + def _execute_query(): + try: + cursor.execute(sql) + except Exception as e: + current_thread().exception = e + query_thread = Thread(target=_execute_query, name='Query execution thread') + query_thread.daemon = True + query_thread.start() + query_thread.join(self.query_timeout_seconds) + if query_thread.is_alive(): + if cursor.connection.supports_kill_connection: + LOG.debug('Attempting to kill connection') + cursor.connection.kill_connection() + LOG.debug('Kill connection') + cursor.close() + cursor.connection.close() + cursor = cursor\ + .connection\ + .connector\ + .create_connection(db_name=cursor.connection.db_name)\ + .create_cursor() + if cursor.connection.db_type == IMPALA: + self.impala_cursor = cursor + else: + self.reference_cursor = cursor + raise QueryTimeout('Query timed out after %s seconds' % self.query_timeout_seconds) + if hasattr(query_thread, 'exception'): + raise query_thread.exception + + def standardize_data(self, data): + '''Return a val that is suitable for comparison.''' + # For float data we need to round otherwise differences in precision will cause errors + if isinstance(data, (float, Decimal)): + return round(data, self.DECIMAL_PLACES) + return data + + def row_sort_cmp(self, left_row, right_row): + for left, right in izip(left_row, right_row): + if left is None and right is not None: + return -1 + if left is not None and right is None: + return 1 + result = cmp(left, right) + if result: + return result + return 0 + + def vals_are_equal(self, left, right): + if left == right: + return True + if isinstance(left, (int, float, Decimal)) and \ + isinstance(right, (int, float, Decimal)): + return self.floats_are_equal(left, right) + return False + + def floats_are_equal(self, left, right): + left = round(left, self.DECIMAL_PLACES) + right = round(right, self.DECIMAL_PLACES) + diff = abs(left - right) + if left * right == 0: + return diff < self.EPSILON + return diff / (abs(left) + abs(right)) < self.EPSILON + + +class ComparisonResult(object): + + def __init__(self, query, reference_db_type): + self.query = query + self.reference_db_type = reference_db_type + self.query_resulted_in_data = False + self.impala_row_count = None + self.reference_row_count = None + self.mismatch_at_row_number = None + self.mismatch_at_col_number = None + self.impala_row = None + self.reference_row = None + self.exception = None + self._error_message = None + + @property + def error(self): + if not self._error_message: + if self.exception: + self._error_message = str(self.exception) + elif self.impala_row_count and \ + self.impala_row_count != self.reference_row_count: + self._error_message = 'Row counts do not match: %s Impala rows vs %s %s rows' \ + % (self.impala_row_count, + self.reference_db_type, + self.reference_row_count) + elif self.mismatch_at_row_number is not None: + # Write a row like "[a, b, <>, d]" where c is a bad value + impala_row = '[' + ', '.join( + '<<' + str(val) + '>>' if idx == self.mismatch_at_col_number - 1 else str(val) + for idx, val in enumerate(self.impala_row) + ) + ']' + reference_row = '[' + ', '.join( + '<<' + str(val) + '>>' if idx == self.mismatch_at_col_number - 1 else str(val) + for idx, val in enumerate(self.reference_row) + ) + ']' + self._error_message = \ + 'Column %s in row %s does not match: %s Impala row vs %s %s row' \ + % (self.mismatch_at_col_number, + self.mismatch_at_row_number, + impala_row, + reference_row, + self.reference_db_type) + return self._error_message + + @property + def is_known_error(self): + return isinstance(self.exception, KnownError) + + @property + def query_timed_out(self): + return isinstance(self.exception, QueryTimeout) + + +class QueryTimeout(Exception): + pass + + +class KnownError(Exception): + + def __init__(self, jira_url): + Exception.__init__(self, 'Known issue: ' + jira_url) + self.jira_url = jira_url + + +class QueryResultDiffSearcher(object): + + # Sometimes things get into a bad state and the same error loops forever + ABORT_ON_REPEAT_ERROR_COUNT = 2 + + def __init__(self, impala_connection, reference_connection, filter_col_types=[]): + self.impala_connection = impala_connection + self.reference_connection = reference_connection + self.common_tables = DbConnection.describe_common_tables( + [impala_connection, reference_connection], + filter_col_types=filter_col_types) + + # A file-backed dict of queries that produced a discrepancy, keyed by query number + # (in string form, as required by the dict). + self.query_shelve_path = gettempdir() + '/query.shelve' + + # A list of all queries attempted + self.query_log_path = gettempdir() + '/impala_query_log.sql' + + def search(self, number_of_test_queries, stop_on_result_mismatch, stop_on_crash): + if exists(self.query_shelve_path): + # Ensure a clean shelve will be created + remove(self.query_shelve_path) + + start_time = time() + impala_sql_writer = SqlWriter.create(dialect=IMPALA) + reference_sql_writer = SqlWriter.create( + dialect=self.reference_connection.db_type) + query_result_comparator = QueryResultComparator( + self.impala_connection, self.reference_connection) + query_generator = QueryGenerator() + query_count = 0 + queries_resulted_in_data_count = 0 + mismatch_count = 0 + query_timeout_count = 0 + known_error_count = 0 + impala_crash_count = 0 + last_error = None + repeat_error_count = 0 + with open(self.query_log_path, 'w') as impala_query_log: + impala_query_log.write( + '--\n' + '-- Stating new run\n' + '--\n') + while number_of_test_queries > query_count: + query = query_generator.create_query(self.common_tables) + impala_sql = impala_sql_writer.write_query(query) + if 'FULL OUTER JOIN' in impala_sql and self.reference_connection.db_type == MYSQL: + # Not supported by MySQL + continue + + query_count += 1 + LOG.info('Running query #%s', query_count) + impala_query_log.write(impala_sql + ';\n') + result = query_result_comparator.compare_query_results(query) + if result.query_resulted_in_data: + queries_resulted_in_data_count += 1 + if result.error: + # TODO: These first two come from psycopg2, the postgres driver. Maybe we should + # try a different driver? Or maybe the usage of the driver isn't correct. + # Anyhow ignore these failures. + if 'division by zero' in result.error \ + or 'out of range' in result.error \ + or 'Too much data' in result.error: + LOG.debug('Ignoring error: %s', result.error) + query_count -= 1 + continue + + if result.is_known_error: + known_error_count += 1 + elif result.query_timed_out: + query_timeout_count += 1 + else: + mismatch_count += 1 + with closing(open_shelve(self.query_shelve_path)) as query_shelve: + query_shelve[str(query_count)] = query + + print('---Impala Query---\n') + print(impala_sql_writer.write_query(query, pretty=True) + '\n') + print('---Reference Query---\n') + print(reference_sql_writer.write_query(query, pretty=True) + '\n') + print('---Error---\n') + print(result.error + '\n') + print('------\n') + + if 'Could not connect' in result.error \ + or "Couldn't open transport for" in result.error: + # if stop_on_crash: + # break + # Assume Impala crashed and try restarting + impala_crash_count += 1 + LOG.info('Restarting Impala') + call([join(getenv('IMPALA_HOME'), 'bin/start-impala-cluster.py'), + '--log_dir=%s' % getenv('LOG_DIR', "/tmp/")]) + self.impala_connection.reconnect() + query_result_comparator.impala_cursor = self.impala_connection.create_cursor() + result = query_result_comparator.compare_query_results(query) + if result.error: + LOG.info('Restarting Impala') + call([join(getenv('IMPALA_HOME'), 'bin/start-impala-cluster.py'), + '--log_dir=%s' % getenv('LOG_DIR', "/tmp/")]) + self.impala_connection.reconnect() + query_result_comparator.impala_cursor = self.impala_connection.create_cursor() + else: + break + + if stop_on_result_mismatch and \ + not (result.is_known_error or result.query_timed_out): + break + + if last_error == result.error \ + and not (result.is_known_error or result.query_timed_out): + repeat_error_count += 1 + if repeat_error_count == self.ABORT_ON_REPEAT_ERROR_COUNT: + break + else: + last_error = result.error + repeat_error_count = 0 + else: + if result.query_resulted_in_data: + LOG.info('Results matched (%s rows)', result.impala_row_count) + else: + LOG.info('Query did not produce meaningful data') + last_error = None + repeat_error_count = 0 + + return SearchResults( + query_count, + queries_resulted_in_data_count, + mismatch_count, + query_timeout_count, + known_error_count, + impala_crash_count, + time() - start_time) + + +class SearchResults(object): + '''This class holds information about the outcome of a search run.''' + + def __init__(self, + query_count, + queries_resulted_in_data_count, + mismatch_count, + query_timeout_count, + known_error_count, + impala_crash_count, + run_time_in_seconds): + # Approx number of queries run, some queries may have been ignored + self.query_count = query_count + self.queries_resulted_in_data_count = queries_resulted_in_data_count + # Number of queries that had an error or result mismatch + self.mismatch_count = mismatch_count + self.query_timeout_count = query_timeout_count + self.known_error_count = known_error_count + self.impala_crash_count = impala_crash_count + self.run_time_in_seconds = run_time_in_seconds + + def get_summary_text(self): + mins, secs = divmod(self.run_time_in_seconds, 60) + hours, mins = divmod(mins, 60) + hours = int(hours) + mins = int(mins) + if hours: + run_time = '%s hour and %s minutes' % (hours, mins) + else: + secs = int(secs) + run_time = '%s seconds' % secs + if mins: + run_time = '%s mins and ' % mins + run_time + summary_params = self.__dict__ + summary_params['run_time'] = run_time + return ( + '%(mismatch_count)s mismatches found after running %(query_count)s queries in ' + '%(run_time)s.\n' + '%(queries_resulted_in_data_count)s of %(query_count)s queries produced results.' + '\n' + '%(impala_crash_count)s Impala crashes occurred.\n' + '%(known_error_count)s queries were excluded from the mismatch count because ' + 'they are known errors.\n' + '%(query_timeout_count)s queries timed out and were excluded from all counts.') \ + % summary_params + + +if __name__ == '__main__': + import sys + from optparse import NO_DEFAULT, OptionGroup, OptionParser + + parser = OptionParser() + parser.add_option('--log-level', default='INFO', + help='The log level to use.', choices=('DEBUG', 'INFO', 'WARN', 'ERROR')) + parser.add_option('--db-name', default='randomness', + help='The name of the database to use. Ex: funcal.') + + parser.add_option('--reference-db-type', default=MYSQL, choices=(MYSQL, POSTGRESQL), + help='The type of the reference database to use. Ex: MYSQL.') + parser.add_option('--stop-on-mismatch', default=False, action='store_true', + help='Exit immediately upon find a discrepancy in a query result.') + parser.add_option('--stop-on-crash', default=False, action='store_true', + help='Exit immediately if Impala crashes.') + parser.add_option('--query-count', default=1000, type=int, + help='Exit after running the given number of queries.') + parser.add_option('--exclude-types', default='Double,Float,TinyInt', + help='A comma separated list of data types to exclude while generating queries.') + + group = OptionGroup(parser, 'MySQL Options') + group.add_option('--mysql-host', default='localhost', + help='The name of the host running the MySQL database.') + group.add_option('--mysql-port', default=3306, type=int, + help='The port of the host running the MySQL database.') + group.add_option('--mysql-user', default='root', + help='The user name to use when connecting to the MySQL database.') + group.add_option('--mysql-password', + help='The password to use when connecting to the MySQL database.') + parser.add_option_group(group) + + group = OptionGroup(parser, 'Postgresql Options') + group.add_option('--postgresql-host', default='localhost', + help='The name of the host running the Postgresql database.') + group.add_option('--postgresql-port', default=5432, type=int, + help='The port of the host running the Postgresql database.') + group.add_option('--postgresql-user', default='postgres', + help='The user name to use when connecting to the Postgresql database.') + group.add_option('--postgresql-password', + help='The password to use when connecting to the Postgresql database.') + parser.add_option_group(group) + + for group in parser.option_groups + [parser]: + for option in group.option_list: + if option.default != NO_DEFAULT: + option.help += " [default: %default]" + + options, args = parser.parse_args() + + basicConfig(level=options.log_level) + + impala_connection = DbConnector(IMPALA).create_connection(options.db_name) + db_connector_param_key = options.reference_db_type.lower() + reference_connection = DbConnector(options.reference_db_type, + user_name=getattr(options, db_connector_param_key + '_user'), + password=getattr(options, db_connector_param_key + '_password'), + host_name=getattr(options, db_connector_param_key + '_host'), + port=getattr(options, db_connector_param_key + '_port')) \ + .create_connection(options.db_name) + if options.exclude_types: + exclude_types = set(type_name.lower() for type_name + in options.exclude_types.split(',')) + filter_col_types = [type_ for type_ in TYPES + if type_.__name__.lower() in exclude_types] + else: + filter_col_types = [] + diff_searcher = QueryResultDiffSearcher( + impala_connection, reference_connection, filter_col_types=filter_col_types) + search_results = diff_searcher.search( + options.query_count, options.stop_on_mismatch, options.stop_on_crash) + print(search_results.get_summary_text()) + sys.exit(search_results.mismatch_count) diff --git a/tests/comparison/model.py b/tests/comparison/model.py new file mode 100644 index 000000000..7249a2017 --- /dev/null +++ b/tests/comparison/model.py @@ -0,0 +1,741 @@ +# Copyright (c) 2014 Cloudera, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +class Query(object): + '''A representation of the stucture of a SQL query. Only the select_clause and + from_clause are required for a valid query. + + ''' + + def __init__(self, select_clause, from_clause): + self.with_clause = None + self.select_clause = select_clause + self.from_clause = from_clause + self.where_clause = None + self.group_by_clause = None + self.having_clause = None + self.union_clause = None + + @property + def table_exprs(self): + '''Provides a list of all table_exprs that are declared by this query. This + includes table_exprs in the WITH and FROM sections. + ''' + table_exprs = self.from_clause.table_exprs + if self.with_clause: + table_exprs += self.with_clause.table_exprs + return table_exprs + + +class SelectClause(object): + '''This encapuslates the SELECT part of a query. It is convenient to separate + non-agg items from agg items so that it is simple to know if the query + is an agg query or not. + + ''' + + def __init__(self, non_agg_items=None, agg_items=None): + self.non_agg_items = non_agg_items or list() + self.agg_items = agg_items or list() + self.distinct = False + + @property + def select_items(self): + '''Provides a consolidated view of all select items.''' + return self.non_agg_items + self.agg_items + + +class SelectItem(object): + '''A representation of any possible expr than would be valid in + + SELECT [, ...] FROM ... + + Each SelectItem contains a ValExpr which will either be a instance of a + DataType (representing a constant), a Column, or a Func. + + Ex: "SELECT int_col + smallint_col FROM alltypes" would have a val_expr of + Plus(Column(), Column()). + + ''' + + def __init__(self, val_expr, alias=None): + self.val_expr = val_expr + self.alias = alias + + @property + def type(self): + '''Returns the DataType of this item.''' + return self.val_expr.type + + @property + def is_agg(self): + '''Evaluates to True if this item contains an aggregate expression.''' + return self.val_expr.is_agg + + +class ValExpr(object): + '''This is an AbstractClass that represents a generic expr that results in a + scalar. The abc module was not used because it caused problems for the pickle + module. + + ''' + + @property + def type(self): + '''This is declared for documentations purposes, subclasses should override this to + return the DataType that this expr represents. + ''' + pass + + @property + def base_type(self): + '''Return the most fundemental data type that the expr evaluates to. Only + numeric types will result in a different val than would be returned by self.type. + + Ex: + if self.type == BigInt: + assert self.base_type == Int + if self.type == Double: + assert self.base_type == Float + if self.type == String: + assert self.base_type == self.type + ''' + if self.returns_int: + return Int + if self.returns_float: + return Float + return self.type + + @property + def is_func(self): + return isinstance(self, Func) + + @property + def is_agg(self): + '''Evaluates to True if this expression contains an aggregate function.''' + if isinstance(self, AggFunc): + return True + if self.is_func: + for arg in self.args: + if arg.is_agg: + return True + + @property + def is_col(self): + return isinstance(self, Column) + + @property + def is_constant(self): + return isinstance(self, DataType) + + @property + def returns_boolean(self): + return issubclass(self.type, Boolean) + + @property + def returns_number(self): + return issubclass(self.type, Number) + + @property + def returns_int(self): + return issubclass(self.type, Int) + + @property + def returns_float(self): + return issubclass(self.type, Float) + + @property + def returns_string(self): + return issubclass(self.type, String) + + @property + def returns_timestamp(self): + return issubclass(self.type, Timestamp) + + +class Column(ValExpr): + '''A representation of a col. All TableExprs will have Columns. So a Column + may belong to an InlineView as well as a standard Table. + + This class is used in two ways: + + 1) As a piece of metadata in a table definiton. In this usage the col isn't + intended to represent an val. + + 2) As an expr in a query, for example an item being selected or as part of + a join condition. In this usage the col is more like a val, which is why + it implements/extends ValExpr. + + ''' + + def __init__(self, owner, name, type_): + self.owner = owner + self.name = name + self._type = type_ + + @property + def type(self): + return self._type + + def __hash__(self): + return hash(self.name) + + def __eq__(self, other): + if not isinstance(other, Column): + return False + if self is other: + return True + return self.name == other.name and self.owner.identifier == other.owner.identifier + + def __repr__(self): + return '%s' % ( + type(self).__name__, self.name, self._type.__name__) + + +class FromClause(object): + '''A representation of a FROM clause. The member variable join_clauses may optionally + contain JoinClause items. + + ''' + + def __init__(self, table_expr, join_clauses=None): + self.table_expr = table_expr + self.join_clauses = join_clauses or list() + + @property + def table_exprs(self): + '''Provides a list of all table_exprs that are declared within this FROM + block. + ''' + table_exprs = [join_clause.table_expr for join_clause in self.join_clauses] + table_exprs.append(self.table_expr) + return table_exprs + + +class TableExpr(object): + '''This is an AbstractClass that represents something that a query may use to select + from or join on. The abc module was not used because it caused problems for the + pickle module. + + ''' + + def identifier(self): + '''Returns either a table name or alias if one has been declared.''' + pass + + def cols(self): + pass + + @property + def cols_by_base_type(self): + '''Group cols by their basic data type and return a dict of the results. + + As an example, a "BigInt" would be considered as an "Int". + ''' + return DataType.group_by_base_type(self.cols) + + @property + def is_table(self): + return isinstance(self, Table) + + @property + def is_inline_view(self): + return isinstance(self, InlineView) + + @property + def is_with_clause_inline_view(self): + return isinstance(self, WithClauseInlineView) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + return self.identifier == other.identifier + + +class Table(TableExpr): + '''Represents a standard database table.''' + + def __init__(self, name): + self.name = name + self._cols = [] + self.alias = None + + @property + def identifier(self): + return self.alias or self.name + + @property + def cols(self): + return self._cols + + @cols.setter + def cols(self, cols): + self._cols = cols + + +class InlineView(TableExpr): + '''Represents an inline view. + + Ex: In the query "SELECT * FROM (SELECT * FROM foo) AS bar", + "(SELECT * FROM foo) AS bar" would be an inline view. + + ''' + + def __init__(self, query): + self.query = query + self.alias = None + + @property + def identifier(self): + return self.alias + + @property + def cols(self): + return [Column(self, item.alias, item.type) for item in + self.query.select_clause.non_agg_items + self.query.select_clause.agg_items] + + +class WithClause(object): + '''Represents a WITH clause. + + Ex: In the query "WITH bar AS (SELECT * FROM foo) SELECT * FROM bar", + "WITH bar AS (SELECT * FROM foo)" would be the with clause. + + ''' + + def __init__(self, with_clause_inline_views): + self.with_clause_inline_views = with_clause_inline_views + + @property + def table_exprs(self): + return self.with_clause_inline_views + + +class WithClauseInlineView(InlineView): + '''Represents the entries in a WITH clause. These are very similar to InlineViews but + may have an additional alias. + + Ex: WITH bar AS (SELECT * FROM foo) + SELECT * + FROM bar as r + JOIN (SELECT * FROM baz) AS z ON ... + + The WithClauseInlineView has aliases "bar" and "r" while the InlineView has + only the alias "z". + + ''' + + def __init__(self, query, with_clause_alias): + self.query = query + self.with_clause_alias = with_clause_alias + self.alias = None + + @property + def identifier(self): + return self.alias or self.with_clause_alias + + +class JoinClause(object): + '''A representation of a JOIN clause. + + Ex: SELECT * FROM foo JOIN [ON ] + + The member variable boolean_expr will be an instance of a boolean func + defined below. + + ''' + + JOINS_TYPES = ['INNER', 'LEFT', 'RIGHT', 'FULL OUTER', 'CROSS'] + + def __init__(self, join_type, table_expr, boolean_expr=None): + self.join_type = join_type + self.table_expr = table_expr + self.boolean_expr = boolean_expr + + +class WhereClause(object): + '''The member variable boolean_expr will be an instance of a boolean func + defined below. + + ''' + + def __init__(self, boolean_expr): + self.boolean_expr = boolean_expr + + +class GroupByClause(object): + + def __init__(self, select_items): + self.group_by_items = select_items + + +class HavingClause(object): + '''The member variable boolean_expr will be an instance of a boolean func + defined below. + + ''' + + def __init__(self, boolean_expr): + self.boolean_expr = boolean_expr + + +class UnionClause(object): + '''A representation of a UNION clause. + + If the member variable "all" is True, the instance represents a "UNION ALL". + + ''' + + def __init__(self, query): + self.query = query + self.all = False + + @property + def queries(self): + queries = list() + query = self.query + while True: + queries.append(query) + if not query.union_clause: + break + query = query.union_clause.query + return queries + + +class DataTypeMetaclass(type): + '''Provides sorting of classes used to determine upcasting.''' + + def __cmp__(cls, other): + return cmp( + getattr(cls, 'CMP_VALUE', cls.__name__), + getattr(other, 'CMP_VALUE', other.__name__)) + + +class DataType(ValExpr): + '''Base class for data types. + + Data types are represented as classes so inheritence can be used. + + ''' + + __metaclass__ = DataTypeMetaclass + + def __init__(self, val): + self.val = val + + @property + def type(self): + return type(self) + + @staticmethod + def group_by_base_type(vals): + '''Group cols by their basic data type and return a dict of the results. + + As an example, a "BigInt" would be considered as an "Int". + ''' + vals_by_type = defaultdict(list) + for val in vals: + type_ = val.type + if issubclass(type_, Int): + type_ = Int + elif issubclass(type_, Float): + type_ = Float + vals_by_type[type_].append(val) + return vals_by_type + + +class Boolean(DataType): + pass + + +class Number(DataType): + pass + + +class Int(Number): + + # Used to compare with other numbers for determining upcasting + CMP_VALUE = 2 + + # Used during data generation to keep vals in range + MIN = -2 ** 31 + MAX = -MIN - 1 + + # Aliases used when reading and writing table definitions + POSTGRESQL = ['INTEGER'] + + +class TinyInt(Int): + + CMP_VALUE = 0 + + MIN = -2 ** 7 + MAX = -MIN - 1 + + POSTGRESQL = ['SMALLINT'] + + +class SmallInt(Int): + + CMP_VALUE = 1 + + MIN = -2 ** 15 + MAX = -MIN - 1 + + +class BigInt(Int): + + CMP_VALUE = 3 + + MIN = -2 ** 63 + MAX = -MIN - 1 + + +class Float(Number): + + CMP_VALUE = 4 + + POSTGRESQL = ['REAL'] + + +class Double(Float): + + CMP_VALUE = 5 + + MYSQL = ['DOUBLE', 'DECIMAL'] # Use double by default but add decimal synonym + POSTGRESQL = ['DOUBLE PRECISION'] + + +class String(DataType): + + MIN = 0 + # The Impala limit is 32,767 but MySQL has a row size limit of 65,535. To allow 3+ + # String cols per table, the limit will be lowered to 1,000. That should be fine + # for testing anyhow. + MAX = 1000 + + MYSQL = ['VARCHAR(%s)' % MAX] + POSTGRESQL = MYSQL + ['CHARACTER VARYING'] + + +class Timestamp(DataType): + + MYSQL = ['DATETIME'] + POSTGRESQL = ['TIMESTAMP WITHOUT TIME ZONE'] + + +NUMBER_TYPES = [Int, TinyInt, SmallInt, BigInt, Float, Double] +TYPES = NUMBER_TYPES + [Boolean, String, Timestamp] + +class Func(ValExpr): + '''Base class for funcs''' + + def __init__(self, *args): + self.args = list(args) + + def __hash__(self): + return hash(type(self)) + hash(tuple(self.args)) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + if self is other: + return True + return self.args == other.args + + +class UnaryFunc(Func): + + def __init__(self, arg): + Func.__init__(self, arg) + + +class BinaryFunc(Func): + + def __init__(self, left, right): + Func.__init__(self, left, right) + + @property + def left(self): + return self.args[0] + + @left.setter + def left(self, left): + self.args[0] = left + + @property + def right(self): + return self.args[1] + + @right.setter + def right(self, right): + self.args[1] = right + + +class BooleanFunc(Func): + + @property + def type(self): + return Boolean + + +class IntFunc(Func): + + @property + def type(self): + return Int + + +class DoubleFunc(Func): + + @property + def type(self): + return Double + + +class StringFunc(Func): + + @property + def type(self): + return String + + +class UpcastingFunc(Func): + + @property + def type(self): + return max(arg.type for arg in self.args) + + +class AggFunc(Func): + + # Avoid having a self.distinct because it would need to be __init__'d explictly, + # which none of the AggFunc subclasses do (ex: Avg doesn't have it's + # own __init__). + + @property + def distinct(self): + return getattr(self, '_distinct', False) + + @distinct.setter + def distinct(self, val): + return setattr(self, '_distinct', val) + +# The classes below diverge from above by including the SQL representation. It's a lot +# easier this way because there are a lot of funcs but they all have the same +# structure. Non-standard funcs, such as string concatenation, would need to have +# their representation information elsewhere (like classes above). + +Parentheses = type('Parentheses', (UnaryFunc, UpcastingFunc), {'FORMAT': '({})'}) + +IsNull = type('IsNull', (UnaryFunc, BooleanFunc), {'FORMAT': '{} IS NULL'}) +IsNotNull = type('IsNotNull', (UnaryFunc, BooleanFunc), {'FORMAT': '{} IS NOT NULL'}) +And = type('And', (BinaryFunc, BooleanFunc), {'FORMAT': '{} AND {}'}) +Or = type('Or', (BinaryFunc, BooleanFunc), {'FORMAT': '{} OR {}'}) +Equals = type('Equals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} = {}'}) +NotEquals = type('NotEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} != {}'}) +GreaterThan = type('GreaterThan', (BinaryFunc, BooleanFunc), {'FORMAT': '{} > {}'}) +LessThan = type('LessThan', (BinaryFunc, BooleanFunc), {'FORMAT': '{} < {}'}) +GreaterThanOrEquals = type( + 'GreaterThanOrEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} >= {}'}) +LessThanOrEquals = type( + 'LessThanOrEquals', (BinaryFunc, BooleanFunc), {'FORMAT': '{} <= {}'}) + +Plus = type('Plus', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} + {}'}) +Minus = type('Minus', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} - {}'}) +Multiply = type('Multiply', (BinaryFunc, UpcastingFunc), {'FORMAT': '{} * {}'}) +Divide = type('Divide', (BinaryFunc, DoubleFunc), {'FORMAT': '{} / {}'}) + +Floor = type('Floor', (UnaryFunc, IntFunc), {'FORMAT': 'FLOOR({})'}) + +Concat = type('Concat', (BinaryFunc, StringFunc), {'FORMAT': 'CONCAT({}, {})'}) +Length = type('Length', (UnaryFunc, IntFunc), {'FORMAT': 'LENGTH({})'}) + +ExtractYear = type( + 'ExtractYear', (UnaryFunc, IntFunc), {'FORMAT': "EXTRACT('YEAR' FROM {})"}) + +# Formatting of agg funcs is a little trickier since they may have a distinct +Avg = type('Avg', (UnaryFunc, DoubleFunc, AggFunc), {}) +Count = type('Count', (UnaryFunc, IntFunc, AggFunc), {}) +Max = type('Max', (UnaryFunc, UpcastingFunc, AggFunc), {}) +Min = type('Min', (UnaryFunc, UpcastingFunc, AggFunc), {}) +Sum = type('Sum', (UnaryFunc, UpcastingFunc, AggFunc), {}) + +UNARY_BOOLEAN_FUNCS = [IsNull, IsNotNull] +BINARY_BOOLEAN_FUNCS = [And, Or] +RELATIONAL_OPERATORS = [ + Equals, NotEquals, GreaterThan, LessThan, GreaterThanOrEquals, LessThanOrEquals] +MATH_OPERATORS = [Plus, Minus, Multiply] # Leaving out Divide +BINARY_STRING_FUNCS = [Concat] +AGG_FUNCS = [Avg, Count, Max, Min, Sum] + + +class If(Func): + + FORMAT = 'CASE WHEN {} THEN {} ELSE {} END' + + def __init__(self, boolean_expr, consquent_expr, alternative_expr): + Func.__init__( + self, boolean_expr, consquent_expr, alternative_expr) + + @property + def boolean_expr(self): + return self.args[0] + + @property + def consquent_expr(self): + return self.args[1] + + @property + def alternative_expr(self): + return self.args[2] + + @property + def type(self): + return max((self.consquent_expr, self.alternative_expr)) + + +class Greatest(BinaryFunc, UpcastingFunc, If): + + def __init__(self, left, rigt): + BinaryFunc.__init__(self, left, rigt) + If.__init__(self, GreaterThan(left, rigt), left, rigt) + + @property + def type(self): + return UpcastingFunc.type.fget(self) + + +class Cast(Func): + + FORMAT = 'CAST({} AS {})' + + def __init__(self, val_expr, resulting_type): + if resulting_type not in TYPES: + raise Exception('Unexpected type: {}'.format(resulting_type)) + Func.__init__(self, val_expr, resulting_type) + + @property + def val_expr(self): + return self.args[0] + + @property + def resulting_type(self): + return self.args[1] + + @property + def type(self): + return self.resulting_type diff --git a/tests/comparison/model_translator.py b/tests/comparison/model_translator.py new file mode 100644 index 000000000..0c9194ed4 --- /dev/null +++ b/tests/comparison/model_translator.py @@ -0,0 +1,367 @@ +# Copyright (c) 2014 Cloudera, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from inspect import getmro +from logging import getLogger +from re import sub +from sqlparse import format + +from tests.comparison.model import ( + Boolean, + Float, + Int, + Number, + Query, + String, + Timestamp) + +LOG = getLogger(__name__) + +class SqlWriter(object): + '''Subclasses of SQLWriter will take a Query and provide the SQL representation for a + specific database such as Impala or MySQL. The SqlWriter.create([dialect=]) + factory method may be used instead of specifying the concrete class. + + Another important function of this class is to ensure that CASTs produce the same + results across different databases. Sometimes the CASTs implemented here produce odd + results. For example, the result of CAST(date_col AS INT) in MySQL may be an int of + YYYYMMDD whereas in Impala it may be seconds since the epoch. For comparison purposes + the CAST could be transformed into EXTRACT(DAY from date_col). + ''' + + @staticmethod + def create(dialect='impala'): + '''Create and return a new SqlWriter appropriate for the given sql dialect. "dialect" + refers to database specific deviations of sql, and the val should be one of + "IMPALA", "MYSQL", or "POSTGRESQL". + ''' + dialect = dialect.upper() + if dialect == 'IMPALA': + return SqlWriter() + if dialect == 'POSTGRESQL': + return PostgresqlSqlWriter() + if dialect == 'MYSQL': + return MySQLSqlWriter() + raise Exception('Unknown dialect: %s' % dialect) + + def write_query(self, query, pretty=False): + '''Return SQL as a string for the given query.''' + sql = list() + # Write out each section in the proper order + for clause in ( + query.with_clause, + query.select_clause, + query.from_clause, + query.where_clause, + query.group_by_clause, + query.having_clause, + query.union_clause): + if clause: + sql.append(self._write(clause)) + sql = '\n'.join(sql) + if pretty: + sql = self.make_pretty_sql(sql) + return sql + + def make_pretty_sql(self, sql): + try: + sql = format(sql, reindent=True) + except Exception as e: + LOG.warn('Unable to format sql: %s', e) + return sql + + def _write_with_clause(self, with_clause): + return 'WITH ' + ',\n'.join('%s AS (%s)' % (view.identifier, self._write(view.query)) + for view in with_clause.with_clause_inline_views) + + def _write_select_clause(self, select_clause): + items = select_clause.non_agg_items + select_clause.agg_items + sql = 'SELECT' + if select_clause.distinct: + sql += ' DISTINCT' + sql += '\n' + ',\n'.join(self._write(item) for item in items) + return sql + + def _write_select_item(self, select_item): + # If the query is nested, the items will have aliases so that the outer query can + # easily reference them. + if not select_item.alias: + raise Exception('An alias is required') + return '%s AS %s' % (self._write(select_item.val_expr), select_item.alias) + + def _write_column(self, col): + return '%s.%s' % (col.owner.identifier, col.name) + + def _write_from_clause(self, from_clause): + sql = 'FROM %s' % self._write(from_clause.table_expr) + if from_clause.join_clauses: + sql += '\n' + '\n'.join(self._write(join) for join in from_clause.join_clauses) + return sql + + def _write_table(self, table): + if table.alias: + return '%s AS %s' % (table.name, table.identifier) + return table.name + + def _write_inline_view(self, inline_view): + if not inline_view.identifier: + raise Exception('An inline view requires an identifier') + return '(\n%s\n) AS %s' % (self._write(inline_view.query), inline_view.identifier) + + def _write_with_clause_inline_view(self, with_clause_inline_view): + if not with_clause_inline_view.with_clause_alias: + raise Exception('An with clause entry requires an identifier') + sql = with_clause_inline_view.with_clause_alias + if with_clause_inline_view.alias: + sql += ' AS ' + with_clause_inline_view.alias + return sql + + def _write_join_clause(self, join_clause): + sql = '%s JOIN %s' % (join_clause.join_type, self._write(join_clause.table_expr)) + if join_clause.boolean_expr: + sql += ' ON ' + self._write(join_clause.boolean_expr) + return sql + + def _write_where_clause(self, where_clause): + return 'WHERE\n' + self._write(where_clause.boolean_expr) + + def _write_group_by_clause(self, group_by_clause): + return 'GROUP BY\n' + ',\n'.join(self._write(item.val_expr) + for item in group_by_clause.group_by_items) + + def _write_having_clause(self, having_clause): + return 'HAVING\n' + self._write(having_clause.boolean_expr) + + def _write_union_clause(self, union_clause): + sql = 'UNION' + if union_clause.all: + sql += ' ALL' + sql += '\n' + self._write(union_clause.query) + return sql + + def _write_data_type(self, data_type): + '''Write a literal value.''' + if data_type.returns_string: + return "'{}'".format(data_type.val) + if data_type.returns_timestamp: + return "CAST('{}' AS TIMESTAMP)".format(data_type.val) + return str(data_type.val) + + def _write_func(self, func): + return func.FORMAT.format(*[self._write(arg) for arg in func.args]) + + def _write_cast(self, cast): + # Handle casts that produce different results across database types or just don't + # make sense like casting a DATE as a BOOLEAN.... + if cast.val_expr.returns_boolean: + if issubclass(cast.resulting_type, Timestamp): + return "CAST(CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END AS TIMESTAMP)"\ + .format(self._write(cast.val_expr)) + elif cast.val_expr.returns_number: + if issubclass(cast.resulting_type, Timestamp): + return ("CAST(CONCAT('2000-01-', " + "LPAD(CAST(ABS(FLOOR({})) % 31 + 1 AS STRING), 2, '0')) " + "AS TIMESTAMP)").format(self._write(cast.val_expr)) + elif cast.val_expr.returns_string: + if issubclass(cast.resulting_type, Boolean): + return "(LENGTH({}) > 2)".format(self._write(cast.val_expr)) + if issubclass(cast.resulting_type, Timestamp): + return ("CAST(CONCAT('2000-01-', LPAD(CAST(LENGTH({}) % 31 + 1 AS STRING), " + "2, '0')) AS TIMESTAMP)").format(self._write(cast.val_expr)) + elif cast.val_expr.returns_timestamp: + if issubclass(cast.resulting_type, Boolean): + return '(DAY({0}) > MONTH({0}))'.format(self._write(cast.val_expr)) + if issubclass(cast.resulting_type, Number): + return ('(DAY({0}) + 100 * MONTH({0}) + 100 * 100 * YEAR({0}))').format( + self._write(cast.val_expr)) + return self._write_func(cast) + + def _write_agg_func(self, agg_func): + sql = type(agg_func).__name__.upper() + '(' + if agg_func.distinct: + sql += 'DISTINCT ' + # All agg funcs only have a single arg + sql += self._write(agg_func.args[0]) + ')' + return sql + + def _write_data_type_metaclass(self, data_type_class): + '''Write a data type class such as Int or Boolean.''' + return data_type_class.__name__.upper() + + def _write(self, object_): + '''Return a sql string representation of the given object.''' + # What's below is effectively a giant switch statement. It works based on a func + # naming and signature convention. It should match the incoming object with the + # corresponding func defined, then call the func and return the result. + # + # Ex: + # a = model.And(...) + # _write(a) should call _write_func(a) because "And" is a subclass of "Func" and no + # other _writer_ methods have been defined higher up the method + # resolution order (MRO). If _write_and(...) were to be defined, it would be called + # instead. + for type_ in getmro(type(object_)): + writer_func_name = '_write' + sub('([A-Z])', r'_\1', type_.__name__).lower() + writer_func = getattr(self, writer_func_name, None) + if writer_func: + return writer_func(object_) + + # Handle any remaining cases + if isinstance(object_, Query): + return self.write_query(object_) + + raise Exception('Unsupported object: %s<%s>' % (type(object_).__name__, object_)) + + +class PostgresqlSqlWriter(SqlWriter): + # TODO: This class is out of date since switching to MySQL. This is left here as is + # in case there is a desire to switch back in the future (it should be better than + # starting from nothing). + + def _write_divide(self, divide): + # For ints, Postgresql does int division but Impala does float division. + return 'CAST({} AS REAL) / {}' \ + .format(*[self._write(arg) for arg in divide.args]) + + def _write_data_type_metaclass(self, data_type_class): + '''Write a data type class such as Int or Boolean.''' + if hasattr(data_type_class, 'POSTGRESQL'): + return data_type_class.POSTGRESQL[0] + return data_type_class.__name__.upper() + + def _write_cast(self, cast): + # Handle casts that produce different results across database types or just don't + # make sense like casting a DATE as a BOOLEAN.... + if cast.val_expr.returns_boolean: + if issubclass(cast.resulting_type, Float): + return "CASE {} WHEN TRUE THEN 1.0 WHEN FALSE THEN 0.0 END".format( + self._write(cast.val_expr)) + if issubclass(cast.resulting_type, Timestamp): + return "CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END".format( + self._write(cast.val_expr)) + if issubclass(cast.resulting_type, String): + return "CASE {} WHEN TRUE THEN '1' WHEN FALSE THEN '0' END".format( + self._write(cast.val_expr)) + elif cast.val_expr.returns_number: + if issubclass(cast.resulting_type, Boolean): + return 'CASE WHEN ({0}) != 0 THEN TRUE WHEN ({0}) = 0 THEN FALSE END'.format( + self._write(cast.val_expr)) + if issubclass(cast.resulting_type, Timestamp): + return "CASE WHEN ({}) > 0 THEN '2000-01-01' ELSE '1999-01-01' END".format( + self._write(cast.val_expr)) + elif cast.val_expr.returns_string: + if issubclass(cast.resulting_type, Boolean): + return "(LENGTH({}) > 2)".format(self._write(cast.val_expr)) + elif cast.val_expr.returns_timestamp: + if issubclass(cast.resulting_type, Boolean): + return '(EXTRACT(DAY FROM {0}) > EXTRACT(MONTH FROM {0}))'.format( + self._write(cast.val_expr)) + if issubclass(cast.resulting_type, Number): + return ('(EXTRACT(DAY FROM {0}) ' + '+ 100 * EXTRACT(MONTH FROM {0}) ' + '+ 100 * 100 * EXTRACT(YEAR FROM {0}))').format( + self._write(cast.val_expr)) + return self._write_func(cast) + + +class MySQLSqlWriter(SqlWriter): + + def write_query(self, query, pretty=False): + # MySQL doesn't support WITH clauses so they need to be converted into inline views. + # We are going to cheat by making use of the fact that the query generator creates + # with clause entries with unique aliases even considering nested queries. + sql = list() + for clause in ( + query.select_clause, + query.from_clause, + query.where_clause, + query.group_by_clause, + query.having_clause, + query.union_clause): + if clause: + sql.append(self._write(clause)) + sql = '\n'.join(sql) + if query.with_clause: + # Just replace the named referenes with inline views. Go in reverse order because + # entries at the bottom of the WITH clause definition may reference entries above. + for with_clause_inline_view in reversed(query.with_clause.with_clause_inline_views): + replacement_sql = '(' + self.write_query(with_clause_inline_view.query) + ')' + sql = sql.replace(with_clause_inline_view.identifier, replacement_sql) + if pretty: + sql = self.make_pretty_sql(sql) + return sql + + def _write_data_type_metaclass(self, data_type_class): + '''Write a data type class such as Int or Boolean.''' + if issubclass(data_type_class, Int): + return 'INTEGER' + if issubclass(data_type_class, Float): + return 'DECIMAL(65, 15)' + if issubclass(data_type_class, String): + return 'CHAR' + if hasattr(data_type_class, 'MYSQL'): + return data_type_class.MYSQL[0] + return data_type_class.__name__.upper() + + def _write_data_type(self, data_type): + '''Write a literal value.''' + if data_type.returns_timestamp: + return "CAST('{}' AS DATETIME)".format(data_type.val) + if data_type.returns_boolean: + # MySQL will error if a data_type "FALSE" is used as a GROUP BY field + return '(0 = 0)' if data_type.val else '(1 = 0)' + return SqlWriter._write_data_type(self, data_type) + + def _write_cast(self, cast): + # Handle casts that produce different results across database types or just don't + # make sense like casting a DATE as a BOOLEAN.... + if cast.val_expr.returns_boolean: + if issubclass(cast.resulting_type, Timestamp): + return "CAST(CASE WHEN {} THEN '2000-01-01' ELSE '1999-01-01' END AS DATETIME)"\ + .format(self._write(cast.val_expr)) + elif cast.val_expr.returns_number: + if issubclass(cast.resulting_type, Boolean): + return ("CASE WHEN ({0}) != 0 THEN TRUE WHEN ({0}) = 0 THEN FALSE END").format( + self._write(cast.val_expr)) + if issubclass(cast.resulting_type, Timestamp): + return "CAST(CONCAT('2000-01-', ABS(FLOOR({})) % 31 + 1) AS DATETIME)"\ + .format(self._write(cast.val_expr)) + elif cast.val_expr.returns_string: + if issubclass(cast.resulting_type, Boolean): + return "(LENGTH({}) > 2)".format(self._write(cast.val_expr)) + if issubclass(cast.resulting_type, Timestamp): + return ("CAST(CONCAT('2000-01-', LENGTH({}) % 31 + 1) AS DATETIME)").format( + self._write(cast.val_expr)) + elif cast.val_expr.returns_timestamp: + if issubclass(cast.resulting_type, Number): + return ('(EXTRACT(DAY FROM {0}) ' + '+ 100 * EXTRACT(MONTH FROM {0}) ' + '+ 100 * 100 * EXTRACT(YEAR FROM {0}))').format( + self._write(cast.val_expr)) + if issubclass(cast.resulting_type, Boolean): + return '(EXTRACT(DAY FROM {0}) > EXTRACT(MONTH FROM {0}))'.format( + self._write(cast.val_expr)) + + # MySQL uses different type names when casting... + if issubclass(cast.resulting_type, Boolean): + data_type = 'UNSIGNED' + elif issubclass(cast.resulting_type, Float): + data_type = 'DECIMAL(65, 15)' + elif issubclass(cast.resulting_type, Int): + data_type = 'SIGNED' + elif issubclass(cast.resulting_type, String): + data_type = 'CHAR' + elif issubclass(cast.resulting_type, Timestamp): + data_type = 'DATETIME' + return cast.FORMAT.format(self._write(cast.val_expr), data_type) diff --git a/tests/comparison/query_generator.py b/tests/comparison/query_generator.py new file mode 100644 index 000000000..9d238a419 --- /dev/null +++ b/tests/comparison/query_generator.py @@ -0,0 +1,556 @@ +# Copyright (c) 2014 Cloudera, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from copy import deepcopy +from itertools import chain +from random import choice, randint, shuffle + +from tests.comparison.model import ( + AGG_FUNCS, + AggFunc, + And, + BINARY_STRING_FUNCS, + BigInt, + Boolean, + Cast, + Column, + Count, + DataType, + Double, + Equals, + Float, + Floor, + FromClause, + Func, + Greatest, + GroupByClause, + HavingClause, + InlineView, + Int, + JoinClause, + Length, + MATH_OPERATORS, + Number, + Query, + RELATIONAL_OPERATORS, + SelectClause, + SelectItem, + String, + Table, + Timestamp, + TYPES, + UNARY_BOOLEAN_FUNCS, + UnionClause, + WhereClause, + WithClause, + WithClauseInlineView) + +def random_boolean(): + '''Return a val that evaluates to True 50% of the time''' + return randint(0, 1) + + +def zero_or_more(): + '''The chance of the return val of n is 1 / 2 ^ (n + 1)''' + val = 0 + while random_boolean(): + val += 1 + return val + + +def one_or_more(): + return zero_or_more() + 1 + + +def random_non_empty_split(iterable): + '''Return two non-empty lists''' + if len(iterable) < 2: + raise Exception('The iterable must contain at least two items') + split_index = randint(1, len(iterable) - 1) + left, right = list(), list() + for idx, item in enumerate(iterable): + if idx < split_index: + left.append(item) + else: + right.append(item) + return left, right + + +class QueryGenerator(object): + + def create_query(self, + table_exprs, + allow_with_clause=True, + select_item_data_types=None): + '''Create a random query using various language features. + + The initial call to this method should only use tables in the table_exprs + parameter, and not inline views or "with" definitions. The other types of + table exprs may be added as part of the query generation. + + If select_item_data_types is specified it must be a sequence or iterable of + DataType. The generated query.select_clause.select_items will have data + types suitable for use in a UNION. + + ''' + # Make a copy so tables can be added if a "with" clause is used + table_exprs = list(table_exprs) + + with_clause = None + if allow_with_clause and randint(1, 10) == 1: + with_clause = self._create_with_clause(table_exprs) + table_exprs.extend(with_clause.table_exprs) + + from_clause = self._create_from_clause(table_exprs) + + select_clause = self._create_select_clause( + from_clause.table_exprs, + select_item_data_types=select_item_data_types) + + query = Query(select_clause, from_clause) + + if with_clause: + query.with_clause = with_clause + + if random_boolean(): + query.where_clause = self._create_where_clause(from_clause.table_exprs) + + if select_clause.agg_items and select_clause.non_agg_items: + query.group_by_clause = GroupByClause(list(select_clause.non_agg_items)) + + if randint(1, 10) == 1: + if select_clause.agg_items: + self._enable_distinct_on_random_agg_items(select_clause.agg_items) + else: + select_clause.distinct = True + + if random_boolean() and (query.group_by_clause or select_clause.agg_items): + query.having_clause = self._create_having_clause(from_clause.table_exprs) + + if randint(1, 10) == 1: + select_item_data_types = list() + for select_item in select_clause.select_items: + # For numbers, choose the largest possible data type in case a CAST is needed. + if select_item.val_expr.returns_float: + select_item_data_types.append(Double) + elif select_item.val_expr.returns_int: + select_item_data_types.append(BigInt) + else: + select_item_data_types.append(select_item.val_expr.type) + query.union_clause = UnionClause(self.create_query( + table_exprs, + allow_with_clause=False, + select_item_data_types=select_item_data_types)) + query.union_clause.all = random_boolean() + + return query + + def _create_with_clause(self, table_exprs): + # Make a copy so newly created tables can be added and made availabele for use in + # future table definitions. + table_exprs = list(table_exprs) + with_clause_inline_views = list() + for with_clause_inline_view_idx in xrange(one_or_more()): + query = self.create_query(table_exprs) + # To help prevent nested WITH clauses from having entries with the same alias, + # choose a random alias. Of course it would be much better to know which aliases + # were already chosen but that information isn't easy to get from here. + with_clause_alias = 'with_%s_%s' % \ + (with_clause_inline_view_idx + 1, randint(1, 1000)) + with_clause_inline_view = WithClauseInlineView(query, with_clause_alias) + table_exprs.append(with_clause_inline_view) + with_clause_inline_views.append(with_clause_inline_view) + return WithClause(with_clause_inline_views) + + def _create_select_clause(self, table_exprs, select_item_data_types=None): + while True: + non_agg_items = [self._create_non_agg_select_item(table_exprs) + for _ in xrange(zero_or_more())] + agg_items = [self._create_agg_select_item(table_exprs) + for _ in xrange(zero_or_more())] + if non_agg_items or agg_items: + if select_item_data_types: + if len(select_item_data_types) > len(non_agg_items) + len(agg_items): + # Not enough items generated, try again + continue + while len(select_item_data_types) < len(non_agg_items) + len(agg_items): + items = choice([non_agg_items, agg_items]) + if items: + items.pop() + for data_type_idx, data_type in enumerate(select_item_data_types): + if data_type_idx < len(non_agg_items): + item = non_agg_items[data_type_idx] + else: + item = agg_items[data_type_idx - len(non_agg_items)] + if not issubclass(item.type, data_type): + item.val_expr = self.convert_val_expr_to_type(item.val_expr, data_type) + for idx, item in enumerate(chain(non_agg_items, agg_items)): + item.alias = '%s_col_%s' % (item.type.__name__.lower(), idx + 1) + return SelectClause(non_agg_items=non_agg_items, agg_items=agg_items) + + def _choose_col(self, table_exprs): + table_expr = choice(table_exprs) + return choice(table_expr.cols) + + def _create_non_agg_select_item(self, table_exprs): + return SelectItem(self._create_val_expr(table_exprs)) + + def _create_val_expr(self, table_exprs): + vals = [self._choose_col(table_exprs) for _ in xrange(one_or_more())] + return self._combine_val_exprs(vals) + + def _create_agg_select_item(self, table_exprs): + vals = [self._create_agg_val_expr(table_exprs) for _ in xrange(one_or_more())] + return SelectItem(self._combine_val_exprs(vals)) + + def _create_agg_val_expr(self, table_exprs): + val = self._create_val_expr(table_exprs) + if issubclass(val.type, Number): + funcs = list(AGG_FUNCS) + else: + funcs = [Count] + return choice(funcs)(val) + + def _create_from_clause(self, table_exprs): + table_expr = self._create_table_expr(table_exprs) + table_expr_count = 1 + table_expr.alias = 't%s' % table_expr_count + from_clause = FromClause(table_expr) + for join_idx in xrange(zero_or_more()): + join_clause = self._create_join_clause(from_clause, table_exprs) + table_expr_count += 1 + join_clause.table_expr.alias = 't%s' % table_expr_count + from_clause.join_clauses.append(join_clause) + return from_clause + + def _create_table_expr(self, table_exprs): + if randint(1, 10) == 1: + return self._create_inline_view(table_exprs) + return self._choose_table(table_exprs) + + def _choose_table(self, table_exprs): + return deepcopy(choice(table_exprs)) + + def _create_inline_view(self, table_exprs): + return InlineView(self.create_query(table_exprs)) + + def _create_join_clause(self, from_clause, table_exprs): + table_expr = self._create_table_expr(table_exprs) + # Increase the chance of using the first join type which is INNER + join_type_idx = (zero_or_more() / 2) % len(JoinClause.JOINS_TYPES) + join_type = JoinClause.JOINS_TYPES[join_type_idx] + join_clause = JoinClause(join_type, table_expr) + + # Prefer non-boolean cols for the first condition. Boolean cols produce too + # many results so it's unlikely that someone would want to join tables only using + # boolean cols. + non_boolean_types = set(type_ for type_ in TYPES if not issubclass(type_, Boolean)) + + if join_type != 'CROSS': + join_clause.boolean_expr = self._combine_val_exprs( + [self._create_relational_join_condition( + table_expr, + choice(from_clause.table_exprs), + prefered_data_types=(non_boolean_types if idx == 0 else set())) + for idx in xrange(one_or_more())], + resulting_type=Boolean) + return join_clause + + def _create_relational_join_condition(self, + left_table_expr, + right_table_expr, + prefered_data_types): + # "base type" means condense all int types into just int, same for floats + left_cols_by_base_type = left_table_expr.cols_by_base_type + right_cols_by_base_type = right_table_expr.cols_by_base_type + common_col_types = set(left_cols_by_base_type) & set(right_cols_by_base_type) + if prefered_data_types: + common_col_types &= prefered_data_types + if common_col_types: + col_type = choice(list(common_col_types)) + left = choice(left_cols_by_base_type[col_type]) + right = choice(right_cols_by_base_type[col_type]) + else: + col_type = None + if prefered_data_types: + for available_col_types in (left_cols_by_base_type, right_cols_by_base_type): + prefered_available_col_types = set(available_col_types) & prefered_data_types + if prefered_available_col_types: + col_type = choice(list(prefered_available_col_types)) + break + if not col_type: + col_type = choice(left_cols_by_base_type.keys()) + + if col_type in left_cols_by_base_type: + left = choice(left_cols_by_base_type[col_type]) + else: + left = choice(choice(left_cols_by_base_type.values())) + left = self.convert_val_expr_to_type(left, col_type) + if col_type in right_cols_by_base_type: + right = choice(right_cols_by_base_type[col_type]) + else: + right = choice(choice(right_cols_by_base_type.values())) + right = self.convert_val_expr_to_type(right, col_type) + return Equals(left, right) + + def _create_where_clause(self, table_exprs): + boolean_exprs = list() + # Create one boolean expr per iteration... + for _ in xrange(one_or_more()): + col_type = None + cols = list() + # ...using one or more cols... + for _ in xrange(one_or_more()): + # ...from any random table, inline view, etc. + table_expr = choice(table_exprs) + if not col_type: + col_type = choice(list(table_expr.cols_by_base_type)) + if col_type in table_expr.cols_by_base_type: + col = choice(table_expr.cols_by_base_type[col_type]) + else: + col = choice(table_expr.cols) + cols.append(col) + boolean_exprs.append(self._combine_val_exprs(cols, resulting_type=Boolean)) + return WhereClause(self._combine_val_exprs(boolean_exprs)) + + def _combine_val_exprs(self, vals, resulting_type=None): + '''Combine the given vals into a single val. + + If resulting_type is specified, the returned val will be of that type. If + the resulting data type was not specified, it will be randomly chosen from the + types of the input vals. + + ''' + if not vals: + raise Exception('At least one val is required') + + types_to_vals = DataType.group_by_base_type(vals) + + if not resulting_type: + resulting_type = choice(types_to_vals.keys()) + + vals_of_resulting_type = list() + + for val_type, vals in types_to_vals.iteritems(): + if issubclass(val_type, resulting_type): + vals_of_resulting_type.extend(vals) + elif resulting_type == Boolean: + # To produce other result types, the vals will be aggd into a single val + # then converted into the desired type. However to make a boolean, relational + # operaters can be used on the vals to make a more realistic query. + val = self._create_boolean_expr_from_vals_of_same_type(vals) + vals_of_resulting_type.append(val) + else: + val = self._combine_vals_of_same_type(vals) + if not (issubclass(val.type, Number) and issubclass(resulting_type, Number)): + val = self.convert_val_expr_to_type(val, resulting_type) + vals_of_resulting_type.append(val) + + return self._combine_vals_of_same_type(vals_of_resulting_type) + + def _create_boolean_expr_from_vals_of_same_type(self, vals): + if not vals: + raise Exception('At least one val is required') + + if len(vals) == 1: + val = vals[0] + if Boolean == val.type: + return val + # Convert a single non-boolean val into a boolean using a func like + # IsNull or IsNotNull. + return choice(UNARY_BOOLEAN_FUNCS)(val) + if len(vals) == 2: + left, right = vals + if left.type == right.type: + if left.type == String: + # Databases may vary in how string comparisons are done. Results may differ + # when using operators like > or <, so just always use =. + return Equals(left, right) + if left.type == Boolean: + # TODO: Enable "OR" at some frequency, using OR at 50% will probably produce + # too many slow queries. + return And(left, right) + # At this point we've got two data points of the same type so any valid + # relational operator is valid and will produce a boolean. + return choice(RELATIONAL_OPERATORS)(left, right) + elif issubclass(left.type, Number) and issubclass(right.type, Number): + # Numbers need not be of the same type. SmallInt, BigInt, etc can all be compared. + # Note: For now ints are the only numbers enabled and division is disabled + # though AVG() is in use. If floats are enabled this will likely need to be + # updated to do some rounding based comparison. + return choice(RELATIONAL_OPERATORS)(left, right) + raise Exception('Vals are not of the same type: %s<%s> vs %s<%s>' + % (left, left.type, right, right.type)) + # Reduce the number of inputs and try again... + left_subset, right_subset = random_non_empty_split(vals) + return self._create_boolean_expr_from_vals_of_same_type([ + self._combine_vals_of_same_type(left_subset), + self._combine_vals_of_same_type(right_subset)]) + + def _combine_vals_of_same_type(self, vals): + '''Combine the given vals into a single expr of the same type. The input + vals must be of the same base data type. For example Int's must not be mixed + with Strings. + + ''' + if not vals: + raise Exception('At least one val is required') + + val_type = None + for val in vals: + if not val_type: + if issubclass(val.type, Number): + val_type = Number + else: + val_type = val.type + elif not issubclass(val.type, val_type): + raise Exception('Incompatable types %s and %s' % (val_type, val.type)) + + if len(vals) == 1: + return vals[0] + + if val_type == Number: + funcs = MATH_OPERATORS + elif val_type == Boolean: + # TODO: Enable "OR" at some frequency + funcs = [And] + elif val_type == String: + funcs = BINARY_STRING_FUNCS + return vals[0] + elif val_type == Timestamp: + funcs = [Greatest] + + vals = list(vals) + shuffle(vals) + left = vals.pop() + right = vals.pop() + while True: + func = choice(funcs) + left = func(left, right) + if not vals: + return left + right = vals.pop() + + def convert_val_expr_to_type(self, val_expr, resulting_type): + if resulting_type not in TYPES: + raise Exception('Unexpected type: {}'.format(resulting_type)) + val_type = val_expr.type + if issubclass(val_type, resulting_type): + return val_expr + + if issubclass(resulting_type, Int): + if val_expr.returns_float: + # Impala will FLOOR while Postgresql will ROUND. Use FLOOR to be conistent. + return Floor(val_expr) + if issubclass(resulting_type, Number): + if val_expr.returns_string: + return Length(val_expr) + if issubclass(resulting_type, String): + if val_expr.returns_float: + # Different databases may use different precision. + return Cast(Floor(val_expr), resulting_type) + + return Cast(val_expr, resulting_type) + + def _create_having_clause(self, table_exprs): + boolean_exprs = list() + # Create one boolean expr per iteration... + for _ in xrange(one_or_more()): + agg_items = list() + # ...using one or more agg exprs... + for _ in xrange(one_or_more()): + vals = [self._create_agg_val_expr(table_exprs) for _ in xrange(one_or_more())] + agg_items.append(self._combine_val_exprs(vals)) + boolean_exprs.append(self._combine_val_exprs(agg_items, resulting_type=Boolean)) + return HavingClause(self._combine_val_exprs(boolean_exprs)) + + def _enable_distinct_on_random_agg_items(self, agg_items): + '''Randomly choose an agg func and set it to use DISTINCT''' + # Impala has a limitation where 'DISTINCT' may only be applied to one agg + # expr. If an agg expr is used more than once, each usage may + # or may not include DISTINCT. + # + # Examples: + # OK: SELECT COUNT(DISTINCT a) + SUM(DISTINCT a) + MAX(a)... + # Not OK: SELECT COUNT(DISTINCT a) + COUNT(DISTINCT b)... + # + # Given a select list like: + # COUNT(a), SUM(a), MAX(b) + # + # We want to ouput one of: + # COUNT(DISTINCT a), SUM(DISTINCT a), AVG(b) + # COUNT(DISTINCT a), SUM(a), AVG(b) + # COUNT(a), SUM(a), AVG(DISTINCT b) + # + # This will be done by first grouping all agg funcs by their inner + # expr: + # {a: [COUNT(a), SUM(a)], + # b: [MAX(b)]} + # + # then choosing a random val (which is a list of aggs) in the above dict, and + # finaly randomly adding DISTINCT to items in the list. + exprs_to_funcs = defaultdict(list) + for item in agg_items: + for expr, funcs in self._group_agg_funcs_by_expr(item.val_expr).iteritems(): + exprs_to_funcs[expr].extend(funcs) + funcs = choice(exprs_to_funcs.values()) + for func in funcs: + if random_boolean(): + func.distinct = True + + def _group_agg_funcs_by_expr(self, val_expr): + '''Group exprs and return a dict mapping the expr to the agg items + it is used in. + + Example: COUNT(a) * SUM(a) - MAX(b) + MIN(c) -> {a: [COUNT(a), SUM(a)], + b: [MAX(b)], + c: [MIN(c)]} + + ''' + exprs_to_funcs = defaultdict(list) + if isinstance(val_expr, AggFunc): + exprs_to_funcs[tuple(val_expr.args)].append(val_expr) + elif isinstance(val_expr, Func): + for arg in val_expr.args: + for expr, funcs in self._group_agg_funcs_by_expr(arg).iteritems(): + exprs_to_funcs[expr].extend(funcs) + # else: The remaining case could happen if the original expr was something like + # "SUM(a) + b + 1" where b is a GROUP BY field. + return exprs_to_funcs + + +if __name__ == '__main__': + '''Generate some queries for manual inspection. The query won't run anywhere because the + tables used are fake. To make real queries, we'd need to connect to a database and + read the table metadata and such. + ''' + tables = list() + data_types = TYPES + data_types.remove(Float) + data_types.remove(Double) + for table_idx in xrange(5): + table = Table('table_%s' % table_idx) + tables.append(table) + for col_idx in xrange(3): + col_type = choice(data_types) + col = Column(table, '%s_col_%s' % (col_type.__name__.lower(), col_idx), col_type) + table.cols.append(col) + + query_generator = QueryGenerator() + from model_translator import SqlWriter + sql_writer = SqlWriter.create() + for _ in range(3000): + query = query_generator.create_query(tables) + print(sql_writer.write_query(query) + '\n')