mirror of
https://github.com/apache/impala.git
synced 2025-12-21 10:58:31 -05:00
The only thing this commit does is upgrade sqlparse. The upgrade was done by downloading and extracting the tarball, nothing else (such as patching). The older version of sqlparse would parse SELECT ' ; ' ; into two statements. Neither statement is complete due to the open quote and this would cause an infinite loop. The bug is already fixed in the newest version of sqlparse. Change-Id: I7ce7c269769ae0cde3dc8ca386d0b0e11bea71c1 Reviewed-on: http://gerrit.cloudera.org:8080/102 Reviewed-by: Casey Ching <casey@cloudera.com> Tested-by: Internal Jenkins
141 lines
4.9 KiB
Python
141 lines
4.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
# Tests splitting functions.
|
|
|
|
import unittest
|
|
|
|
from tests.utils import load_file, TestCaseBase
|
|
|
|
import sqlparse
|
|
|
|
|
|
class SQLSplitTest(TestCaseBase):
|
|
"""Tests sqlparse.sqlsplit()."""
|
|
|
|
_sql1 = 'select * from foo;'
|
|
_sql2 = 'select * from bar;'
|
|
|
|
def test_split_semicolon(self):
|
|
sql2 = 'select * from foo where bar = \'foo;bar\';'
|
|
stmts = sqlparse.parse(''.join([self._sql1, sql2]))
|
|
self.assertEqual(len(stmts), 2)
|
|
self.ndiffAssertEqual(unicode(stmts[0]), self._sql1)
|
|
self.ndiffAssertEqual(unicode(stmts[1]), sql2)
|
|
|
|
def test_create_function(self):
|
|
sql = load_file('function.sql')
|
|
stmts = sqlparse.parse(sql)
|
|
self.assertEqual(len(stmts), 1)
|
|
self.ndiffAssertEqual(unicode(stmts[0]), sql)
|
|
|
|
def test_create_function_psql(self):
|
|
sql = load_file('function_psql.sql')
|
|
stmts = sqlparse.parse(sql)
|
|
self.assertEqual(len(stmts), 1)
|
|
self.ndiffAssertEqual(unicode(stmts[0]), sql)
|
|
|
|
def test_create_function_psql3(self):
|
|
sql = load_file('function_psql3.sql')
|
|
stmts = sqlparse.parse(sql)
|
|
self.assertEqual(len(stmts), 1)
|
|
self.ndiffAssertEqual(unicode(stmts[0]), sql)
|
|
|
|
def test_create_function_psql2(self):
|
|
sql = load_file('function_psql2.sql')
|
|
stmts = sqlparse.parse(sql)
|
|
self.assertEqual(len(stmts), 1)
|
|
self.ndiffAssertEqual(unicode(stmts[0]), sql)
|
|
|
|
def test_dashcomments(self):
|
|
sql = load_file('dashcomment.sql')
|
|
stmts = sqlparse.parse(sql)
|
|
self.assertEqual(len(stmts), 3)
|
|
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
|
|
|
|
def test_dashcomments_eol(self):
|
|
stmts = sqlparse.parse('select foo; -- comment\n')
|
|
self.assertEqual(len(stmts), 1)
|
|
stmts = sqlparse.parse('select foo; -- comment\r')
|
|
self.assertEqual(len(stmts), 1)
|
|
stmts = sqlparse.parse('select foo; -- comment\r\n')
|
|
self.assertEqual(len(stmts), 1)
|
|
stmts = sqlparse.parse('select foo; -- comment')
|
|
self.assertEqual(len(stmts), 1)
|
|
|
|
def test_begintag(self):
|
|
sql = load_file('begintag.sql')
|
|
stmts = sqlparse.parse(sql)
|
|
self.assertEqual(len(stmts), 3)
|
|
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
|
|
|
|
def test_begintag_2(self):
|
|
sql = load_file('begintag_2.sql')
|
|
stmts = sqlparse.parse(sql)
|
|
self.assertEqual(len(stmts), 1)
|
|
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
|
|
|
|
def test_dropif(self):
|
|
sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
|
|
stmts = sqlparse.parse(sql)
|
|
self.assertEqual(len(stmts), 2)
|
|
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
|
|
|
|
def test_comment_with_umlaut(self):
|
|
sql = (u'select * from foo;\n'
|
|
u'-- Testing an umlaut: ä\n'
|
|
u'select * from bar;')
|
|
stmts = sqlparse.parse(sql)
|
|
self.assertEqual(len(stmts), 2)
|
|
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
|
|
|
|
def test_comment_end_of_line(self):
|
|
sql = ('select * from foo; -- foo\n'
|
|
'select * from bar;')
|
|
stmts = sqlparse.parse(sql)
|
|
self.assertEqual(len(stmts), 2)
|
|
self.ndiffAssertEqual(''.join(unicode(q) for q in stmts), sql)
|
|
# make sure the comment belongs to first query
|
|
self.ndiffAssertEqual(unicode(stmts[0]), 'select * from foo; -- foo\n')
|
|
|
|
def test_casewhen(self):
|
|
sql = ('SELECT case when val = 1 then 2 else null end as foo;\n'
|
|
'comment on table actor is \'The actor table.\';')
|
|
stmts = sqlparse.split(sql)
|
|
self.assertEqual(len(stmts), 2)
|
|
|
|
def test_cursor_declare(self):
|
|
sql = ('DECLARE CURSOR "foo" AS SELECT 1;\n'
|
|
'SELECT 2;')
|
|
stmts = sqlparse.split(sql)
|
|
self.assertEqual(len(stmts), 2)
|
|
|
|
def test_if_function(self): # see issue 33
|
|
# don't let IF as a function confuse the splitter
|
|
sql = ('CREATE TEMPORARY TABLE tmp '
|
|
'SELECT IF(a=1, a, b) AS o FROM one; '
|
|
'SELECT t FROM two')
|
|
stmts = sqlparse.split(sql)
|
|
self.assertEqual(len(stmts), 2)
|
|
|
|
def test_split_stream(self):
|
|
import types
|
|
from cStringIO import StringIO
|
|
|
|
stream = StringIO("SELECT 1; SELECT 2;")
|
|
stmts = sqlparse.parsestream(stream)
|
|
self.assertEqual(type(stmts), types.GeneratorType)
|
|
self.assertEqual(len(list(stmts)), 2)
|
|
|
|
def test_encoding_parsestream(self):
|
|
from cStringIO import StringIO
|
|
stream = StringIO("SELECT 1; SELECT 2;")
|
|
stmts = list(sqlparse.parsestream(stream))
|
|
self.assertEqual(type(stmts[0].tokens[0].value), unicode)
|
|
|
|
|
|
def test_split_simple():
|
|
stmts = sqlparse.split('select * from foo; select * from bar;')
|
|
assert len(stmts) == 2
|
|
assert stmts[0] == 'select * from foo;'
|
|
assert stmts[1] == 'select * from bar;'
|