mirror of
https://github.com/apache/impala.git
synced 2026-01-06 15:01:43 -05:00
sqlparse-0.1.19 is the last version of sqlparse that supports Python 2.6. Testing: - Ran all end-to-end tests Change-Id: Ide51ef3ac52d25a96b0fa832e29b6535197d23cb Reviewed-on: http://gerrit.cloudera.org:8080/10354 Reviewed-by: David Knupp <dknupp@cloudera.com> Tested-by: Impala Public Jenkins <impala-public-jenkins@cloudera.com>
347 lines
16 KiB
Python
347 lines
16 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
import pytest
|
|
|
|
from tests.utils import TestCaseBase
|
|
|
|
import sqlparse
|
|
from sqlparse.exceptions import SQLParseError
|
|
|
|
|
|
class TestFormat(TestCaseBase):
|
|
|
|
def test_keywordcase(self):
|
|
sql = 'select * from bar; -- select foo\n'
|
|
res = sqlparse.format(sql, keyword_case='upper')
|
|
self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n')
|
|
res = sqlparse.format(sql, keyword_case='capitalize')
|
|
self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n')
|
|
res = sqlparse.format(sql.upper(), keyword_case='lower')
|
|
self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n')
|
|
self.assertRaises(SQLParseError, sqlparse.format, sql,
|
|
keyword_case='foo')
|
|
|
|
def test_identifiercase(self):
|
|
sql = 'select * from bar; -- select foo\n'
|
|
res = sqlparse.format(sql, identifier_case='upper')
|
|
self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n')
|
|
res = sqlparse.format(sql, identifier_case='capitalize')
|
|
self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n')
|
|
res = sqlparse.format(sql.upper(), identifier_case='lower')
|
|
self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n')
|
|
self.assertRaises(SQLParseError, sqlparse.format, sql,
|
|
identifier_case='foo')
|
|
sql = 'select * from "foo"."bar"'
|
|
res = sqlparse.format(sql, identifier_case="upper")
|
|
self.ndiffAssertEqual(res, 'select * from "foo"."bar"')
|
|
|
|
def test_strip_comments_single(self):
|
|
sql = 'select *-- statement starts here\nfrom foo'
|
|
res = sqlparse.format(sql, strip_comments=True)
|
|
self.ndiffAssertEqual(res, 'select * from foo')
|
|
sql = 'select * -- statement starts here\nfrom foo'
|
|
res = sqlparse.format(sql, strip_comments=True)
|
|
self.ndiffAssertEqual(res, 'select * from foo')
|
|
sql = 'select-- foo\nfrom -- bar\nwhere'
|
|
res = sqlparse.format(sql, strip_comments=True)
|
|
self.ndiffAssertEqual(res, 'select from where')
|
|
self.assertRaises(SQLParseError, sqlparse.format, sql,
|
|
strip_comments=None)
|
|
|
|
def test_strip_comments_multi(self):
|
|
sql = '/* sql starts here */\nselect'
|
|
res = sqlparse.format(sql, strip_comments=True)
|
|
self.ndiffAssertEqual(res, 'select')
|
|
sql = '/* sql starts here */ select'
|
|
res = sqlparse.format(sql, strip_comments=True)
|
|
self.ndiffAssertEqual(res, 'select')
|
|
sql = '/*\n * sql starts here\n */\nselect'
|
|
res = sqlparse.format(sql, strip_comments=True)
|
|
self.ndiffAssertEqual(res, 'select')
|
|
sql = 'select (/* sql starts here */ select 2)'
|
|
res = sqlparse.format(sql, strip_comments=True)
|
|
self.ndiffAssertEqual(res, 'select (select 2)')
|
|
sql = 'select (/* sql /* starts here */ select 2)'
|
|
res = sqlparse.format(sql, strip_comments=True)
|
|
self.ndiffAssertEqual(res, 'select (select 2)')
|
|
|
|
def test_strip_ws(self):
|
|
f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
|
|
s = 'select\n* from foo\n\twhere ( 1 = 2 )\n'
|
|
self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)')
|
|
s = 'select -- foo\nfrom bar\n'
|
|
self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar')
|
|
self.assertRaises(SQLParseError, sqlparse.format, s,
|
|
strip_whitespace=None)
|
|
|
|
def test_preserve_ws(self):
|
|
# preserve at least one whitespace after subgroups
|
|
f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
|
|
s = 'select\n* /* foo */ from bar '
|
|
self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar')
|
|
|
|
def test_notransform_of_quoted_crlf(self):
|
|
# Make sure that CR/CR+LF characters inside string literals don't get
|
|
# affected by the formatter.
|
|
|
|
s1 = "SELECT some_column LIKE 'value\r'"
|
|
s2 = "SELECT some_column LIKE 'value\r'\r\nWHERE id = 1\n"
|
|
s3 = "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\r"
|
|
s4 = "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\r\n"
|
|
|
|
f = lambda x: sqlparse.format(x)
|
|
|
|
# Because of the use of
|
|
self.ndiffAssertEqual(f(s1), "SELECT some_column LIKE 'value\r'")
|
|
self.ndiffAssertEqual(f(s2), "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n")
|
|
self.ndiffAssertEqual(f(s3), "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n")
|
|
self.ndiffAssertEqual(f(s4), "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n")
|
|
|
|
def test_outputformat(self):
|
|
sql = 'select * from foo;'
|
|
self.assertRaises(SQLParseError, sqlparse.format, sql,
|
|
output_format='foo')
|
|
|
|
|
|
class TestFormatReindent(TestCaseBase):
|
|
|
|
def test_option(self):
|
|
self.assertRaises(SQLParseError, sqlparse.format, 'foo',
|
|
reindent=2)
|
|
self.assertRaises(SQLParseError, sqlparse.format, 'foo',
|
|
indent_tabs=2)
|
|
self.assertRaises(SQLParseError, sqlparse.format, 'foo',
|
|
reindent=True, indent_width='foo')
|
|
self.assertRaises(SQLParseError, sqlparse.format, 'foo',
|
|
reindent=True, indent_width=-12)
|
|
|
|
def test_stmts(self):
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'select foo; select bar'
|
|
self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar')
|
|
s = 'select foo'
|
|
self.ndiffAssertEqual(f(s), 'select foo')
|
|
s = 'select foo; -- test\n select bar'
|
|
self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar')
|
|
|
|
def test_keywords(self):
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'select * from foo union select * from bar;'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select *',
|
|
'from foo',
|
|
'union',
|
|
'select *',
|
|
'from bar;']))
|
|
|
|
def test_keywords_between(self): # issue 14
|
|
# don't break AND after BETWEEN
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'and foo between 1 and 2 and bar = 3'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['',
|
|
'and foo between 1 and 2',
|
|
'and bar = 3']))
|
|
|
|
def test_parenthesis(self):
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'select count(*) from (select * from foo);'
|
|
self.ndiffAssertEqual(f(s),
|
|
'\n'.join(['select count(*)',
|
|
'from',
|
|
' (select *',
|
|
' from foo);',
|
|
])
|
|
)
|
|
|
|
def test_where(self):
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;'
|
|
self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
|
|
'where bar = 1\n'
|
|
' and baz = 2\n'
|
|
' or bzz = 3;'))
|
|
s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);'
|
|
self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
|
|
'where bar = 1\n'
|
|
' and (baz = 2\n'
|
|
' or bzz = 3);'))
|
|
|
|
def test_join(self):
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'select * from foo join bar on 1 = 2'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select *',
|
|
'from foo',
|
|
'join bar on 1 = 2']))
|
|
s = 'select * from foo inner join bar on 1 = 2'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select *',
|
|
'from foo',
|
|
'inner join bar on 1 = 2']))
|
|
s = 'select * from foo left outer join bar on 1 = 2'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select *',
|
|
'from foo',
|
|
'left outer join bar on 1 = 2']
|
|
))
|
|
s = 'select * from foo straight_join bar on 1 = 2'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select *',
|
|
'from foo',
|
|
'straight_join bar on 1 = 2']
|
|
))
|
|
|
|
def test_identifier_list(self):
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'select foo, bar, baz from table1, table2 where 1 = 2'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select foo,',
|
|
' bar,',
|
|
' baz',
|
|
'from table1,',
|
|
' table2',
|
|
'where 1 = 2']))
|
|
s = 'select a.*, b.id from a, b'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select a.*,',
|
|
' b.id',
|
|
'from a,',
|
|
' b']))
|
|
|
|
def test_identifier_list_with_functions(self):
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = ("select 'abc' as foo, coalesce(col1, col2)||col3 as bar,"
|
|
"col3 from my_table")
|
|
self.ndiffAssertEqual(f(s), '\n'.join(
|
|
["select 'abc' as foo,",
|
|
" coalesce(col1, col2)||col3 as bar,",
|
|
" col3",
|
|
"from my_table"]))
|
|
|
|
def test_case(self):
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['case',
|
|
' when foo = 1 then 2',
|
|
' when foo = 3 then 4',
|
|
' else 5',
|
|
'end']))
|
|
|
|
def test_case2(self):
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'case(foo) when bar = 1 then 2 else 3 end'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['case(foo)',
|
|
' when bar = 1 then 2',
|
|
' else 3',
|
|
'end']))
|
|
|
|
def test_nested_identifier_list(self): # issue4
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = '(foo as bar, bar1, bar2 as bar3, b4 as b5)'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['(foo as bar,',
|
|
' bar1,',
|
|
' bar2 as bar3,',
|
|
' b4 as b5)']))
|
|
|
|
def test_duplicate_linebreaks(self): # issue3
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'select c1 -- column1\nfrom foo'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select c1 -- column1',
|
|
'from foo']))
|
|
s = 'select c1 -- column1\nfrom foo'
|
|
r = sqlparse.format(s, reindent=True, strip_comments=True)
|
|
self.ndiffAssertEqual(r, '\n'.join(['select c1',
|
|
'from foo']))
|
|
s = 'select c1\nfrom foo\norder by c1'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select c1',
|
|
'from foo',
|
|
'order by c1']))
|
|
s = 'select c1 from t1 where (c1 = 1) order by c1'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select c1',
|
|
'from t1',
|
|
'where (c1 = 1)',
|
|
'order by c1']))
|
|
|
|
def test_keywordfunctions(self): # issue36
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'select max(a) b, foo, bar'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select max(a) b,',
|
|
' foo,',
|
|
' bar']))
|
|
|
|
def test_identifier_and_functions(self): # issue45
|
|
f = lambda sql: sqlparse.format(sql, reindent=True)
|
|
s = 'select foo.bar, nvl(1) from dual'
|
|
self.ndiffAssertEqual(f(s), '\n'.join(['select foo.bar,',
|
|
' nvl(1)',
|
|
'from dual']))
|
|
|
|
|
|
class TestOutputFormat(TestCaseBase):
|
|
|
|
def test_python(self):
|
|
sql = 'select * from foo;'
|
|
f = lambda sql: sqlparse.format(sql, output_format='python')
|
|
self.ndiffAssertEqual(f(sql), "sql = 'select * from foo;'")
|
|
f = lambda sql: sqlparse.format(sql, output_format='python',
|
|
reindent=True)
|
|
self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n"
|
|
" 'from foo;')"))
|
|
|
|
def test_php(self):
|
|
sql = 'select * from foo;'
|
|
f = lambda sql: sqlparse.format(sql, output_format='php')
|
|
self.ndiffAssertEqual(f(sql), '$sql = "select * from foo;";')
|
|
f = lambda sql: sqlparse.format(sql, output_format='php',
|
|
reindent=True)
|
|
self.ndiffAssertEqual(f(sql), ('$sql = "select * ";\n'
|
|
'$sql .= "from foo;";'))
|
|
|
|
def test_sql(self): # "sql" is an allowed option but has no effect
|
|
sql = 'select * from foo;'
|
|
f = lambda sql: sqlparse.format(sql, output_format='sql')
|
|
self.ndiffAssertEqual(f(sql), 'select * from foo;')
|
|
|
|
|
|
def test_format_column_ordering(): # issue89
|
|
sql = 'select * from foo order by c1 desc, c2, c3;'
|
|
formatted = sqlparse.format(sql, reindent=True)
|
|
expected = '\n'.join(['select *',
|
|
'from foo',
|
|
'order by c1 desc,',
|
|
' c2,',
|
|
' c3;'])
|
|
assert formatted == expected
|
|
|
|
|
|
def test_truncate_strings():
|
|
sql = 'update foo set value = \'' + 'x' * 1000 + '\';'
|
|
formatted = sqlparse.format(sql, truncate_strings=10)
|
|
assert formatted == 'update foo set value = \'xxxxxxxxxx[...]\';'
|
|
formatted = sqlparse.format(sql, truncate_strings=3, truncate_char='YYY')
|
|
assert formatted == 'update foo set value = \'xxxYYY\';'
|
|
|
|
|
|
def test_truncate_strings_invalid_option():
|
|
pytest.raises(SQLParseError, sqlparse.format,
|
|
'foo', truncate_strings='bar')
|
|
pytest.raises(SQLParseError, sqlparse.format,
|
|
'foo', truncate_strings=-1)
|
|
pytest.raises(SQLParseError, sqlparse.format,
|
|
'foo', truncate_strings=0)
|
|
|
|
|
|
@pytest.mark.parametrize('sql', ['select verrrylongcolumn from foo',
|
|
'select "verrrylongcolumn" from "foo"'])
|
|
def test_truncate_strings_doesnt_truncate_identifiers(sql):
|
|
formatted = sqlparse.format(sql, truncate_strings=2)
|
|
assert formatted == sql
|
|
|
|
|
|
def test_having_produces_newline():
|
|
sql = (
|
|
'select * from foo, bar where bar.id = foo.bar_id'
|
|
' having sum(bar.value) > 100')
|
|
formatted = sqlparse.format(sql, reindent=True)
|
|
expected = [
|
|
'select *',
|
|
'from foo,',
|
|
' bar',
|
|
'where bar.id = foo.bar_id',
|
|
'having sum(bar.value) > 100'
|
|
]
|
|
assert formatted == '\n'.join(expected)
|