Protect against SQL injections by using tree comparisons (#3109)

* add SQLQuery class with tests for safe queries and non-safe tautology attacks

* add test for union query injections

* split .apply calls to newline

* add tests for comment attacks

* remove double underscore

* extract complex children check to variable

* inherit from object because I'm not a lamer

Co-Authored-By: rauchy <omer@rauchy.net>

* simplify cognitive complexity

* check that additional columns are not injected

* detect appended queries

* inline .apply calls

* move SQLQuery to it's own module

* move SQLQuery tests to their own module

* serialize SQLQuery instances

* raise an exception when attempting to serialize an unsafe query

* queries without parameters are safe

* remove redundant parentheses

* use cached properties

* rename SQLInjectionException to SQLInjectionError

* support multiple word params and param negations

* refactor out methods that don't involve any state

* don't cache text()

* reduce cognitive complexity
This commit is contained in:
Omer Lachish
2018-12-02 21:51:06 +02:00
committed by Arik Fraimovich
parent 463d4ce518
commit 9579f12a83
3 changed files with 165 additions and 0 deletions

71
redash/utils/sql_query.py Normal file
View File

@@ -0,0 +1,71 @@
import re
import sqlparse
from redash.utils import mustache_render
def _replace_params(template):
return re.sub('-?{{.+?}}', 'param', template)
def _inside_a_where_clause(a):
if a is None:
return False
else:
return type(a.parent) is sqlparse.sql.Where or _inside_a_where_clause(a.parent)
def _populating_an_in_operator(a, b):
if type(a) is sqlparse.sql.Identifier and \
type(b) is sqlparse.sql.IdentifierList and \
_inside_a_where_clause(a):
return True
def _equivalent_leaves(a, b):
return type(a) == type(b) or \
(type(a) is sqlparse.sql.Identifier and type(b) is sqlparse.sql.Token)
def _filter_noise(tokens):
skippable_tokens = [sqlparse.tokens.Error, sqlparse.tokens.Whitespace]
return [t for t in tokens if t.ttype not in skippable_tokens]
def _same_type(a, b):
if _populating_an_in_operator(a, b):
return True
elif type(a) in (list, tuple):
children_are_same = [_same_type(child_a, child_b) for (child_a, child_b) in zip(a, b)]
return len(a) == len(b) and all(children_are_same)
elif (hasattr(a, 'tokens') and hasattr(b, 'tokens')):
return _same_type(_filter_noise(a.tokens), _filter_noise(b.tokens))
else:
return _equivalent_leaves(a, b)
class SQLQuery(object):
def __init__(self, template):
self.template = template
self.query = template
def apply(self, parameters):
self.query = mustache_render(self.template, parameters)
return self
def is_safe(self):
template_tree = sqlparse.parse(_replace_params(self.template))
query_tree = sqlparse.parse(self.query)
return _same_type(template_tree, query_tree)
@property
def text(self):
if not self.is_safe():
raise SQLInjectionError()
else:
return self.query
class SQLInjectionError(Exception):
pass

0
tests/utils/__init__.py Normal file
View File

View File

@@ -0,0 +1,94 @@
from unittest import TestCase
from redash.utils.sql_query import SQLInjectionError, SQLQuery
class TestSQLQuery(TestCase):
def test_serializes(self):
query = SQLQuery("SELECT * FROM users WHERE userid='{{userid}}'").apply({
"userid": 22
})
self.assertEqual(query.text, "SELECT * FROM users WHERE userid='22'")
def test_raises_when_serializing_unsafe_queries(self):
query = SQLQuery("SELECT * FROM users WHERE userid={{userid}}").apply({
"userid": "22 OR 1==1"
})
self.assertRaises(SQLInjectionError, getattr, query, 'text')
def test_marks_queries_without_params_as_safe(self):
query = SQLQuery("SELECT * FROM users")
self.assertTrue(query.is_safe())
def test_marks_simple_queries_with_where_params_as_safe(self):
query = SQLQuery("SELECT * FROM users WHERE userid='{{userid}}'").apply({
"userid": 22
})
self.assertTrue(query.is_safe())
def test_marks_simple_queries_with_column_params_as_safe(self):
query = SQLQuery("SELECT {{this_column}} FROM users").apply({
"this_column": "username"
})
self.assertTrue(query.is_safe())
def test_marks_multiple_simple_queries_as_safe(self):
query = SQLQuery("SELECT * FROM users WHERE userid='{{userid}}' ; SELECT * FROM profiles").apply({
"userid": 22
})
self.assertTrue(query.is_safe())
def test_marks_tautologies_as_not_safe(self):
query = SQLQuery("SELECT * FROM users WHERE userid={{userid}}").apply({
"userid": "22 OR 1==1"
})
self.assertFalse(query.is_safe())
def test_marks_union_queries_as_not_safe(self):
query = SQLQuery("SELECT * FROM users WHERE userid={{userid}}").apply({
"userid": "22 UNION SELECT body, results, 1 FROM reports"
})
self.assertFalse(query.is_safe())
def test_marks_comment_attacks_as_not_safe(self):
query = SQLQuery("SELECT * FROM users WHERE username='{{username}}' AND password='{{password}}'").apply({
"username": "admin' --"
})
self.assertFalse(query.is_safe())
def test_marks_additional_columns_as_not_safe(self):
query = SQLQuery("SELECT {{this_column}} FROM users").apply({
"this_column": "username, password"
})
self.assertFalse(query.is_safe())
def test_marks_query_additions_as_not_safe(self):
query = SQLQuery("SELECT * FROM users ORDER BY {{this_column}}").apply({
"this_column": "id ; DROP TABLE midgets"
})
self.assertFalse(query.is_safe())
def test_marks_multiple_word_params_as_safe(self):
query = SQLQuery("SELECT {{why would you do this}} FROM users").apply({
"why would you do this": "shrug"
})
self.assertTrue(query.is_safe())
def test_marks_param_negations_as_safe(self):
query = SQLQuery("SELECT date_add(some_column, INTERVAL -{{days}} DAY) FROM events").apply({
"days": 7
})
self.assertTrue(query.is_safe())