diff --git a/redash/query_runner/snowflake.py b/redash/query_runner/snowflake.py index bb67c20d6..e43743ce9 100644 --- a/redash/query_runner/snowflake.py +++ b/redash/query_runner/snowflake.py @@ -1,11 +1,14 @@ try: import snowflake.connector + from cryptography.hazmat.primitives.serialization import load_pem_private_key enabled = True except ImportError: enabled = False +from base64 import b64decode + from redash import __version__ from redash.query_runner import ( TYPE_BOOLEAN, @@ -43,6 +46,8 @@ class Snowflake(BaseSQLQueryRunner): "account": {"type": "string"}, "user": {"type": "string"}, "password": {"type": "string"}, + "private_key_File": {"type": "string"}, + "private_key_pwd": {"type": "string"}, "warehouse": {"type": "string"}, "database": {"type": "string"}, "region": {"type": "string", "default": "us-west"}, @@ -57,13 +62,15 @@ class Snowflake(BaseSQLQueryRunner): "account", "user", "password", + "private_key_File", + "private_key_pwd", "warehouse", "database", "region", "host", ], - "required": ["user", "password", "account", "database", "warehouse"], - "secret": ["password"], + "required": ["user", "account", "database", "warehouse"], + "secret": ["password", "private_key_File", "private_key_pwd"], "extra_options": [ "host", ], @@ -88,7 +95,7 @@ class Snowflake(BaseSQLQueryRunner): if region == "us-west": region = None - if self.configuration.__contains__("host"): + if self.configuration.get("host"): host = self.configuration.get("host") else: if region: @@ -96,14 +103,29 @@ class Snowflake(BaseSQLQueryRunner): else: host = "{}.snowflakecomputing.com".format(account) - connection = snowflake.connector.connect( - user=self.configuration["user"], - password=self.configuration["password"], - account=account, - region=region, - host=host, - application="Redash/{} (Snowflake)".format(__version__.split("-")[0]), - ) + params = { + "user": self.configuration["user"], + "account": account, + "region": region, + "host": host, + "application": "Redash/{} (Snowflake)".format(__version__.split("-")[0]), + } + + if self.configuration.get("password"): + params["password"] = self.configuration["password"] + elif self.configuration.get("private_key_File"): + private_key_b64 = self.configuration.get("private_key_File") + private_key_bytes = b64decode(private_key_b64) + if self.configuration.get("private_key_pwd"): + private_key_pwd = self.configuration.get("private_key_pwd").encode() + else: + private_key_pwd = None + private_key_pem = load_pem_private_key(private_key_bytes, private_key_pwd) + params["private_key"] = private_key_pem + else: + raise Exception("Neither password nor private_key_b64 is set.") + + connection = snowflake.connector.connect(**params) return connection