mirror of
https://github.com/apache/impala.git
synced 2025-12-19 18:12:08 -05:00
This takes steps to make Python 2 behave like Python 3 as
a way to flush out issues with running on Python 3. Specifically,
it handles two main differences:
1. Python 3 requires absolute imports within packages. This
can be emulated via "from __future__ import absolute_import"
2. Python 3 changed division to "true" division that doesn't
round to an integer. This can be emulated via
"from __future__ import division"
This changes all Python files to add imports for absolute_import
and division. For completeness, this also includes print_function in the
import.
I scrutinized each old-division location and converted some locations
to use the integer division '//' operator if it needed an integer
result (e.g. for indices, counts of records, etc). Some code was also using
relative imports and needed to be adjusted to handle absolute_import.
This fixes all Pylint warnings about no-absolute-import and old-division,
and these warnings are now banned.
Testing:
- Ran core tests
Change-Id: Idb0fcbd11f3e8791f5951c4944be44fb580e576b
Reviewed-on: http://gerrit.cloudera.org:8080/19588
Reviewed-by: Joe McDonnell <joemcdonnell@cloudera.com>
Tested-by: Joe McDonnell <joemcdonnell@cloudera.com>
846 lines
30 KiB
Python
846 lines
30 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
from inspect import getmro
|
|
from logging import getLogger
|
|
from re import sub
|
|
from sqlparse import format
|
|
|
|
from tests.comparison.common import StructColumn, CollectionColumn
|
|
from tests.comparison.db_types import (
|
|
Char,
|
|
Decimal,
|
|
Double,
|
|
Float,
|
|
Int,
|
|
String,
|
|
Timestamp,
|
|
TinyInt,
|
|
VarChar)
|
|
from tests.comparison.query import InsertClause, Query
|
|
from tests.comparison.query_flattener import QueryFlattener
|
|
|
|
LOG = getLogger(__name__)
|
|
|
|
|
|
class SqlWriter(object):
|
|
'''Subclasses of SQLWriter will take a Query and provide the SQL representation for a
|
|
specific database such as Impala or MySQL. The SqlWriter.create([dialect=])
|
|
factory method may be used instead of specifying the concrete class.
|
|
|
|
'''
|
|
|
|
@staticmethod
|
|
def create(dialect='impala', nulls_order_asc='DEFAULT'):
|
|
'''Create and return a new SqlWriter appropriate for the given sql dialect. "dialect"
|
|
refers to database specific deviations of sql, and the val should be one of
|
|
"IMPALA", "MYSQL", or "POSTGRESQL".
|
|
'''
|
|
dialect = dialect.upper()
|
|
if dialect == 'IMPALA':
|
|
return ImpalaSqlWriter(nulls_order_asc)
|
|
if dialect == 'MYSQL':
|
|
return MySQLSqlWriter(nulls_order_asc)
|
|
if dialect == 'ORACLE':
|
|
return OracleSqlWriter(nulls_order_asc)
|
|
if dialect == 'POSTGRESQL':
|
|
return PostgresqlSqlWriter(nulls_order_asc)
|
|
if dialect == 'HIVE':
|
|
return HiveSqlWriter(nulls_order_asc)
|
|
raise Exception('Unknown dialect: %s' % dialect)
|
|
|
|
def __init__(self, nulls_order_asc):
|
|
if nulls_order_asc not in ('BEFORE', 'AFTER', 'DEFAULT'):
|
|
raise Exception('Unknown nulls order: %s' % nulls_order_asc)
|
|
self.nulls_order_asc = nulls_order_asc
|
|
|
|
# Functions that don't follow the usual call syntax of foo(bar, baz) can be listed
|
|
# here. Parenthesis were added everywhere to avoid problems with operator precedence.
|
|
# TODO: Account for operator precedence...
|
|
self.operator_funcs = {
|
|
'And': '({0}) AND ({1})',
|
|
'Or': '({0}) OR ({1})',
|
|
'Plus': '({0}) + ({1})',
|
|
'Minus': '({0}) - ({1})',
|
|
'Multiply': '({0}) * ({1})',
|
|
'Divide': '({0}) / ({1})',
|
|
'Equals': '({0}) = ({1})',
|
|
'NotEquals': '({0}) != ({1})',
|
|
'IsNotDistinctFrom': '({0}) IS NOT DISTINCT FROM ({1})',
|
|
# If a database supports the operator version of 'IS NOT DISTINCT FROM', it
|
|
# should overwrite the value of 'IsNotDistinctFromOp'.
|
|
'IsNotDistinctFromOp': '({0}) IS NOT DISTINCT FROM ({1})',
|
|
'IsDistinctFrom': '({0}) IS DISTINCT FROM ({1})',
|
|
'LessThan': '({0}) < ({1})',
|
|
'GreaterThan': '({0}) > ({1})',
|
|
'LessThanOrEquals': '({0}) <= ({1})',
|
|
'GreaterThanOrEquals': '({0}) >= ({1})',
|
|
'IsNull': '({0}) IS NULL',
|
|
'IsNotNull': '({0}) IS NOT NULL'}
|
|
|
|
def write_query(self, statement, pretty=False):
|
|
"""
|
|
Return SQL as a string for the given query.
|
|
|
|
If "pretty" is True, the SQL will be formatted (though not very well) with new
|
|
lines and indentation.
|
|
"""
|
|
sql = self._write(statement)
|
|
if pretty:
|
|
sql = self.make_pretty_sql(sql)
|
|
return sql
|
|
|
|
def _write_query(self, query):
|
|
"""
|
|
Taking in a Query object with some attributes set, return a string
|
|
representation of the query in the correct dialect.
|
|
|
|
This is just another dispatch destination of self._write(). When
|
|
self._write(Query) is called, that's dispatched to self._write_query(Query)
|
|
"""
|
|
sql = list()
|
|
# Write out each section in the proper order
|
|
for clause in (
|
|
query.with_clause,
|
|
query.select_clause,
|
|
query.from_clause,
|
|
query.where_clause,
|
|
query.group_by_clause,
|
|
query.having_clause,
|
|
query.union_clause,
|
|
query.order_by_clause,
|
|
query.limit_clause
|
|
):
|
|
if clause:
|
|
sql.append(self._write(clause))
|
|
sql = '\n'.join(sql)
|
|
return sql
|
|
|
|
def _write_insert_statement(self, insert_statement):
|
|
"""
|
|
Taking in a InsertStatement object with some attributes set, return a string
|
|
representation of the query in the correct dialect.
|
|
"""
|
|
sql = list()
|
|
|
|
if insert_statement.with_clause:
|
|
sql.append(self._write(insert_statement.with_clause))
|
|
|
|
if insert_statement.insert_clause:
|
|
sql.append(self._write(insert_statement.insert_clause))
|
|
else:
|
|
raise Exception('InsertStatement is missing insert_clause attribute')
|
|
|
|
if insert_statement.select_query and not insert_statement.values_clause:
|
|
sql.append(self._write(insert_statement.select_query))
|
|
elif not insert_statement.select_query and insert_statement.values_clause:
|
|
sql.append(self._write(insert_statement.values_clause))
|
|
else:
|
|
raise Exception('InsertStatement must have a select_query xor a values clause')
|
|
|
|
sql = '\n'.join(sql)
|
|
return sql
|
|
|
|
def make_pretty_sql(self, sql):
|
|
try:
|
|
sql = format(sql, reindent=True)
|
|
except Exception as e:
|
|
LOG.warn('Unable to format sql: %s', e)
|
|
return sql
|
|
|
|
def write_create_table_as(self, query, name, pretty=False):
|
|
return 'CREATE TABLE %s AS %s' % (name, self.write_query(query, pretty=pretty))
|
|
|
|
def write_create_view(self, query, name, pretty=False):
|
|
return 'CREATE VIEW %s AS %s' % (name, self.write_query(query, pretty=pretty))
|
|
|
|
def _write_with_clause(self, with_clause):
|
|
return 'WITH ' + ',\n'.join('%s AS (%s)' % (view.identifier, self._write(view.query))
|
|
for view in with_clause.with_clause_inline_views)
|
|
|
|
def _write_select_clause(self, select_clause):
|
|
if hasattr(select_clause, 'star_prefix'):
|
|
# This is a little hack to get query flattening to work (look at query_flattener.py
|
|
# TODO: Add proper support for SELECT *, and SELECT table_name.*
|
|
return 'SELECT %s.*' % select_clause.star_prefix
|
|
sql = 'SELECT'
|
|
if select_clause.distinct:
|
|
sql += ' DISTINCT'
|
|
sql += '\n' + ',\n'.join(self._write(item) for item in select_clause.items)
|
|
return sql
|
|
|
|
def _write_select_item(self, select_item):
|
|
if select_item.alias:
|
|
return '{0} AS {1}'.format(self._write(select_item.val_expr), select_item.alias)
|
|
else:
|
|
return self._write(select_item.val_expr)
|
|
|
|
def _write_struct_column(self, struct_col):
|
|
if isinstance(struct_col.owner, StructColumn) or \
|
|
(isinstance(struct_col.owner, CollectionColumn) and not struct_col.owner.alias):
|
|
return '%s.%s' % (self._write(struct_col.owner), struct_col.name)
|
|
else:
|
|
return '%s.%s' % (struct_col.owner.identifier, struct_col.name)
|
|
|
|
def _write_collection_column(self, collection_col):
|
|
if isinstance(collection_col.owner, (StructColumn, CollectionColumn)) and \
|
|
not collection_col.owner.alias:
|
|
if collection_col.alias:
|
|
return '%s.%s %s' % (
|
|
self._write(collection_col.owner),
|
|
collection_col.name,
|
|
collection_col.alias)
|
|
else:
|
|
return '%s.%s' % (self._write(collection_col.owner), collection_col.name)
|
|
else:
|
|
if collection_col.alias:
|
|
return '%s.%s %s' % (
|
|
collection_col.owner.identifier,
|
|
collection_col.name,
|
|
collection_col.alias)
|
|
else:
|
|
return '%s.%s' % (collection_col.owner.identifier, collection_col.name)
|
|
|
|
def _write_column(self, col):
|
|
if isinstance(col.owner, StructColumn):
|
|
return '%s.%s' % (self._write(col.owner), col.name)
|
|
return '%s.%s' % (col.owner.identifier, col.name)
|
|
|
|
def _write_from_clause(self, from_clause):
|
|
sql = 'FROM %s' % self._write(from_clause.table_expr)
|
|
if from_clause.join_clauses:
|
|
sql += '\n' + '\n'.join(self._write(join) for join in from_clause.join_clauses)
|
|
return sql
|
|
|
|
def _write_table(self, table):
|
|
if table.alias:
|
|
return '%s %s' % (table.name, table.identifier)
|
|
return table.name
|
|
|
|
def _write_inline_view(self, inline_view):
|
|
if not inline_view.identifier:
|
|
raise Exception('An inline view requires an identifier')
|
|
return '(\n%s\n) %s' % (self._write(inline_view.query), inline_view.identifier)
|
|
|
|
def _write_with_clause_inline_view(self, with_clause_inline_view):
|
|
if not with_clause_inline_view.with_clause_alias:
|
|
raise Exception('An with clause entry requires an identifier')
|
|
sql = with_clause_inline_view.with_clause_alias
|
|
if with_clause_inline_view.alias:
|
|
sql += ' ' + with_clause_inline_view.alias
|
|
return sql
|
|
|
|
def _write_join_clause(self, join_clause):
|
|
sql = '%s JOIN %s' % (join_clause.join_type, self._write(join_clause.table_expr))
|
|
if join_clause.boolean_expr:
|
|
sql += ' ON ' + self._write(join_clause.boolean_expr)
|
|
return sql
|
|
|
|
def _write_where_clause(self, where_clause):
|
|
return 'WHERE\n' + self._write(where_clause.boolean_expr)
|
|
|
|
def _write_group_by_clause(self, group_by_clause):
|
|
return 'GROUP BY\n' + ',\n'.join(self._write(item.val_expr)
|
|
for item in group_by_clause.group_by_items)
|
|
|
|
def _write_having_clause(self, having_clause):
|
|
return 'HAVING\n' + self._write(having_clause.boolean_expr)
|
|
|
|
def _write_union_clause(self, union_clause):
|
|
sql = 'UNION'
|
|
if union_clause.all:
|
|
sql += ' ALL'
|
|
sql += '\n' + self._write(union_clause.query)
|
|
return sql
|
|
|
|
def _write_data_type(self, data_type):
|
|
'''Write a literal value.'''
|
|
if data_type.val is None:
|
|
return 'NULL'
|
|
if data_type.returns_char:
|
|
return "'{0}'".format(data_type.val)
|
|
if data_type.returns_timestamp:
|
|
return "CAST('{0}' AS {1})".format(
|
|
data_type.val, self._write_data_type_metaclass(Timestamp))
|
|
return str(data_type.val)
|
|
|
|
def _write_func(self, func):
|
|
if func.name() in self.operator_funcs:
|
|
sql = self.operator_funcs[func.name()].format(
|
|
*[self._write(arg) for arg in func.args])
|
|
else:
|
|
sql = '%s(%s)' % \
|
|
(self._to_sql_name(func.name()), self._write_as_comma_list(func.args))
|
|
return sql
|
|
|
|
def _write_exists(self, func):
|
|
return 'EXISTS ' + self._write(func.args[0])
|
|
|
|
def _write_not_exists(self, func):
|
|
return 'NOT EXISTS ' + self._write(func.args[0])
|
|
|
|
def _write_in(self, func, use_not=False):
|
|
sql = '(%s) ' % self._write(func.args[0])
|
|
if use_not:
|
|
sql += 'NOT '
|
|
sql += 'IN '
|
|
if func.signature.args[1].is_subquery:
|
|
sql += self._write(func.args[1])
|
|
else:
|
|
sql += '(' + self._write_as_comma_list(func.args[1:]) + ')'
|
|
return sql
|
|
|
|
def _write_not_in(self, func):
|
|
return self._write_in(func, use_not=True)
|
|
|
|
def _write_as_comma_list(self, items):
|
|
return ', '.join([self._write(item) for item in items])
|
|
|
|
def _write_cast_as_char(self, func):
|
|
return 'CAST(%s AS %s)' % (self._write(func.args[0]), self._write(String))
|
|
|
|
def _write_cast(self, arg, type):
|
|
return 'CAST(%s AS %s)' % (self._write(arg), type)
|
|
|
|
def _write_cast_func(self, func):
|
|
val_expr = func.args[0]
|
|
type_ = func.args[1]
|
|
return 'CAST({val_expr} AS {type_})'.format(
|
|
val_expr=self._write(val_expr), type_=self._write(type_))
|
|
|
|
def _write_date_add_year(self, func):
|
|
return "%s + INTERVAL %s YEAR" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_date_add_month(self, func):
|
|
return "%s + INTERVAL %s MONTH" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_date_add_day(self, func):
|
|
return "%s + INTERVAL %s DAY" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_date_add_hour(self, func):
|
|
return "%s + INTERVAL %s HOUR" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_date_add_minute(self, func):
|
|
return "%s + INTERVAL %s MINUTE" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_date_add_second(self, func):
|
|
return "%s + INTERVAL %s SECOND" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_extract_year(self, func):
|
|
return 'EXTRACT(YEAR FROM %s)' % self._write(func.args[0])
|
|
|
|
def _write_extract_month(self, func):
|
|
return 'EXTRACT(MONTH FROM %s)' % self._write(func.args[0])
|
|
|
|
def _write_extract_day(self, func):
|
|
return 'EXTRACT(DAY FROM %s)' % self._write(func.args[0])
|
|
|
|
def _write_extract_hour(self, func):
|
|
return 'EXTRACT(HOUR FROM %s)' % self._write(func.args[0])
|
|
|
|
def _write_extract_minute(self, func):
|
|
return 'EXTRACT(MINUTE FROM %s)' % self._write(func.args[0])
|
|
|
|
def _write_extract_second(self, func):
|
|
return 'EXTRACT(SECOND FROM %s)' % self._write(func.args[0])
|
|
|
|
def _write_implicit_cast(self, func):
|
|
return self._write(func.args[0])
|
|
|
|
def _write_analytic_func(self, func):
|
|
sql = self._to_sql_name(func.name()) \
|
|
+ '(' + self._write_as_comma_list(func.args) \
|
|
+ ') OVER ('
|
|
options = []
|
|
if func.partition_by_clause:
|
|
options.append(self._write(func.partition_by_clause))
|
|
if func.order_by_clause:
|
|
options.append(self._write(func.order_by_clause))
|
|
if func.window_clause:
|
|
options.append(self._write(func.window_clause))
|
|
return sql + ' '.join(options) + ')'
|
|
|
|
def _write_partition_by_clause(self, partition_by_clause):
|
|
return 'PARTITION BY ' + \
|
|
', '.join(self._write(expr) for expr in partition_by_clause.val_exprs)
|
|
|
|
def _write_window_clause(self, window_clause):
|
|
sql = window_clause.range_or_rows.upper() + ' '
|
|
if window_clause.end_boundary:
|
|
sql += 'BETWEEN '
|
|
if window_clause.start_boundary.val_expr:
|
|
sql += self._write(window_clause.start_boundary.val_expr) + ' '
|
|
sql += window_clause.start_boundary.boundary_type.upper() + ' '
|
|
if window_clause.end_boundary:
|
|
sql += 'AND '
|
|
if window_clause.end_boundary.val_expr:
|
|
sql += self._write(window_clause.end_boundary.val_expr) + ' '
|
|
sql += window_clause.end_boundary.boundary_type.upper()
|
|
return sql
|
|
|
|
def _write_agg_func(self, func):
|
|
sql = self._to_sql_name(func.name()) + '('
|
|
if func.distinct:
|
|
sql += 'DISTINCT '
|
|
# All agg funcs only have a single arg
|
|
sql += self._write(func.args[0]) + ')'
|
|
return sql
|
|
|
|
def _write_data_type_metaclass(self, data_type_class):
|
|
'''Write a data type class such as Int, Boolean, or Decimal(4, 2).'''
|
|
if data_type_class == Char:
|
|
return 'CHAR({0})'.format(data_type_class.MAX)
|
|
elif data_type_class == VarChar:
|
|
return 'VARCHAR({0})'.format(data_type_class.MAX)
|
|
elif data_type_class == Decimal:
|
|
return 'DECIMAL({scale},{precision})'.format(
|
|
scale=data_type_class.MAX_DIGITS,
|
|
precision=data_type_class.MAX_FRACTIONAL_DIGITS)
|
|
else:
|
|
return data_type_class.__name__.upper()
|
|
|
|
def _write_subquery(self, subquery):
|
|
return '({0})'.format(self._write(subquery.query))
|
|
|
|
def _write_order_by_clause(self, order_by_clause):
|
|
sql = 'ORDER BY '
|
|
for idx, (expr, order) in enumerate(order_by_clause.exprs_to_order):
|
|
if idx > 0:
|
|
sql += ', '
|
|
sql += self._write(expr)
|
|
if order:
|
|
sql += ' ' + order
|
|
nulls_order = self.get_nulls_order(order)
|
|
if nulls_order is not None:
|
|
sql += ' ' + nulls_order
|
|
return sql
|
|
|
|
def _write_limit_clause(self, limit_clause):
|
|
return 'LIMIT {0}'.format(limit_clause.limit)
|
|
|
|
def _write_insert_clause(self, insert_clause):
|
|
"""
|
|
Given an InsertClause, return a string representing that portion of the query. The
|
|
InsertClause object may have the column_list attribute set, which is a
|
|
sequence of columns.
|
|
"""
|
|
if insert_clause.column_list is None:
|
|
column_list = ''
|
|
else:
|
|
column_list = ' ({column_list})'.format(
|
|
column_list=', '.join([col.name for col in insert_clause.column_list]))
|
|
return 'INSERT INTO {table_name}{column_list}'.format(
|
|
table_name=insert_clause.table.name, column_list=column_list)
|
|
|
|
def _write_values_row(self, values_row):
|
|
"""
|
|
Return a string representing 1 row of a VALUES clause.
|
|
"""
|
|
return '({values_row})'.format(
|
|
values_row=', '.join([self._write(item) for item in values_row.items]))
|
|
|
|
def _write_values_clause(self, values_clause):
|
|
"""
|
|
Return a string representing the VALUES clause of an INSERT query.
|
|
"""
|
|
return 'VALUES\n{values_rows}'.format(
|
|
values_rows=',\n'.join([self._write(values_row)
|
|
for values_row in values_clause.values_rows]))
|
|
|
|
def _write(self, object_):
|
|
'''Return a sql string representation of the given object.'''
|
|
# What's below is effectively a giant switch statement. It works based on a func
|
|
# naming and signature convention. It should match the incoming object with the
|
|
# corresponding func defined, then call the func and return the result.
|
|
#
|
|
# Ex:
|
|
# a = model.And(...)
|
|
# _write(a) should call _write_func(a) because "And" is a subclass of "Func" and no
|
|
# other _writer_<class name> methods have been defined higher up the method
|
|
# resolution order (MRO). If _write_and(...) were to be defined, it would be called
|
|
# instead.
|
|
for type_ in getmro(type(object_)):
|
|
writer_func_name = '_write_' + self._to_py_name(type_.__name__)
|
|
writer_func = getattr(self, writer_func_name, None)
|
|
if writer_func:
|
|
return writer_func(object_)
|
|
|
|
# Handle any remaining cases
|
|
if isinstance(object_, Query):
|
|
return self.write_query(object_)
|
|
|
|
raise Exception('Unsupported object: %s<%s>' % (type(object_).__name__, object_))
|
|
|
|
def get_nulls_order(self, order):
|
|
if self.nulls_order_asc is None:
|
|
return None
|
|
nulls_order_asc = self.nulls_order_asc
|
|
if order == 'ASC':
|
|
if nulls_order_asc == 'BEFORE':
|
|
return 'NULLS FIRST'
|
|
if nulls_order_asc == 'AFTER':
|
|
return 'NULLS LAST'
|
|
if order == 'DESC':
|
|
if nulls_order_asc == 'BEFORE':
|
|
return 'NULLS LAST'
|
|
if nulls_order_asc == 'AFTER':
|
|
return 'NULLS FIRST'
|
|
|
|
def _to_py_name(self, name):
|
|
return sub('([A-Z])', r'_\1', name).lower().lstrip('_')
|
|
|
|
def _to_sql_name(self, name):
|
|
return self._to_py_name(name).upper()
|
|
|
|
|
|
class ImpalaSqlWriter(SqlWriter):
|
|
|
|
DIALECT = 'IMPALA'
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(ImpalaSqlWriter, self).__init__(*args, **kwargs)
|
|
self.operator_funcs['IsNotDistinctFromOp'] = '({0}) <=> ({1})'
|
|
|
|
def _write_column(self, col):
|
|
result = super(ImpalaSqlWriter, self)._write_column(col)
|
|
if col.exact_type == Char:
|
|
# TRIM is a temporary workaround for IMPALA-1652
|
|
result = 'TRIM(%s)' % result
|
|
return result
|
|
|
|
def _write_insert_clause(self, insert_clause):
|
|
sql = super(ImpalaSqlWriter, self)._write_insert_clause(insert_clause)
|
|
if insert_clause.conflict_action == InsertClause.CONFLICT_ACTION_UPDATE:
|
|
# The value of sql at this point would be something like:
|
|
#
|
|
# INSERT INTO <table name> [(column list)]
|
|
#
|
|
# If it happens that the table name or column list contains the text INSERT in an
|
|
# identifier, we want to ensure that the replace() call below does not alter their
|
|
# names but instead only modifiers the INSERT keyword to UPSERT.
|
|
return sql.replace('INSERT', 'UPSERT', 1)
|
|
else:
|
|
return sql
|
|
|
|
|
|
class OracleSqlWriter(SqlWriter):
|
|
|
|
DIALECT = 'ORACLE'
|
|
|
|
|
|
class HiveSqlWriter(SqlWriter):
|
|
|
|
DIALECT = 'HIVE'
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(HiveSqlWriter, self).__init__(*args, **kwargs)
|
|
|
|
self.operator_funcs.update({
|
|
'IsNotDistinctFrom': '({0}) <=> ({1})',
|
|
'IsNotDistinctFromOp': '({0}) <=> ({1})',
|
|
'IsDistinctFrom': 'NOT(({0}) <=> ({1}))'
|
|
})
|
|
|
|
# Hive greatest UDF is strict on type equality
|
|
# Hive Profile already restricts to signatures with the same types,
|
|
# but sometimes expression with UDF's like 'count'
|
|
# return an unpredictable type like 'bigint' unlike
|
|
# the query model, so cast is still necessary.
|
|
def _write_greatest(self, func):
|
|
args = func.args
|
|
if args[0].type in (Int, Decimal, Float):
|
|
argtype = args[0].type.__name__.lower()
|
|
sql = '%s(%s)' % (
|
|
self._to_sql_name(func.name()),
|
|
self._write_cast(args[0], argtype) + ", " + self._write_cast(args[1], argtype))
|
|
else:
|
|
sql = self._write_func(func)
|
|
return sql
|
|
|
|
# Hive least UDF is strict on type equality
|
|
# Hive Profile already restricts to signatures with the same types,
|
|
# but sometimes expression with UDF's like 'count'
|
|
# return an unpredictable type like 'bigint' unlike
|
|
# the query model, so cast is still necessary.
|
|
def _write_least(self, func):
|
|
args = func.args
|
|
if args[0].type in (Int, Decimal, Float):
|
|
argtype = args[0].type.__name__.lower()
|
|
sql = '%s(%s)' % (
|
|
self._to_sql_name(func.name()),
|
|
self._write_cast(args[0], argtype) + ", " + self._write_cast(args[1], argtype))
|
|
else:
|
|
sql = self._write_func(func)
|
|
return sql
|
|
|
|
# Workaround for tinyint casting issues when run against RefDb
|
|
# that might only have larger integers.
|
|
# This does all the arithmetic operations in terms of bigints.
|
|
def _write_plus(self, func):
|
|
return self.arithmetic_cast(func, '+')
|
|
|
|
def _write_minus(self, func):
|
|
return self.arithmetic_cast(func, '-')
|
|
|
|
def _write_multiply(self, func):
|
|
return self.arithmetic_cast(func, '*')
|
|
|
|
def arithmetic_cast(self, func, symbol):
|
|
args = func.args
|
|
if args[0].type is Int and args[1].type is Int:
|
|
return 'CAST (%s AS BIGINT) %s CAST (%s AS BIGINT)' % (
|
|
self._write(args[0]), symbol, self._write(args[1]))
|
|
else:
|
|
return self._write_func(func)
|
|
|
|
# Hive partition by clause throws exception if sorted by more than one key, unless
|
|
# 'rows unbounded preceding' added.
|
|
def _write_analytic_func(self, func):
|
|
sql = self._to_sql_name(func.name()) \
|
|
+ '(' + self._write_as_comma_list(func.args) \
|
|
+ ') OVER ('
|
|
options = []
|
|
if func.partition_by_clause:
|
|
options.append(self._write(func.partition_by_clause))
|
|
if func.order_by_clause:
|
|
options.append(self._write(func.order_by_clause))
|
|
if func.window_clause:
|
|
options.append(self._write(func.window_clause))
|
|
if func.partition_by_clause and func.order_by_clause:
|
|
if len(func.order_by_clause.exprs_to_order) > 1:
|
|
if func.SUPPORTS_WINDOWING and func.window_clause is None:
|
|
options.append('rows unbounded preceding')
|
|
return sql + ' '.join(options) + ')'
|
|
|
|
def _write_extract_year(self, func):
|
|
return 'YEAR(%s)' % self._write(func.args[0])
|
|
|
|
def _write_extract_month(self, func):
|
|
return 'MONTH(%s)' % self._write(func.args[0])
|
|
|
|
def _write_extract_day(self, func):
|
|
return 'DAY(%s)' % self._write(func.args[0])
|
|
|
|
def _write_extract_hour(self, func):
|
|
return 'HOUR(%s)' % self._write(func.args[0])
|
|
|
|
def _write_extract_minute(self, func):
|
|
return 'MINUTE(%s)' % self._write(func.args[0])
|
|
|
|
def _write_extract_second(self, func):
|
|
return 'SECOND(%s)' % self._write(func.args[0])
|
|
|
|
|
|
class PostgresqlSqlWriter(SqlWriter):
|
|
|
|
DIALECT = 'POSTGRESQL'
|
|
|
|
def _write_insert_statement(self, insert_statement):
|
|
sql = SqlWriter._write_insert_statement(self, insert_statement)
|
|
if insert_statement.conflict_action == InsertClause.CONFLICT_ACTION_DEFAULT:
|
|
pass
|
|
elif insert_statement.conflict_action == InsertClause.CONFLICT_ACTION_IGNORE:
|
|
sql += '\nON CONFLICT DO NOTHING'
|
|
elif insert_statement.conflict_action == InsertClause.CONFLICT_ACTION_UPDATE:
|
|
if insert_statement.updatable_column_names:
|
|
primary_keys = insert_statement.primary_key_string
|
|
columns = ',\n'.join('{name} = EXCLUDED.{name}'.format(name=name) for name in
|
|
insert_statement.updatable_column_names)
|
|
sql += '\nON CONFLICT {primary_keys}\nDO UPDATE SET\n{columns}'.format(
|
|
primary_keys=primary_keys, columns=columns)
|
|
else:
|
|
sql += '\nON CONFLICT DO NOTHING'
|
|
else:
|
|
raise Exception('InsertStatement has unsupported conflict_action: {0}'.format(
|
|
insert_statement.conflict_action))
|
|
return sql
|
|
|
|
def _write_date_add_year(self, func):
|
|
return "%s + (%s) * INTERVAL '1' YEAR" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_date_add_month(self, func):
|
|
return "%s + (%s) * INTERVAL '1' MONTH" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_date_add_day(self, func):
|
|
return "%s + (%s) * INTERVAL '1' DAY" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_date_add_hour(self, func):
|
|
return "%s + (%s) * INTERVAL '1' HOUR" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_join_clause(self, join_clause):
|
|
sql = '%s JOIN %s %s' % (
|
|
join_clause.join_type,
|
|
'LATERAL' if join_clause.is_lateral_join else '',
|
|
self._write(join_clause.table_expr))
|
|
if join_clause.boolean_expr:
|
|
sql += ' ON ' + self._write(join_clause.boolean_expr)
|
|
return sql
|
|
|
|
def _write_date_add_minute(self, func):
|
|
return "%s + (%s) * INTERVAL '1' MINUTE" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_date_add_second(self, func):
|
|
return "%s + (%s) * INTERVAL '1' SECOND" \
|
|
% (self._write(func.args[0]), self._write(func.args[1]))
|
|
|
|
def _write_column(self, col):
|
|
def first_non_struct_ancestor(col):
|
|
col = col.owner
|
|
while isinstance(col, StructColumn):
|
|
col = col.owner
|
|
return col
|
|
return '%s.%s' % (first_non_struct_ancestor(col).identifier,
|
|
QueryFlattener.flat_column_name(col))
|
|
|
|
def _write_collection_column(self, collection_col):
|
|
return '%s %s' % (QueryFlattener.flat_collection_name(collection_col),
|
|
collection_col.identifier)
|
|
|
|
def _write_extract_second(self, func):
|
|
# For some reason Postgresql decided that extracting second should return a FLOAT...
|
|
return 'FLOOR(EXTRACT(SECOND FROM %s))' % self._write(func.args[0])
|
|
|
|
def _write_if(self, func):
|
|
return 'CASE WHEN {0} THEN {1} ELSE {2} END' \
|
|
.format(*[self._write(arg) for arg in func.args])
|
|
|
|
def _write_data_type_metaclass(self, data_type_class):
|
|
'''Write a data type class such as Int or Boolean.'''
|
|
if data_type_class == Double:
|
|
return 'DOUBLE PRECISION'
|
|
elif data_type_class == Float:
|
|
return 'REAL'
|
|
elif data_type_class == String:
|
|
return 'VARCHAR(%s)' % data_type_class.MAX
|
|
elif data_type_class == Timestamp:
|
|
return 'TIMESTAMP WITHOUT TIME ZONE'
|
|
elif data_type_class == TinyInt:
|
|
return 'SMALLINT'
|
|
else:
|
|
return super(PostgresqlSqlWriter, self)._write_data_type_metaclass(data_type_class)
|
|
|
|
def _write_order_by_clause(self, order_by_clause):
|
|
sql = 'ORDER BY '
|
|
for idx, (expr, order) in enumerate(order_by_clause.exprs_to_order):
|
|
if idx > 0:
|
|
sql += ', '
|
|
if expr.returns_char:
|
|
sql += 'CAST({0} AS BYTEA)'.format(self._write(expr))
|
|
else:
|
|
sql += self._write(expr)
|
|
if order:
|
|
sql += ' ' + order
|
|
nulls_order = self.get_nulls_order(order)
|
|
if (nulls_order is not None):
|
|
sql += ' ' + nulls_order
|
|
return sql
|
|
|
|
def _write_data_type(self, data_type):
|
|
'''Write a literal value.'''
|
|
if data_type.val is None:
|
|
return 'NULL'
|
|
if data_type.returns_char:
|
|
# Literals sometimes produce an error 'could not determine polymorphic type
|
|
# because input has type "unknown"', adding an "|| ''" avoids the problem.
|
|
return "'%s' || ''" % data_type.val
|
|
return SqlWriter._write_data_type(self, data_type)
|
|
|
|
def _write_concat(self, func):
|
|
# PostgreSQL CONCAT() doesn't behave like Impala CONCAT(). PostgreSQL || does.
|
|
return '({concat_list})'.format(
|
|
concat_list=' || '.join(['({written_item})'.format(written_item=self._write(item))
|
|
for item in func.args]))
|
|
|
|
|
|
class MySQLSqlWriter(SqlWriter):
|
|
|
|
DIALECT = 'MYSQL'
|
|
|
|
def write_query(self, query, pretty=False):
|
|
# MySQL doesn't support WITH clauses so they need to be converted into inline views.
|
|
# We are going to cheat by making use of the fact that the query generator creates
|
|
# with clause entries with unique aliases even considering nested queries.
|
|
sql = list()
|
|
for clause in (
|
|
query.select_clause,
|
|
query.from_clause,
|
|
query.where_clause,
|
|
query.group_by_clause,
|
|
query.having_clause,
|
|
query.union_clause,
|
|
query.order_by_clause,
|
|
query.limit_clause
|
|
):
|
|
if clause:
|
|
sql.append(self._write(clause))
|
|
sql = '\n'.join(sql)
|
|
if query.with_clause:
|
|
# Just replace the named referenes with inline views. Go in reverse order because
|
|
# entries at the bottom of the WITH clause definition may reference entries above.
|
|
for with_clause_inline_view in reversed(query.with_clause.with_clause_inline_views):
|
|
replacement_sql = '(' + self.write_query(with_clause_inline_view.query) + ')'
|
|
sql = sql.replace(with_clause_inline_view.identifier, replacement_sql)
|
|
if pretty:
|
|
sql = self.make_pretty_sql(sql)
|
|
return sql
|
|
|
|
def _write_data_type_metaclass(self, data_type_class):
|
|
'''Write a data type class such as Int or Boolean.'''
|
|
if issubclass(data_type_class, Int):
|
|
return 'INTEGER'
|
|
if issubclass(data_type_class, Float):
|
|
return 'DECIMAL(65, 15)'
|
|
if issubclass(data_type_class, VarChar):
|
|
return 'CHAR'
|
|
if hasattr(data_type_class, 'MYSQL'):
|
|
return data_type_class.MYSQL[0]
|
|
return data_type_class.__name__.upper()
|
|
|
|
def _write_data_type(self, data_type):
|
|
'''Write a literal value.'''
|
|
if data_type.returns_timestamp:
|
|
return "CAST('{0}' AS DATETIME)".format(data_type.val)
|
|
if data_type.returns_boolean:
|
|
# MySQL will error if a data_type "FALSE" is used as a GROUP BY field
|
|
return '(0 = 0)' if data_type.val else '(1 = 0)'
|
|
return SqlWriter._write_data_type(self, data_type)
|
|
|
|
def _write_order_by_clause(self, order_by_clause):
|
|
sql = 'ORDER BY '
|
|
for idx, (expr, order) in enumerate(order_by_clause.exprs_to_order):
|
|
if idx > 0:
|
|
sql += ', '
|
|
sql += 'ISNULL({0}), {0}'.format(self._write(expr))
|
|
if order:
|
|
sql += ' ' + order
|
|
nulls_order = self.get_nulls_order(order)
|
|
if (nulls_order is not None):
|
|
sql += ' ' + nulls_order
|
|
return sql
|