format code by black and isort (#6167)

Signed-off-by: Ye Sijun <junnplus@gmail.com>
This commit is contained in:
Jun
2023-07-11 18:13:54 +09:00
committed by GitHub
parent 7f40837d3f
commit 9b2f635692
219 changed files with 2553 additions and 4222 deletions

View File

@@ -16,8 +16,10 @@ jobs:
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: '3.8' python-version: '3.8'
- run: sudo pip install flake8 - run: sudo pip install flake8 black isort
- run: ./bin/flake8_tests.sh - run: flake8 .
- run: black --check .
- run: isort --check-only --diff .
backend-unit-tests: backend-unit-tests:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04

View File

@@ -5,5 +5,5 @@ set -o errexit # fail the build if any task fails
flake8 --version ; pip --version flake8 --version ; pip --version
# stop the build if there are Python syntax errors or undefined names # stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics

View File

@@ -1,35 +1,44 @@
#!/bin/env python3 #!/bin/env python3
import sys
import re import re
import subprocess import subprocess
import sys
def get_change_log(previous_sha): def get_change_log(previous_sha):
args = ['git', '--no-pager', 'log', '--merges', '--grep', 'Merge pull request', '--pretty=format:"%h|%s|%b|%p"', 'master...{}'.format(previous_sha)] args = [
"git",
"--no-pager",
"log",
"--merges",
"--grep",
"Merge pull request",
'--pretty=format:"%h|%s|%b|%p"',
"master...{}".format(previous_sha),
]
log = subprocess.check_output(args) log = subprocess.check_output(args)
changes = [] changes = []
for line in log.split('\n'): for line in log.split("\n"):
try: try:
sha, subject, body, parents = line[1:-1].split('|') sha, subject, body, parents = line[1:-1].split("|")
except ValueError: except ValueError:
continue continue
try: try:
pull_request = re.match("Merge pull request #(\d+)", subject).groups()[0] pull_request = re.match(r"Merge pull request #(\d+)", subject).groups()[0]
pull_request = " #{}".format(pull_request) pull_request = " #{}".format(pull_request)
except Exception as ex: except Exception:
pull_request = "" pull_request = ""
author = subprocess.check_output(['git', 'log', '-1', '--pretty=format:"%an"', parents.split(' ')[-1]])[1:-1] author = subprocess.check_output(["git", "log", "-1", '--pretty=format:"%an"', parents.split(" ")[-1]])[1:-1]
changes.append("{}{}: {} ({})".format(sha, pull_request, body.strip(), author)) changes.append("{}{}: {} ({})".format(sha, pull_request, body.strip(), author))
return changes return changes
if __name__ == '__main__': if __name__ == "__main__":
previous_sha = sys.argv[1] previous_sha = sys.argv[1]
changes = get_change_log(previous_sha) changes = get_change_log(previous_sha)

View File

@@ -1,17 +1,19 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import sys
import re import re
import subprocess import subprocess
import sys
import requests import requests
import simplejson import simplejson
github_token = os.environ['GITHUB_TOKEN'] github_token = os.environ["GITHUB_TOKEN"]
auth = (github_token, 'x-oauth-basic') auth = (github_token, "x-oauth-basic")
repo = 'getredash/redash' repo = "getredash/redash"
def _github_request(method, path, params=None, headers={}): def _github_request(method, path, params=None, headers={}):
if not path.startswith('https://api.github.com'): if not path.startswith("https://api.github.com"):
url = "https://api.github.com/{}".format(path) url = "https://api.github.com/{}".format(path)
else: else:
url = path url = path
@@ -22,15 +24,18 @@ def _github_request(method, path, params=None, headers={}):
response = requests.request(method, url, data=params, auth=auth) response = requests.request(method, url, data=params, auth=auth)
return response return response
def exception_from_error(message, response): def exception_from_error(message, response):
return Exception("({}) {}: {}".format(response.status_code, message, response.json().get('message', '?'))) return Exception("({}) {}: {}".format(response.status_code, message, response.json().get("message", "?")))
def rc_tag_name(version): def rc_tag_name(version):
return "v{}-rc".format(version) return "v{}-rc".format(version)
def get_rc_release(version): def get_rc_release(version):
tag = rc_tag_name(version) tag = rc_tag_name(version)
response = _github_request('get', 'repos/{}/releases/tags/{}'.format(repo, tag)) response = _github_request("get", "repos/{}/releases/tags/{}".format(repo, tag))
if response.status_code == 404: if response.status_code == 404:
return None return None
@@ -39,84 +44,101 @@ def get_rc_release(version):
raise exception_from_error("Unknown error while looking RC release: ", response) raise exception_from_error("Unknown error while looking RC release: ", response)
def create_release(version, commit_sha): def create_release(version, commit_sha):
tag = rc_tag_name(version) tag = rc_tag_name(version)
params = { params = {
'tag_name': tag, "tag_name": tag,
'name': "{} - RC".format(version), "name": "{} - RC".format(version),
'target_commitish': commit_sha, "target_commitish": commit_sha,
'prerelease': True "prerelease": True,
} }
response = _github_request('post', 'repos/{}/releases'.format(repo), params) response = _github_request("post", "repos/{}/releases".format(repo), params)
if response.status_code != 201: if response.status_code != 201:
raise exception_from_error("Failed creating new release", response) raise exception_from_error("Failed creating new release", response)
return response.json() return response.json()
def upload_asset(release, filepath): def upload_asset(release, filepath):
upload_url = release['upload_url'].replace('{?name,label}', '') upload_url = release["upload_url"].replace("{?name,label}", "")
filename = filepath.split('/')[-1] filename = filepath.split("/")[-1]
with open(filepath) as file_content: with open(filepath) as file_content:
headers = {'Content-Type': 'application/gzip'} headers = {"Content-Type": "application/gzip"}
response = requests.post(upload_url, file_content, params={'name': filename}, headers=headers, auth=auth, verify=False) response = requests.post(
upload_url, file_content, params={"name": filename}, headers=headers, auth=auth, verify=False
)
if response.status_code != 201: # not 200/201/... if response.status_code != 201: # not 200/201/...
raise exception_from_error('Failed uploading asset', response) raise exception_from_error("Failed uploading asset", response)
return response return response
def remove_previous_builds(release): def remove_previous_builds(release):
for asset in release['assets']: for asset in release["assets"]:
response = _github_request('delete', asset['url']) response = _github_request("delete", asset["url"])
if response.status_code != 204: if response.status_code != 204:
raise exception_from_error("Failed deleting asset", response) raise exception_from_error("Failed deleting asset", response)
def get_changelog(commit_sha): def get_changelog(commit_sha):
latest_release = _github_request('get', 'repos/{}/releases/latest'.format(repo)) latest_release = _github_request("get", "repos/{}/releases/latest".format(repo))
if latest_release.status_code != 200: if latest_release.status_code != 200:
raise exception_from_error('Failed getting latest release', latest_release) raise exception_from_error("Failed getting latest release", latest_release)
latest_release = latest_release.json() latest_release = latest_release.json()
previous_sha = latest_release['target_commitish'] previous_sha = latest_release["target_commitish"]
args = ['git', '--no-pager', 'log', '--merges', '--grep', 'Merge pull request', '--pretty=format:"%h|%s|%b|%p"', '{}...{}'.format(previous_sha, commit_sha)] args = [
"git",
"--no-pager",
"log",
"--merges",
"--grep",
"Merge pull request",
'--pretty=format:"%h|%s|%b|%p"',
"{}...{}".format(previous_sha, commit_sha),
]
log = subprocess.check_output(args) log = subprocess.check_output(args)
changes = ["Changes since {}:".format(latest_release['name'])] changes = ["Changes since {}:".format(latest_release["name"])]
for line in log.split('\n'): for line in log.split("\n"):
try: try:
sha, subject, body, parents = line[1:-1].split('|') sha, subject, body, parents = line[1:-1].split("|")
except ValueError: except ValueError:
continue continue
try: try:
pull_request = re.match("Merge pull request #(\d+)", subject).groups()[0] pull_request = re.match(r"Merge pull request #(\d+)", subject).groups()[0]
pull_request = " #{}".format(pull_request) pull_request = " #{}".format(pull_request)
except Exception as ex: except Exception:
pull_request = "" pull_request = ""
author = subprocess.check_output(['git', 'log', '-1', '--pretty=format:"%an"', parents.split(' ')[-1]])[1:-1] author = subprocess.check_output(["git", "log", "-1", '--pretty=format:"%an"', parents.split(" ")[-1]])[1:-1]
changes.append("{}{}: {} ({})".format(sha, pull_request, body.strip(), author)) changes.append("{}{}: {} ({})".format(sha, pull_request, body.strip(), author))
return "\n".join(changes) return "\n".join(changes)
def update_release_commit_sha(release, commit_sha): def update_release_commit_sha(release, commit_sha):
params = { params = {
'target_commitish': commit_sha, "target_commitish": commit_sha,
} }
response = _github_request('patch', 'repos/{}/releases/{}'.format(repo, release['id']), params) response = _github_request("patch", "repos/{}/releases/{}".format(repo, release["id"]), params)
if response.status_code != 200: if response.status_code != 200:
raise exception_from_error("Failed updating commit sha for existing release", response) raise exception_from_error("Failed updating commit sha for existing release", response)
return response.json() return response.json()
def update_release(version, build_filepath, commit_sha): def update_release(version, build_filepath, commit_sha):
try: try:
release = get_rc_release(version) release = get_rc_release(version)
@@ -125,21 +147,22 @@ def update_release(version, build_filepath, commit_sha):
else: else:
release = create_release(version, commit_sha) release = create_release(version, commit_sha)
print("Using release id: {}".format(release['id'])) print("Using release id: {}".format(release["id"]))
remove_previous_builds(release) remove_previous_builds(release)
response = upload_asset(release, build_filepath) response = upload_asset(release, build_filepath)
changelog = get_changelog(commit_sha) changelog = get_changelog(commit_sha)
response = _github_request('patch', release['url'], {'body': changelog}) response = _github_request("patch", release["url"], {"body": changelog})
if response.status_code != 200: if response.status_code != 200:
raise exception_from_error("Failed updating release description", response) raise exception_from_error("Failed updating release description", response)
except Exception as ex: except Exception as ex:
print(ex) print(ex)
if __name__ == '__main__':
if __name__ == "__main__":
commit_sha = sys.argv[1] commit_sha = sys.argv[1]
version = sys.argv[2] version = sys.argv[2]
filepath = sys.argv[3] filepath = sys.argv[3]

View File

@@ -1,9 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import urllib
import argparse import argparse
import os import os
import subprocess import subprocess
import sys import sys
import urllib
from collections import namedtuple from collections import namedtuple
from fnmatch import fnmatch from fnmatch import fnmatch
@@ -15,8 +15,8 @@ except ImportError:
print("Missing required library: semver.") print("Missing required library: semver.")
exit(1) exit(1)
REDASH_HOME = os.environ.get('REDASH_HOME', '/opt/redash') REDASH_HOME = os.environ.get("REDASH_HOME", "/opt/redash")
CURRENT_VERSION_PATH = '{}/current'.format(REDASH_HOME) CURRENT_VERSION_PATH = "{}/current".format(REDASH_HOME)
def run(cmd, cwd=None): def run(cmd, cwd=None):
@@ -27,11 +27,11 @@ def run(cmd, cwd=None):
def confirm(question): def confirm(question):
reply = str(input(question + ' (y/n): ')).lower().strip() reply = str(input(question + " (y/n): ")).lower().strip()
if reply[0] == 'y': if reply[0] == "y":
return True return True
if reply[0] == 'n': if reply[0] == "n":
return False return False
else: else:
return confirm("Please use 'y' or 'n'") return confirm("Please use 'y' or 'n'")
@@ -40,7 +40,8 @@ def confirm(question):
def version_path(version_name): def version_path(version_name):
return "{}/{}".format(REDASH_HOME, version_name) return "{}/{}".format(REDASH_HOME, version_name)
END_CODE = '\033[0m'
END_CODE = "\033[0m"
def colored_string(text, color): def colored_string(text, color):
@@ -51,60 +52,62 @@ def colored_string(text, color):
def h1(text): def h1(text):
print(colored_string(text, '\033[4m\033[1m')) print(colored_string(text, "\033[4m\033[1m"))
def green(text): def green(text):
print(colored_string(text, '\033[92m')) print(colored_string(text, "\033[92m"))
def red(text): def red(text):
print(colored_string(text, '\033[91m')) print(colored_string(text, "\033[91m"))
class Release(namedtuple('Release', ('version', 'download_url', 'filename', 'description'))): class Release(namedtuple("Release", ("version", "download_url", "filename", "description"))):
def v1_or_newer(self): def v1_or_newer(self):
return semver.compare(self.version, '1.0.0-alpha') >= 0 return semver.compare(self.version, "1.0.0-alpha") >= 0
def is_newer(self, version): def is_newer(self, version):
return semver.compare(self.version, version) > 0 return semver.compare(self.version, version) > 0
@property @property
def version_name(self): def version_name(self):
return self.filename.replace('.tar.gz', '') return self.filename.replace(".tar.gz", "")
def get_latest_release_from_ci(): def get_latest_release_from_ci():
response = requests.get('https://circleci.com/api/v1.1/project/github/getredash/redash/latest/artifacts?branch=master') response = requests.get(
"https://circleci.com/api/v1.1/project/github/getredash/redash/latest/artifacts?branch=master"
)
if response.status_code != 200: if response.status_code != 200:
exit("Failed getting releases (status code: %s)." % response.status_code) exit("Failed getting releases (status code: %s)." % response.status_code)
tarball_asset = filter(lambda asset: asset['url'].endswith('.tar.gz'), response.json())[0] tarball_asset = filter(lambda asset: asset["url"].endswith(".tar.gz"), response.json())[0]
filename = urllib.unquote(tarball_asset['pretty_path'].split('/')[-1]) filename = urllib.unquote(tarball_asset["pretty_path"].split("/")[-1])
version = filename.replace('redash.', '').replace('.tar.gz', '') version = filename.replace("redash.", "").replace(".tar.gz", "")
release = Release(version, tarball_asset['url'], filename, '') release = Release(version, tarball_asset["url"], filename, "")
return release return release
def get_release(channel): def get_release(channel):
if channel == 'ci': if channel == "ci":
return get_latest_release_from_ci() return get_latest_release_from_ci()
response = requests.get('https://version.redash.io/api/releases?channel={}'.format(channel)) response = requests.get("https://version.redash.io/api/releases?channel={}".format(channel))
release = response.json()[0] release = response.json()[0]
filename = release['download_url'].split('/')[-1] filename = release["download_url"].split("/")[-1]
release = Release(release['version'], release['download_url'], filename, release['description']) release = Release(release["version"], release["download_url"], filename, release["description"])
return release return release
def link_to_current(version_name): def link_to_current(version_name):
green("Linking to current version...") green("Linking to current version...")
run('ln -nfs {} {}'.format(version_path(version_name), CURRENT_VERSION_PATH)) run("ln -nfs {} {}".format(version_path(version_name), CURRENT_VERSION_PATH))
def restart_services(): def restart_services():
@@ -113,25 +116,25 @@ def restart_services():
# directory. # directory.
green("Restarting...") green("Restarting...")
try: try:
run('sudo /etc/init.d/redash_supervisord restart') run("sudo /etc/init.d/redash_supervisord restart")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
run('sudo service supervisor restart') run("sudo service supervisor restart")
def update_requirements(version_name): def update_requirements(version_name):
green("Installing new Python packages (if needed)...") green("Installing new Python packages (if needed)...")
new_requirements_file = '{}/requirements.txt'.format(version_path(version_name)) new_requirements_file = "{}/requirements.txt".format(version_path(version_name))
install_requirements = False install_requirements = False
try: try:
run('diff {}/requirements.txt {}'.format(CURRENT_VERSION_PATH, new_requirements_file)) != 0 run("diff {}/requirements.txt {}".format(CURRENT_VERSION_PATH, new_requirements_file)) != 0
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
if e.returncode != 0: if e.returncode != 0:
install_requirements = True install_requirements = True
if install_requirements: if install_requirements:
run('sudo pip install -r {}'.format(new_requirements_file)) run("sudo pip install -r {}".format(new_requirements_file))
def apply_migrations(release): def apply_migrations(release):
@@ -143,8 +146,12 @@ def apply_migrations(release):
def find_migrations(version_name): def find_migrations(version_name):
current_migrations = set([f for f in os.listdir("{}/migrations".format(CURRENT_VERSION_PATH)) if fnmatch(f, '*_*.py')]) current_migrations = set(
new_migrations = sorted([f for f in os.listdir("{}/migrations".format(version_path(version_name))) if fnmatch(f, '*_*.py')]) [f for f in os.listdir("{}/migrations".format(CURRENT_VERSION_PATH)) if fnmatch(f, "*_*.py")]
)
new_migrations = sorted(
[f for f in os.listdir("{}/migrations".format(version_path(version_name))) if fnmatch(f, "*_*.py")]
)
return [m for m in new_migrations if m not in current_migrations] return [m for m in new_migrations if m not in current_migrations]
@@ -154,40 +161,45 @@ def apply_migrations_pre_v1(version_name):
if new_migrations: if new_migrations:
green("New migrations to run: ") green("New migrations to run: ")
print(', '.join(new_migrations)) print(", ".join(new_migrations))
else: else:
print("No new migrations in this version.") print("No new migrations in this version.")
if new_migrations and confirm("Apply new migrations? (make sure you have backup)"): if new_migrations and confirm("Apply new migrations? (make sure you have backup)"):
for migration in new_migrations: for migration in new_migrations:
print("Applying {}...".format(migration)) print("Applying {}...".format(migration))
run("sudo sudo -u redash PYTHONPATH=. bin/run python migrations/{}".format(migration), cwd=version_path(version_name)) run(
"sudo sudo -u redash PYTHONPATH=. bin/run python migrations/{}".format(migration),
cwd=version_path(version_name),
)
def download_and_unpack(release): def download_and_unpack(release):
directory_name = release.version_name directory_name = release.version_name
green("Downloading release tarball...") green("Downloading release tarball...")
run('sudo wget --header="Accept: application/octet-stream" -O {} {}'.format(release.filename, release.download_url)) run(
'sudo wget --header="Accept: application/octet-stream" -O {} {}'.format(release.filename, release.download_url)
)
green("Unpacking to: {}...".format(directory_name)) green("Unpacking to: {}...".format(directory_name))
run('sudo mkdir -p {}'.format(directory_name)) run("sudo mkdir -p {}".format(directory_name))
run('sudo tar -C {} -xvf {}'.format(directory_name, release.filename)) run("sudo tar -C {} -xvf {}".format(directory_name, release.filename))
green("Changing ownership to redash...") green("Changing ownership to redash...")
run('sudo chown redash {}'.format(directory_name)) run("sudo chown redash {}".format(directory_name))
green("Linking .env file...") green("Linking .env file...")
run('sudo ln -nfs {}/.env {}/.env'.format(REDASH_HOME, version_path(directory_name))) run("sudo ln -nfs {}/.env {}/.env".format(REDASH_HOME, version_path(directory_name)))
def current_version(): def current_version():
real_current_path = os.path.realpath(CURRENT_VERSION_PATH).replace('.b', '+b') real_current_path = os.path.realpath(CURRENT_VERSION_PATH).replace(".b", "+b")
return real_current_path.replace(REDASH_HOME + '/', '').replace('redash.', '') return real_current_path.replace(REDASH_HOME + "/", "").replace("redash.", "")
def verify_minimum_version(): def verify_minimum_version():
green("Current version: " + current_version()) green("Current version: " + current_version())
if semver.compare(current_version(), '0.12.0') < 0: if semver.compare(current_version(), "0.12.0") < 0:
red("You need to have Redash v0.12.0 or newer to upgrade to post v1.0.0 releases.") red("You need to have Redash v0.12.0 or newer to upgrade to post v1.0.0 releases.")
green("To upgrade to v0.12.0, run the upgrade script set to the legacy channel (--channel legacy).") green("To upgrade to v0.12.0, run the upgrade script set to the legacy channel (--channel legacy).")
exit(1) exit(1)
@@ -234,9 +246,9 @@ def deploy_release(channel):
red("Exit status: {}\nOutput:\n{}".format(e.returncode, e.output)) red("Exit status: {}\nOutput:\n{}".format(e.returncode, e.output))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--channel", help="The channel to get release from (default: stable).", default='stable') parser.add_argument("--channel", help="The channel to get release from (default: stable).", default="stable")
args = parser.parse_args() args = parser.parse_args()
deploy_release(args.channel) deploy_release(args.channel)

View File

@@ -5,5 +5,5 @@ CLI to manage redash.
from redash.cli import manager from redash.cli import manager
if __name__ == '__main__': if __name__ == "__main__":
manager() manager()

View File

@@ -1,19 +1,18 @@
from __future__ import absolute_import
import logging import logging
import os import os
import sys import sys
import redis import redis
from flask_mail import Mail
from flask_limiter import Limiter from flask_limiter import Limiter
from flask_limiter.util import get_ipaddr from flask_limiter.util import get_ipaddr
from flask_mail import Mail
from flask_migrate import Migrate from flask_migrate import Migrate
from statsd import StatsClient from statsd import StatsClient
from . import settings from redash import settings
from .app import create_app # noqa from redash.app import create_app # noqa
from .query_runner import import_query_runners from redash.destinations import import_destinations
from .destinations import import_destinations from redash.query_runner import import_query_runners
__version__ = "11.0.0-dev" __version__ = "11.0.0-dev"
@@ -48,9 +47,7 @@ redis_connection = redis.from_url(settings.REDIS_URL)
rq_redis_connection = redis.from_url(settings.RQ_REDIS_URL) rq_redis_connection = redis.from_url(settings.RQ_REDIS_URL)
mail = Mail() mail = Mail()
migrate = Migrate(compare_type=True) migrate = Migrate(compare_type=True)
statsd_client = StatsClient( statsd_client = StatsClient(host=settings.STATSD_HOST, port=settings.STATSD_PORT, prefix=settings.STATSD_PREFIX)
host=settings.STATSD_HOST, port=settings.STATSD_PORT, prefix=settings.STATSD_PREFIX
)
limiter = Limiter(key_func=get_ipaddr, storage_uri=settings.LIMITER_STORAGE) limiter = Limiter(key_func=get_ipaddr, storage_uri=settings.LIMITER_STORAGE)
import_query_runners(settings.QUERY_RUNNERS) import_query_runners(settings.QUERY_RUNNERS)

View File

@@ -1,7 +1,7 @@
from flask import Flask from flask import Flask
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
from . import settings from redash import settings
class Redash(Flask): class Redash(Flask):

View File

@@ -21,12 +21,10 @@ logger = logging.getLogger("authentication")
def get_login_url(external=False, next="/"): def get_login_url(external=False, next="/"):
if settings.MULTI_ORG and current_org == None: if settings.MULTI_ORG and current_org == None: # noqa: E711
login_url = "/" login_url = "/"
elif settings.MULTI_ORG: elif settings.MULTI_ORG:
login_url = url_for( login_url = url_for("redash.login", org_slug=current_org.slug, next=next, _external=external)
"redash.login", org_slug=current_org.slug, next=next, _external=external
)
else: else:
login_url = url_for("redash.login", next=next, _external=external) login_url = url_for("redash.login", next=next, _external=external)
@@ -69,11 +67,7 @@ def request_loader(request):
elif settings.AUTH_TYPE == "api_key": elif settings.AUTH_TYPE == "api_key":
user = api_key_load_user_from_request(request) user = api_key_load_user_from_request(request)
else: else:
logger.warning( logger.warning("Unknown authentication type ({}). Using default (HMAC).".format(settings.AUTH_TYPE))
"Unknown authentication type ({}). Using default (HMAC).".format(
settings.AUTH_TYPE
)
)
user = hmac_load_user_from_request(request) user = hmac_load_user_from_request(request)
if org_settings["auth_jwt_login_enabled"] and user is None: if org_settings["auth_jwt_login_enabled"] and user is None:
@@ -229,7 +223,7 @@ def redirect_to_login():
def logout_and_redirect_to_index(): def logout_and_redirect_to_index():
logout_user() logout_user()
if settings.MULTI_ORG and current_org == None: if settings.MULTI_ORG and current_org == None: # noqa: E711
index_url = "/" index_url = "/"
elif settings.MULTI_ORG: elif settings.MULTI_ORG:
index_url = url_for("redash.index", org_slug=current_org.slug, _external=False) index_url = url_for("redash.index", org_slug=current_org.slug, _external=False)

View File

@@ -1,13 +1,12 @@
import logging import logging
from flask import render_template from flask import render_template
from itsdangerous import URLSafeTimedSerializer
from redash import settings from redash import settings
from redash.tasks import send_mail from redash.tasks import send_mail
from redash.utils import base_url from redash.utils import base_url
# noinspection PyUnresolvedReferences
from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
serializer = URLSafeTimedSerializer(settings.SECRET_KEY) serializer = URLSafeTimedSerializer(settings.SECRET_KEY)

View File

@@ -1,18 +1,17 @@
import logging import logging
import requests import requests
from flask import redirect, url_for, Blueprint, flash, request, session from authlib.integrations.flask_client import OAuth
from flask import Blueprint, flash, redirect, request, session, url_for
from redash import models
from redash import models, settings
from redash.authentication import ( from redash.authentication import (
create_and_login_user, create_and_login_user,
logout_and_redirect_to_index,
get_next_path, get_next_path,
logout_and_redirect_to_index,
) )
from redash.authentication.org_resolving import current_org from redash.authentication.org_resolving import current_org
from authlib.integrations.flask_client import OAuth
def verify_profile(org, profile): def verify_profile(org, profile):
if org.is_public: if org.is_public:
@@ -46,9 +45,7 @@ def create_google_oauth_blueprint(app):
def get_user_profile(access_token): def get_user_profile(access_token):
headers = {"Authorization": "OAuth {}".format(access_token)} headers = {"Authorization": "OAuth {}".format(access_token)}
response = requests.get( response = requests.get("https://www.googleapis.com/oauth2/v1/userinfo", headers=headers)
"https://www.googleapis.com/oauth2/v1/userinfo", headers=headers
)
if response.status_code == 401: if response.status_code == 401:
logger.warning("Failed getting user profile (response code 401).") logger.warning("Failed getting user profile (response code 401).")
@@ -63,12 +60,9 @@ def create_google_oauth_blueprint(app):
@blueprint.route("/oauth/google", endpoint="authorize") @blueprint.route("/oauth/google", endpoint="authorize")
def login(): def login():
redirect_uri = url_for(".callback", _external=True) redirect_uri = url_for(".callback", _external=True)
next_path = request.args.get( next_path = request.args.get("next", url_for("redash.index", org_slug=session.get("org_slug")))
"next", url_for("redash.index", org_slug=session.get("org_slug"))
)
logger.debug("Callback url: %s", redirect_uri) logger.debug("Callback url: %s", redirect_uri)
logger.debug("Next is: %s", next_path) logger.debug("Next is: %s", next_path)
@@ -78,7 +72,6 @@ def create_google_oauth_blueprint(app):
@blueprint.route("/oauth/google_callback", endpoint="callback") @blueprint.route("/oauth/google_callback", endpoint="callback")
def authorized(): def authorized():
logger.debug("Authorized user inbound") logger.debug("Authorized user inbound")
resp = oauth.google.authorize_access_token() resp = oauth.google.authorize_access_token()
@@ -109,21 +102,15 @@ def create_google_oauth_blueprint(app):
profile["email"], profile["email"],
org, org,
) )
flash( flash("Your Google Apps account ({}) isn't allowed.".format(profile["email"]))
"Your Google Apps account ({}) isn't allowed.".format(profile["email"])
)
return redirect(url_for("redash.login", org_slug=org.slug)) return redirect(url_for("redash.login", org_slug=org.slug))
picture_url = "%s?sz=40" % profile["picture"] picture_url = "%s?sz=40" % profile["picture"]
user = create_and_login_user( user = create_and_login_user(org, profile["name"], profile["email"], picture_url)
org, profile["name"], profile["email"], picture_url
)
if user is None: if user is None:
return logout_and_redirect_to_index() return logout_and_redirect_to_index()
unsafe_next_path = session.get("next_url") or url_for( unsafe_next_path = session.get("next_url") or url_for("redash.index", org_slug=org.slug)
"redash.index", org_slug=org.slug
)
next_path = get_next_path(unsafe_next_path) next_path = get_next_path(unsafe_next_path)
return redirect(next_path) return redirect(next_path)

View File

@@ -1,4 +1,5 @@
import logging import logging
import jwt import jwt
import requests import requests
import simplejson import simplejson
@@ -21,9 +22,7 @@ def get_public_keys(url):
if "keys" in data: if "keys" in data:
public_keys = [] public_keys = []
for key_dict in data["keys"]: for key_dict in data["keys"]:
public_key = jwt.algorithms.RSAAlgorithm.from_jwk( public_key = jwt.algorithms.RSAAlgorithm.from_jwk(simplejson.dumps(key_dict))
simplejson.dumps(key_dict)
)
public_keys.append(public_key) public_keys.append(public_key)
get_public_keys.key_cache[url] = public_keys get_public_keys.key_cache[url] = public_keys
@@ -36,9 +35,7 @@ def get_public_keys(url):
get_public_keys.key_cache = {} get_public_keys.key_cache = {}
def verify_jwt_token( def verify_jwt_token(jwt_token, expected_issuer, expected_audience, algorithms, public_certs_url):
jwt_token, expected_issuer, expected_audience, algorithms, public_certs_url
):
# https://developers.cloudflare.com/access/setting-up-access/validate-jwt-tokens/ # https://developers.cloudflare.com/access/setting-up-access/validate-jwt-tokens/
# https://cloud.google.com/iap/docs/signed-headers-howto # https://cloud.google.com/iap/docs/signed-headers-howto
# Loop through the keys since we can't pass the key set to the decoder # Loop through the keys since we can't pass the key set to the decoder
@@ -53,9 +50,7 @@ def verify_jwt_token(
for key in keys: for key in keys:
try: try:
# decode returns the claims which has the email if you need it # decode returns the claims which has the email if you need it
payload = jwt.decode( payload = jwt.decode(jwt_token, key=key, audience=expected_audience, algorithms=algorithms)
jwt_token, key=key, audience=expected_audience, algorithms=algorithms
)
issuer = payload["iss"] issuer = payload["iss"]
if issuer != expected_issuer: if issuer != expected_issuer:
raise Exception("Wrong issuer: {}".format(issuer)) raise Exception("Wrong issuer: {}".format(issuer))

View File

@@ -1,13 +1,13 @@
import logging import logging
import sys import sys
from redash import settings from flask import Blueprint, flash, redirect, render_template, request, url_for
from flask import flash, redirect, render_template, request, url_for, Blueprint
from flask_login import current_user from flask_login import current_user
from redash import settings
try: try:
from ldap3 import Server, Connection from ldap3 import Connection, Server
except ImportError: except ImportError:
if settings.LDAP_LOGIN_ENABLED: if settings.LDAP_LOGIN_ENABLED:
sys.exit( sys.exit(
@@ -16,8 +16,8 @@ except ImportError:
from redash.authentication import ( from redash.authentication import (
create_and_login_user, create_and_login_user,
logout_and_redirect_to_index,
get_next_path, get_next_path,
logout_and_redirect_to_index,
) )
from redash.authentication.org_resolving import current_org from redash.authentication.org_resolving import current_org
from redash.handlers.base import org_scoped_rule from redash.handlers.base import org_scoped_rule

View File

@@ -1,13 +1,15 @@
import logging import logging
from flask import redirect, url_for, Blueprint, request
from flask import Blueprint, redirect, request, url_for
from redash import settings
from redash.authentication import ( from redash.authentication import (
create_and_login_user, create_and_login_user,
logout_and_redirect_to_index,
get_next_path, get_next_path,
logout_and_redirect_to_index,
) )
from redash.authentication.org_resolving import current_org from redash.authentication.org_resolving import current_org
from redash.handlers.base import org_scoped_rule from redash.handlers.base import org_scoped_rule
from redash import settings
logger = logging.getLogger("remote_user_auth") logger = logging.getLogger("remote_user_auth")
@@ -20,9 +22,7 @@ def login(org_slug=None):
next_path = get_next_path(unsafe_next_path) next_path = get_next_path(unsafe_next_path)
if not settings.REMOTE_USER_LOGIN_ENABLED: if not settings.REMOTE_USER_LOGIN_ENABLED:
logger.error( logger.error("Cannot use remote user for login without being enabled in settings")
"Cannot use remote user for login without being enabled in settings"
)
return redirect(url_for("redash.index", next=next_path, org_slug=org_slug)) return redirect(url_for("redash.index", next=next_path, org_slug=org_slug))
email = request.headers.get(settings.REMOTE_USER_HEADER) email = request.headers.get(settings.REMOTE_USER_HEADER)

View File

@@ -1,16 +1,20 @@
import logging import logging
from flask import flash, redirect, url_for, Blueprint, request
from redash import settings from flask import Blueprint, flash, redirect, request, url_for
from redash.authentication import create_and_login_user, logout_and_redirect_to_index
from redash.authentication.org_resolving import current_org
from redash.handlers.base import org_scoped_rule
from redash.utils import mustache_render
from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT, entity from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT, entity
from saml2.client import Saml2Client from saml2.client import Saml2Client
from saml2.config import Config as Saml2Config from saml2.config import Config as Saml2Config
from saml2.saml import NAMEID_FORMAT_TRANSIENT from saml2.saml import NAMEID_FORMAT_TRANSIENT
from saml2.sigver import get_xmlsec_binary from saml2.sigver import get_xmlsec_binary
from redash import settings
from redash.authentication import (
create_and_login_user,
logout_and_redirect_to_index,
)
from redash.authentication.org_resolving import current_org
from redash.handlers.base import org_scoped_rule
from redash.utils import mustache_render
logger = logging.getLogger("saml_auth") logger = logging.getLogger("saml_auth")
blueprint = Blueprint("saml_auth", __name__) blueprint = Blueprint("saml_auth", __name__)
@@ -91,6 +95,7 @@ def get_saml_client(org):
if sp_settings: if sp_settings:
import json import json
saml_settings["service"]["sp"].update(json.loads(sp_settings)) saml_settings["service"]["sp"].update(json.loads(sp_settings))
sp_config = Saml2Config() sp_config = Saml2Config()

View File

@@ -34,11 +34,7 @@ def list_command(organization=None):
if i > 0: if i > 0:
print("-" * 20) print("-" * 20)
print( print("Id: {}\nName: {}\nType: {}\nOptions: {}".format(ds.id, ds.name, ds.type, ds.options.to_json()))
"Id: {}\nName: {}\nType: {}\nOptions: {}".format(
ds.id, ds.name, ds.type, ds.options.to_json()
)
)
@manager.command(name="list_types") @manager.command(name="list_types")
@@ -76,9 +72,7 @@ def test(name, organization="default"):
data_source = models.DataSource.query.filter( data_source = models.DataSource.query.filter(
models.DataSource.name == name, models.DataSource.org == org models.DataSource.name == name, models.DataSource.org == org
).one() ).one()
print( print("Testing connection to data source: {} (id={})".format(name, data_source.id))
"Testing connection to data source: {} (id={})".format(name, data_source.id)
)
try: try:
data_source.query_runner.test_connection() data_source.query_runner.test_connection()
except Exception as e: except Exception as e:
@@ -165,11 +159,7 @@ def new(name=None, type=None, options=None, organization="default"):
print("Error: invalid configuration.") print("Error: invalid configuration.")
exit(1) exit(1)
print( print("Creating {} data source ({}) with options:\n{}".format(type, name, options.to_json()))
"Creating {} data source ({}) with options:\n{}".format(
type, name, options.to_json()
)
)
data_source = models.DataSource.create_with_group( data_source = models.DataSource.create_with_group(
name=name, name=name,

View File

@@ -1,9 +1,9 @@
import time import time
import sqlalchemy
from click import argument, option from click import argument, option
from flask.cli import AppGroup from flask.cli import AppGroup
from flask_migrate import stamp from flask_migrate import stamp
import sqlalchemy
from sqlalchemy.exc import DatabaseError from sqlalchemy.exc import DatabaseError
from sqlalchemy.sql import select from sqlalchemy.sql import select
from sqlalchemy_utils.types.encrypted.encrypted_type import FernetEngine from sqlalchemy_utils.types.encrypted.encrypted_type import FernetEngine
@@ -93,9 +93,7 @@ def reencrypt(old_secret, new_secret, show_sql):
Column("id", key_type(orm_name), primary_key=True), Column("id", key_type(orm_name), primary_key=True),
Column( Column(
"encrypted_options", "encrypted_options",
ConfigurationContainer.as_mutable( ConfigurationContainer.as_mutable(EncryptedConfiguration(db.Text, old_secret, FernetEngine)),
EncryptedConfiguration(db.Text, old_secret, FernetEngine)
),
), ),
) )
table_for_update = sqlalchemy.Table( table_for_update = sqlalchemy.Table(
@@ -104,9 +102,7 @@ def reencrypt(old_secret, new_secret, show_sql):
Column("id", key_type(orm_name), primary_key=True), Column("id", key_type(orm_name), primary_key=True),
Column( Column(
"encrypted_options", "encrypted_options",
ConfigurationContainer.as_mutable( ConfigurationContainer.as_mutable(EncryptedConfiguration(db.Text, new_secret, FernetEngine)),
EncryptedConfiguration(db.Text, new_secret, FernetEngine)
),
), ),
) )

View File

@@ -64,10 +64,7 @@ def change_permissions(group_id, permissions=None):
exit(1) exit(1)
permissions = extract_permissions_string(permissions) permissions = extract_permissions_string(permissions)
print( print("current permissions [%s] will be modify to [%s]" % (",".join(group.permissions), ",".join(permissions)))
"current permissions [%s] will be modify to [%s]"
% (",".join(group.permissions), ",".join(permissions))
)
group.permissions = permissions group.permissions = permissions

View File

@@ -17,21 +17,13 @@ def set_google_apps_domains(domains):
organization.settings[k] = domains.split(",") organization.settings[k] = domains.split(",")
models.db.session.add(organization) models.db.session.add(organization)
models.db.session.commit() models.db.session.commit()
print( print("Updated list of allowed domains to: {}".format(organization.google_apps_domains))
"Updated list of allowed domains to: {}".format(
organization.google_apps_domains
)
)
@manager.command(name="show_google_apps_domains") @manager.command(name="show_google_apps_domains")
def show_google_apps_domains(): def show_google_apps_domains():
organization = models.Organization.query.first() organization = models.Organization.query.first()
print( print("Current list of Google Apps domains: {}".format(", ".join(organization.google_apps_domains)))
"Current list of Google Apps domains: {}".format(
", ".join(organization.google_apps_domains)
)
)
@manager.command(name="list") @manager.command(name="list")

View File

@@ -1,7 +1,5 @@
from __future__ import absolute_import
import socket
import sys
import datetime import datetime
import socket
from itertools import chain from itertools import chain
from click import argument from click import argument
@@ -14,11 +12,11 @@ from supervisor_checks.check_modules import base
from redash import rq_redis_connection from redash import rq_redis_connection
from redash.tasks import ( from redash.tasks import (
Worker, periodic_job_definitions,
rq_scheduler, rq_scheduler,
schedule_periodic_jobs, schedule_periodic_jobs,
periodic_job_definitions,
) )
from redash.tasks.worker import Worker
from redash.worker import default_queues from redash.worker import default_queues
manager = AppGroup(help="RQ management commands.") manager = AppGroup(help="RQ management commands.")
@@ -55,11 +53,7 @@ class WorkerHealthcheck(base.BaseCheck):
def __call__(self, process_spec): def __call__(self, process_spec):
pid = process_spec["pid"] pid = process_spec["pid"]
all_workers = Worker.all(connection=rq_redis_connection) all_workers = Worker.all(connection=rq_redis_connection)
workers = [ workers = [w for w in all_workers if w.hostname == socket.gethostname() and w.pid == pid]
w
for w in all_workers
if w.hostname == socket.gethostname() and w.pid == pid
]
if not workers: if not workers:
self._log(f"Cannot find worker for hostname {socket.gethostname()} and pid {pid}. ==> Is healthy? False") self._log(f"Cannot find worker for hostname {socket.gethostname()} and pid {pid}. ==> Is healthy? False")
@@ -96,6 +90,4 @@ class WorkerHealthcheck(base.BaseCheck):
@manager.command() @manager.command()
def healthcheck(): def healthcheck():
return check_runner.CheckRunner( return check_runner.CheckRunner("worker_healthcheck", "worker", None, [(WorkerHealthcheck, {})]).run()
"worker_healthcheck", "worker", None, [(WorkerHealthcheck, {})]
).run()

View File

@@ -136,17 +136,13 @@ def create(
"--password", "--password",
"password", "password",
default=None, default=None,
help="Password for root user who don't use Google Auth " help="Password for root user who don't use Google Auth (leave blank for prompt).",
"(leave blank for prompt).",
) )
def create_root(email, name, google_auth=False, password=None, organization="default"): def create_root(email, name, google_auth=False, password=None, organization="default"):
""" """
Create root user. Create root user.
""" """
print( print("Creating root user (%s, %s) in organization %s..." % (email, name, organization))
"Creating root user (%s, %s) in organization %s..."
% (email, name, organization)
)
print("Login with Google Auth: %r\n" % google_auth) print("Login with Google Auth: %r\n" % google_auth)
user = models.User.query.filter(models.User.email == email).first() user = models.User.query.filter(models.User.email == email).first()
@@ -206,13 +202,9 @@ def delete(email, organization=None):
""" """
if organization: if organization:
org = models.Organization.get_by_slug(organization) org = models.Organization.get_by_slug(organization)
deleted_count = models.User.query.filter( deleted_count = models.User.query.filter(models.User.email == email, models.User.org == org.id).delete()
models.User.email == email, models.User.org == org.id
).delete()
else: else:
deleted_count = models.User.query.filter(models.User.email == email).delete( deleted_count = models.User.query.filter(models.User.email == email).delete(synchronize_session=False)
synchronize_session=False
)
models.db.session.commit() models.db.session.commit()
print("Deleted %d users." % deleted_count) print("Deleted %d users." % deleted_count)
@@ -232,9 +224,7 @@ def password(email, password, organization=None):
""" """
if organization: if organization:
org = models.Organization.get_by_slug(organization) org = models.Organization.get_by_slug(organization)
user = models.User.query.filter( user = models.User.query.filter(models.User.email == email, models.User.org == org).first()
models.User.email == email, models.User.org == org
).first()
else: else:
user = models.User.query.filter(models.User.email == email).first() user = models.User.query.filter(models.User.email == email).first()

View File

@@ -41,7 +41,7 @@ class BaseDestination(object):
"type": cls.type(), "type": cls.type(),
"icon": cls.icon(), "icon": cls.icon(),
"configuration_schema": cls.configuration_schema(), "configuration_schema": cls.configuration_schema(),
**({ "deprecated": True } if cls.deprecated else {}) **({"deprecated": True} if cls.deprecated else {}),
} }

View File

@@ -1,13 +1,12 @@
import logging import logging
import requests import requests
from redash.destinations import * from redash.destinations import BaseDestination, register
class ChatWork(BaseDestination): class ChatWork(BaseDestination):
ALERTS_DEFAULT_MESSAGE_TEMPLATE = ( ALERTS_DEFAULT_MESSAGE_TEMPLATE = "{alert_name} changed state to {new_state}.\\n{alert_url}\\n{query_url}"
"{alert_name} changed state to {new_state}.\\n{alert_url}\\n{query_url}"
)
@classmethod @classmethod
def configuration_schema(cls): def configuration_schema(cls):
@@ -33,9 +32,7 @@ class ChatWork(BaseDestination):
def notify(self, alert, query, user, new_state, app, host, options): def notify(self, alert, query, user, new_state, app, host, options):
try: try:
# Documentation: http://developer.chatwork.com/ja/endpoint_rooms.html#POST-rooms-room_id-messages # Documentation: http://developer.chatwork.com/ja/endpoint_rooms.html#POST-rooms-room_id-messages
url = "https://api.chatwork.com/v2/rooms/{room_id}/messages".format( url = "https://api.chatwork.com/v2/rooms/{room_id}/messages".format(room_id=options.get("room_id"))
room_id=options.get("room_id")
)
message = "" message = ""
if alert.custom_subject: if alert.custom_subject:
@@ -43,15 +40,9 @@ class ChatWork(BaseDestination):
if alert.custom_body: if alert.custom_body:
message += alert.custom_body message += alert.custom_body
else: else:
alert_url = "{host}/alerts/{alert_id}".format( alert_url = "{host}/alerts/{alert_id}".format(host=host, alert_id=alert.id)
host=host, alert_id=alert.id query_url = "{host}/queries/{query_id}".format(host=host, query_id=query.id)
) message_template = options.get("message_template", ChatWork.ALERTS_DEFAULT_MESSAGE_TEMPLATE)
query_url = "{host}/queries/{query_id}".format(
host=host, query_id=query.id
)
message_template = options.get(
"message_template", ChatWork.ALERTS_DEFAULT_MESSAGE_TEMPLATE
)
message += message_template.replace("\\n", "\n").format( message += message_template.replace("\\n", "\n").format(
alert_name=alert.name, alert_name=alert.name,
new_state=new_state.upper(), new_state=new_state.upper(),
@@ -65,11 +56,7 @@ class ChatWork(BaseDestination):
resp = requests.post(url, headers=headers, data=payload, timeout=5.0) resp = requests.post(url, headers=headers, data=payload, timeout=5.0)
logging.warning(resp.text) logging.warning(resp.text)
if resp.status_code != 200: if resp.status_code != 200:
logging.error( logging.error("ChatWork send ERROR. status_code => {status}".format(status=resp.status_code))
"ChatWork send ERROR. status_code => {status}".format(
status=resp.status_code
)
)
except Exception: except Exception:
logging.exception("ChatWork send ERROR.") logging.exception("ChatWork send ERROR.")

View File

@@ -1,8 +1,9 @@
import logging import logging
from flask_mail import Message from flask_mail import Message
from redash import mail, settings from redash import mail, settings
from redash.destinations import * from redash.destinations import BaseDestination, register
class Email(BaseDestination): class Email(BaseDestination):
@@ -27,9 +28,7 @@ class Email(BaseDestination):
return "fa-envelope" return "fa-envelope"
def notify(self, alert, query, user, new_state, app, host, options): def notify(self, alert, query, user, new_state, app, host, options):
recipients = [ recipients = [email for email in options.get("addresses", "").split(",") if email]
email for email in options.get("addresses", "").split(",") if email
]
if not recipients: if not recipients:
logging.warning("No emails given. Skipping send.") logging.warning("No emails given. Skipping send.")
@@ -50,9 +49,7 @@ class Email(BaseDestination):
if alert.custom_subject: if alert.custom_subject:
subject = alert.custom_subject subject = alert.custom_subject
else: else:
subject_template = options.get( subject_template = options.get("subject_template", settings.ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE)
"subject_template", settings.ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE
)
subject = subject_template.format(alert_name=alert.name, state=state) subject = subject_template.format(alert_name=alert.name, state=state)
message = Message(recipients=recipients, subject=subject, html=html) message = Message(recipients=recipients, subject=subject, html=html)

View File

@@ -1,7 +1,8 @@
import logging import logging
import requests import requests
from redash.destinations import * from redash.destinations import BaseDestination, register
from redash.utils import json_dumps from redash.utils import json_dumps
@@ -43,9 +44,7 @@ class HangoutsChat(BaseDestination):
elif new_state == "ok": elif new_state == "ok":
message = '<font color="#27ae60">Went back to normal</font>' message = '<font color="#27ae60">Went back to normal</font>'
else: else:
message = ( message = "Unable to determine status. Check Query and Alert configuration."
"Unable to determine status. Check Query and Alert configuration."
)
if alert.custom_subject: if alert.custom_subject:
title = alert.custom_subject title = alert.custom_subject
@@ -56,17 +55,13 @@ class HangoutsChat(BaseDestination):
"cards": [ "cards": [
{ {
"header": {"title": title}, "header": {"title": title},
"sections": [ "sections": [{"widgets": [{"textParagraph": {"text": message}}]}],
{"widgets": [{"textParagraph": {"text": message}}]}
],
} }
] ]
} }
if alert.custom_body: if alert.custom_body:
data["cards"][0]["sections"].append( data["cards"][0]["sections"].append({"widgets": [{"textParagraph": {"text": alert.custom_body}}]})
{"widgets": [{"textParagraph": {"text": alert.custom_body}}]}
)
if options.get("icon_url"): if options.get("icon_url"):
data["cards"][0]["header"]["imageUrl"] = options.get("icon_url") data["cards"][0]["header"]["imageUrl"] = options.get("icon_url")
@@ -81,9 +76,7 @@ class HangoutsChat(BaseDestination):
"text": "OPEN QUERY", "text": "OPEN QUERY",
"onClick": { "onClick": {
"openLink": { "openLink": {
"url": "{host}/queries/{query_id}".format( "url": "{host}/queries/{query_id}".format(host=host, query_id=query.id)
host=host, query_id=query.id
)
} }
}, },
} }
@@ -93,15 +86,9 @@ class HangoutsChat(BaseDestination):
) )
headers = {"Content-Type": "application/json; charset=UTF-8"} headers = {"Content-Type": "application/json; charset=UTF-8"}
resp = requests.post( resp = requests.post(options.get("url"), data=json_dumps(data), headers=headers, timeout=5.0)
options.get("url"), data=json_dumps(data), headers=headers, timeout=5.0
)
if resp.status_code != 200: if resp.status_code != 200:
logging.error( logging.error("webhook send ERROR. status_code => {status}".format(status=resp.status_code))
"webhook send ERROR. status_code => {status}".format(
status=resp.status_code
)
)
except Exception: except Exception:
logging.exception("webhook send ERROR.") logging.exception("webhook send ERROR.")

View File

@@ -1,10 +1,10 @@
import logging import logging
import requests import requests
from redash.destinations import * from redash.destinations import BaseDestination, register
from redash.models import Alert from redash.models import Alert
from redash.utils import json_dumps, deprecated from redash.utils import deprecated, json_dumps
colors = { colors = {
Alert.OK_STATE: "green", Alert.OK_STATE: "green",
@@ -47,14 +47,10 @@ class HipChat(BaseDestination):
data = {"message": message, "color": colors.get(new_state, "green")} data = {"message": message, "color": colors.get(new_state, "green")}
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
response = requests.post( response = requests.post(options["url"], data=json_dumps(data), headers=headers, timeout=5.0)
options["url"], data=json_dumps(data), headers=headers, timeout=5.0
)
if response.status_code != 204: if response.status_code != 204:
logging.error( logging.error("Bad status code received from HipChat: %d", response.status_code)
"Bad status code received from HipChat: %d", response.status_code
)
except Exception: except Exception:
logging.exception("HipChat Send ERROR.") logging.exception("HipChat Send ERROR.")

View File

@@ -1,7 +1,8 @@
import logging import logging
import requests import requests
from redash.destinations import * from redash.destinations import BaseDestination, register
from redash.utils import json_dumps from redash.utils import json_dumps
@@ -16,7 +17,7 @@ class Mattermost(BaseDestination):
"icon_url": {"type": "string", "title": "Icon (URL)"}, "icon_url": {"type": "string", "title": "Icon (URL)"},
"channel": {"type": "string", "title": "Channel"}, "channel": {"type": "string", "title": "Channel"},
}, },
"secret": "url" "secret": "url",
} }
@classmethod @classmethod
@@ -33,9 +34,7 @@ class Mattermost(BaseDestination):
payload = {"text": text} payload = {"text": text}
if alert.custom_body: if alert.custom_body:
payload["attachments"] = [ payload["attachments"] = [{"fields": [{"title": "Description", "value": alert.custom_body}]}]
{"fields": [{"title": "Description", "value": alert.custom_body}]}
]
if options.get("username"): if options.get("username"):
payload["username"] = options.get("username") payload["username"] = options.get("username")
@@ -45,17 +44,11 @@ class Mattermost(BaseDestination):
payload["channel"] = options.get("channel") payload["channel"] = options.get("channel")
try: try:
resp = requests.post( resp = requests.post(options.get("url"), data=json_dumps(payload), timeout=5.0)
options.get("url"), data=json_dumps(payload), timeout=5.0
)
logging.warning(resp.text) logging.warning(resp.text)
if resp.status_code != 200: if resp.status_code != 200:
logging.error( logging.error("Mattermost webhook send ERROR. status_code => {status}".format(status=resp.status_code))
"Mattermost webhook send ERROR. status_code => {status}".format(
status=resp.status_code
)
)
except Exception: except Exception:
logging.exception("Mattermost webhook send ERROR.") logging.exception("Mattermost webhook send ERROR.")

View File

@@ -1,10 +1,10 @@
import logging import logging
import requests
from string import Template from string import Template
from redash.destinations import * import requests
from redash.destinations import BaseDestination, register
from redash.utils import json_dumps from redash.utils import json_dumps
from redash.serializers import serialize_alert
def json_string_substitute(j, substitutions): def json_string_substitute(j, substitutions):
@@ -26,30 +26,26 @@ def json_string_substitute(j, substitutions):
class MicrosoftTeamsWebhook(BaseDestination): class MicrosoftTeamsWebhook(BaseDestination):
ALERTS_DEFAULT_MESSAGE_TEMPLATE = json_dumps({ ALERTS_DEFAULT_MESSAGE_TEMPLATE = json_dumps(
"@type": "MessageCard", {
"@context": "http://schema.org/extensions", "@type": "MessageCard",
"themeColor": "0076D7", "@context": "http://schema.org/extensions",
"summary": "A Redash Alert was Triggered", "themeColor": "0076D7",
"sections": [{ "summary": "A Redash Alert was Triggered",
"activityTitle": "A Redash Alert was Triggered", "sections": [
"facts": [{ {
"name": "Alert Name", "activityTitle": "A Redash Alert was Triggered",
"value": "{alert_name}" "facts": [
}, { {"name": "Alert Name", "value": "{alert_name}"},
"name": "Alert URL", {"name": "Alert URL", "value": "{alert_url}"},
"value": "{alert_url}" {"name": "Query", "value": "{query_text}"},
}, { {"name": "Query URL", "value": "{query_url}"},
"name": "Query", ],
"value": "{query_text}" "markdown": True,
}, { }
"name": "Query URL", ],
"value": "{query_url}" }
}], )
"markdown": True
}]
})
@classmethod @classmethod
def name(cls): def name(cls):
@@ -64,10 +60,7 @@ class MicrosoftTeamsWebhook(BaseDestination):
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
"url": { "url": {"type": "string", "title": "Microsoft Teams Webhook URL"},
"type": "string",
"title": "Microsoft Teams Webhook URL"
},
"message_template": { "message_template": {
"type": "string", "type": "string",
"default": MicrosoftTeamsWebhook.ALERTS_DEFAULT_MESSAGE_TEMPLATE, "default": MicrosoftTeamsWebhook.ALERTS_DEFAULT_MESSAGE_TEMPLATE,
@@ -86,26 +79,23 @@ class MicrosoftTeamsWebhook(BaseDestination):
:type app: redash.Redash :type app: redash.Redash
""" """
try: try:
alert_url = "{host}/alerts/{alert_id}".format( alert_url = "{host}/alerts/{alert_id}".format(host=host, alert_id=alert.id)
host=host, alert_id=alert.id
)
query_url = "{host}/queries/{query_id}".format( query_url = "{host}/queries/{query_id}".format(host=host, query_id=query.id)
host=host, query_id=query.id
)
message_template = options.get( message_template = options.get("message_template", MicrosoftTeamsWebhook.ALERTS_DEFAULT_MESSAGE_TEMPLATE)
"message_template", MicrosoftTeamsWebhook.ALERTS_DEFAULT_MESSAGE_TEMPLATE
)
# Doing a string Template substitution here because the template contains braces, which # Doing a string Template substitution here because the template contains braces, which
# result in keyerrors when attempting string.format # result in keyerrors when attempting string.format
payload = json_string_substitute(message_template, { payload = json_string_substitute(
"alert_name": alert.name, message_template,
"alert_url": alert_url, {
"query_text": query.query_text, "alert_name": alert.name,
"query_url": query_url "alert_url": alert_url,
}) "query_text": query.query_text,
"query_url": query_url,
},
)
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@@ -116,11 +106,7 @@ class MicrosoftTeamsWebhook(BaseDestination):
timeout=5.0, timeout=5.0,
) )
if resp.status_code != 200: if resp.status_code != 200:
logging.error( logging.error("MS Teams Webhook send ERROR. status_code => {status}".format(status=resp.status_code))
"MS Teams Webhook send ERROR. status_code => {status}".format(
status=resp.status_code
)
)
except Exception: except Exception:
logging.exception("MS Teams Webhook send ERROR.") logging.exception("MS Teams Webhook send ERROR.")

View File

@@ -1,5 +1,6 @@
import logging import logging
from redash.destinations import *
from redash.destinations import BaseDestination, register
enabled = True enabled = True
@@ -10,7 +11,6 @@ except ImportError:
class PagerDuty(BaseDestination): class PagerDuty(BaseDestination):
KEY_STRING = "{alert_id}_{query_id}" KEY_STRING = "{alert_id}_{query_id}"
DESCRIPTION_STR = "Alert: {alert_name}" DESCRIPTION_STR = "Alert: {alert_name}"
@@ -41,7 +41,6 @@ class PagerDuty(BaseDestination):
return "creative-commons-pd-alt" return "creative-commons-pd-alt"
def notify(self, alert, query, user, new_state, app, host, options): def notify(self, alert, query, user, new_state, app, host, options):
if alert.custom_subject: if alert.custom_subject:
default_desc = alert.custom_subject default_desc = alert.custom_subject
elif options.get("description"): elif options.get("description"):
@@ -73,7 +72,6 @@ class PagerDuty(BaseDestination):
data["event_action"] = "resolve" data["event_action"] = "resolve"
try: try:
ev = pypd.EventV2.create(data=data) ev = pypd.EventV2.create(data=data)
logging.warning(ev) logging.warning(ev)

View File

@@ -1,7 +1,8 @@
import logging import logging
import requests import requests
from redash.destinations import * from redash.destinations import BaseDestination, register
from redash.utils import json_dumps from redash.utils import json_dumps
@@ -25,16 +26,12 @@ class Slack(BaseDestination):
fields = [ fields = [
{ {
"title": "Query", "title": "Query",
"value": "{host}/queries/{query_id}".format( "value": "{host}/queries/{query_id}".format(host=host, query_id=query.id),
host=host, query_id=query.id
),
"short": True, "short": True,
}, },
{ {
"title": "Alert", "title": "Alert",
"value": "{host}/alerts/{alert_id}".format( "value": "{host}/alerts/{alert_id}".format(host=host, alert_id=alert.id),
host=host, alert_id=alert.id
),
"short": True, "short": True,
}, },
] ]
@@ -53,16 +50,10 @@ class Slack(BaseDestination):
payload = {"attachments": [{"text": text, "color": color, "fields": fields}]} payload = {"attachments": [{"text": text, "color": color, "fields": fields}]}
try: try:
resp = requests.post( resp = requests.post(options.get("url"), data=json_dumps(payload), timeout=5.0)
options.get("url"), data=json_dumps(payload), timeout=5.0
)
logging.warning(resp.text) logging.warning(resp.text)
if resp.status_code != 200: if resp.status_code != 200:
logging.error( logging.error("Slack send ERROR. status_code => {status}".format(status=resp.status_code))
"Slack send ERROR. status_code => {status}".format(
status=resp.status_code
)
)
except Exception: except Exception:
logging.exception("Slack send ERROR.") logging.exception("Slack send ERROR.")

View File

@@ -1,10 +1,11 @@
import logging import logging
import requests import requests
from requests.auth import HTTPBasicAuth from requests.auth import HTTPBasicAuth
from redash.destinations import * from redash.destinations import BaseDestination, register
from redash.utils import json_dumps
from redash.serializers import serialize_alert from redash.serializers import serialize_alert
from redash.utils import json_dumps
class Webhook(BaseDestination): class Webhook(BaseDestination):
@@ -37,11 +38,7 @@ class Webhook(BaseDestination):
data["alert"]["title"] = alert.custom_subject data["alert"]["title"] = alert.custom_subject
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
auth = ( auth = HTTPBasicAuth(options.get("username"), options.get("password")) if options.get("username") else None
HTTPBasicAuth(options.get("username"), options.get("password"))
if options.get("username")
else None
)
resp = requests.post( resp = requests.post(
options.get("url"), options.get("url"),
data=json_dumps(data), data=json_dumps(data),
@@ -50,11 +47,7 @@ class Webhook(BaseDestination):
timeout=5.0, timeout=5.0,
) )
if resp.status_code != 200: if resp.status_code != 200:
logging.error( logging.error("webhook send ERROR. status_code => {status}".format(status=resp.status_code))
"webhook send ERROR. status_code => {status}".format(
status=resp.status_code
)
)
except Exception: except Exception:
logging.exception("webhook send ERROR.") logging.exception("webhook send ERROR.")

View File

@@ -24,13 +24,13 @@ def status_api():
def init_app(app): def init_app(app):
from redash.handlers import ( from redash.handlers import (
embed,
queries,
static,
authentication,
admin, admin,
setup, authentication,
embed,
organization, organization,
queries,
setup,
static,
) )
app.register_blueprint(routes) app.register_blueprint(routes)

View File

@@ -1,14 +1,13 @@
from flask import request from flask_login import current_user, login_required
from flask_login import login_required, current_user
from redash import models, redis_connection from redash import models, redis_connection
from redash.authentication import current_org from redash.authentication import current_org
from redash.handlers import routes from redash.handlers import routes
from redash.handlers.base import json_response, record_event from redash.handlers.base import json_response, record_event
from redash.monitor import rq_status
from redash.permissions import require_super_admin from redash.permissions import require_super_admin
from redash.serializers import QuerySerializer from redash.serializers import QuerySerializer
from redash.utils import json_loads from redash.utils import json_loads
from redash.monitor import rq_status
@routes.route("/api/admin/queries/outdated", methods=["GET"]) @routes.route("/api/admin/queries/outdated", methods=["GET"])
@@ -29,13 +28,14 @@ def outdated_queries():
record_event( record_event(
current_org, current_org,
current_user._get_current_object(), current_user._get_current_object(),
{"action": "list", "object_type": "outdated_queries",}, {
"action": "list",
"object_type": "outdated_queries",
},
) )
response = { response = {
"queries": QuerySerializer( "queries": QuerySerializer(outdated_queries, with_stats=True, with_last_modified_by=False).serialize(),
outdated_queries, with_stats=True, with_last_modified_by=False
).serialize(),
"updated_at": manager_status["last_refresh_at"], "updated_at": manager_status["last_refresh_at"],
} }
return json_response(response) return json_response(response)

View File

@@ -1,52 +1,43 @@
import time
from flask import request from flask import request
from funcy import project from funcy import project
from redash import models from redash import models
from redash.serializers import serialize_alert from redash.handlers.base import (
from redash.handlers.base import BaseResource, get_object_or_404, require_fields BaseResource,
get_object_or_404,
require_fields,
)
from redash.permissions import ( from redash.permissions import (
require_access, require_access,
require_admin_or_owner, require_admin_or_owner,
require_permission, require_permission,
view_only, view_only,
) )
from redash.utils import json_dumps from redash.serializers import serialize_alert
class AlertResource(BaseResource): class AlertResource(BaseResource):
def get(self, alert_id): def get(self, alert_id):
alert = get_object_or_404( alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org)
models.Alert.get_by_id_and_org, alert_id, self.current_org
)
require_access(alert, self.current_user, view_only) require_access(alert, self.current_user, view_only)
self.record_event( self.record_event({"action": "view", "object_id": alert.id, "object_type": "alert"})
{"action": "view", "object_id": alert.id, "object_type": "alert"}
)
return serialize_alert(alert) return serialize_alert(alert)
def post(self, alert_id): def post(self, alert_id):
req = request.get_json(True) req = request.get_json(True)
params = project(req, ("options", "name", "query_id", "rearm")) params = project(req, ("options", "name", "query_id", "rearm"))
alert = get_object_or_404( alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org)
models.Alert.get_by_id_and_org, alert_id, self.current_org
)
require_admin_or_owner(alert.user.id) require_admin_or_owner(alert.user.id)
self.update_model(alert, params) self.update_model(alert, params)
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "edit", "object_id": alert.id, "object_type": "alert"})
{"action": "edit", "object_id": alert.id, "object_type": "alert"}
)
return serialize_alert(alert) return serialize_alert(alert)
def delete(self, alert_id): def delete(self, alert_id):
alert = get_object_or_404( alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org)
models.Alert.get_by_id_and_org, alert_id, self.current_org
)
require_admin_or_owner(alert.user_id) require_admin_or_owner(alert.user_id)
models.db.session.delete(alert) models.db.session.delete(alert)
models.db.session.commit() models.db.session.commit()
@@ -54,30 +45,22 @@ class AlertResource(BaseResource):
class AlertMuteResource(BaseResource): class AlertMuteResource(BaseResource):
def post(self, alert_id): def post(self, alert_id):
alert = get_object_or_404( alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org)
models.Alert.get_by_id_and_org, alert_id, self.current_org
)
require_admin_or_owner(alert.user.id) require_admin_or_owner(alert.user.id)
alert.options["muted"] = True alert.options["muted"] = True
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "mute", "object_id": alert.id, "object_type": "alert"})
{"action": "mute", "object_id": alert.id, "object_type": "alert"}
)
def delete(self, alert_id): def delete(self, alert_id):
alert = get_object_or_404( alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org)
models.Alert.get_by_id_and_org, alert_id, self.current_org
)
require_admin_or_owner(alert.user.id) require_admin_or_owner(alert.user.id)
alert.options["muted"] = False alert.options["muted"] = False
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "unmute", "object_id": alert.id, "object_type": "alert"})
{"action": "unmute", "object_id": alert.id, "object_type": "alert"}
)
class AlertListResource(BaseResource): class AlertListResource(BaseResource):
@@ -100,19 +83,14 @@ class AlertListResource(BaseResource):
models.db.session.flush() models.db.session.flush()
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "create", "object_id": alert.id, "object_type": "alert"})
{"action": "create", "object_id": alert.id, "object_type": "alert"}
)
return serialize_alert(alert) return serialize_alert(alert)
@require_permission("list_alerts") @require_permission("list_alerts")
def get(self): def get(self):
self.record_event({"action": "list", "object_type": "alert"}) self.record_event({"action": "list", "object_type": "alert"})
return [ return [serialize_alert(alert) for alert in models.Alert.all(group_ids=self.current_user.group_ids)]
serialize_alert(alert)
for alert in models.Alert.all(group_ids=self.current_user.group_ids)
]
class AlertSubscriptionListResource(BaseResource): class AlertSubscriptionListResource(BaseResource):
@@ -124,9 +102,7 @@ class AlertSubscriptionListResource(BaseResource):
kwargs = {"alert": alert, "user": self.current_user} kwargs = {"alert": alert, "user": self.current_user}
if "destination_id" in req: if "destination_id" in req:
destination = models.NotificationDestination.get_by_id_and_org( destination = models.NotificationDestination.get_by_id_and_org(req["destination_id"], self.current_org)
req["destination_id"], self.current_org
)
kwargs["destination"] = destination kwargs["destination"] = destination
subscription = models.AlertSubscription(**kwargs) subscription = models.AlertSubscription(**kwargs)
@@ -160,6 +136,4 @@ class AlertSubscriptionResource(BaseResource):
models.db.session.delete(subscription) models.db.session.delete(subscription)
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "unsubscribe", "object_id": alert_id, "object_type": "alert"})
{"action": "unsubscribe", "object_id": alert_id, "object_type": "alert"}
)

View File

@@ -4,19 +4,19 @@ from werkzeug.wrappers import Response
from redash.handlers.alerts import ( from redash.handlers.alerts import (
AlertListResource, AlertListResource,
AlertResource,
AlertMuteResource, AlertMuteResource,
AlertResource,
AlertSubscriptionListResource, AlertSubscriptionListResource,
AlertSubscriptionResource, AlertSubscriptionResource,
) )
from redash.handlers.base import org_scoped_rule from redash.handlers.base import org_scoped_rule
from redash.handlers.dashboards import ( from redash.handlers.dashboards import (
MyDashboardsResource,
DashboardFavoriteListResource, DashboardFavoriteListResource,
DashboardListResource, DashboardListResource,
DashboardResource, DashboardResource,
DashboardShareResource, DashboardShareResource,
DashboardTagsResource, DashboardTagsResource,
MyDashboardsResource,
PublicDashboardResource, PublicDashboardResource,
) )
from redash.handlers.data_sources import ( from redash.handlers.data_sources import (
@@ -38,7 +38,10 @@ from redash.handlers.destinations import (
DestinationTypeListResource, DestinationTypeListResource,
) )
from redash.handlers.events import EventsResource from redash.handlers.events import EventsResource
from redash.handlers.favorites import DashboardFavoriteResource, QueryFavoriteResource from redash.handlers.favorites import (
DashboardFavoriteResource,
QueryFavoriteResource,
)
from redash.handlers.groups import ( from redash.handlers.groups import (
GroupDataSourceListResource, GroupDataSourceListResource,
GroupDataSourceResource, GroupDataSourceResource,
@@ -59,15 +62,15 @@ from redash.handlers.queries import (
QueryListResource, QueryListResource,
QueryRecentResource, QueryRecentResource,
QueryRefreshResource, QueryRefreshResource,
QueryRegenerateApiKeyResource,
QueryResource, QueryResource,
QuerySearchResource, QuerySearchResource,
QueryTagsResource, QueryTagsResource,
QueryRegenerateApiKeyResource,
) )
from redash.handlers.query_results import ( from redash.handlers.query_results import (
JobResource, JobResource,
QueryResultDropdownResource,
QueryDropdownsResource, QueryDropdownsResource,
QueryResultDropdownResource,
QueryResultListResource, QueryResultListResource,
QueryResultResource, QueryResultResource,
) )
@@ -112,9 +115,7 @@ def json_representation(data, code, headers=None):
api.add_org_resource(AlertResource, "/api/alerts/<alert_id>", endpoint="alert") api.add_org_resource(AlertResource, "/api/alerts/<alert_id>", endpoint="alert")
api.add_org_resource( api.add_org_resource(AlertMuteResource, "/api/alerts/<alert_id>/mute", endpoint="alert_mute")
AlertMuteResource, "/api/alerts/<alert_id>/mute", endpoint="alert_mute"
)
api.add_org_resource( api.add_org_resource(
AlertSubscriptionListResource, AlertSubscriptionListResource,
"/api/alerts/<alert_id>/subscriptions", "/api/alerts/<alert_id>/subscriptions",
@@ -128,9 +129,7 @@ api.add_org_resource(
api.add_org_resource(AlertListResource, "/api/alerts", endpoint="alerts") api.add_org_resource(AlertListResource, "/api/alerts", endpoint="alerts")
api.add_org_resource(DashboardListResource, "/api/dashboards", endpoint="dashboards") api.add_org_resource(DashboardListResource, "/api/dashboards", endpoint="dashboards")
api.add_org_resource( api.add_org_resource(DashboardResource, "/api/dashboards/<dashboard_id>", endpoint="dashboard")
DashboardResource, "/api/dashboards/<dashboard_id>", endpoint="dashboard"
)
api.add_org_resource( api.add_org_resource(
PublicDashboardResource, PublicDashboardResource,
"/api/dashboards/public/<token>", "/api/dashboards/public/<token>",
@@ -142,18 +141,10 @@ api.add_org_resource(
endpoint="dashboard_share", endpoint="dashboard_share",
) )
api.add_org_resource( api.add_org_resource(DataSourceTypeListResource, "/api/data_sources/types", endpoint="data_source_types")
DataSourceTypeListResource, "/api/data_sources/types", endpoint="data_source_types" api.add_org_resource(DataSourceListResource, "/api/data_sources", endpoint="data_sources")
) api.add_org_resource(DataSourceSchemaResource, "/api/data_sources/<data_source_id>/schema")
api.add_org_resource( api.add_org_resource(DatabricksDatabaseListResource, "/api/databricks/databases/<data_source_id>")
DataSourceListResource, "/api/data_sources", endpoint="data_sources"
)
api.add_org_resource(
DataSourceSchemaResource, "/api/data_sources/<data_source_id>/schema"
)
api.add_org_resource(
DatabricksDatabaseListResource, "/api/databricks/databases/<data_source_id>"
)
api.add_org_resource( api.add_org_resource(
DatabricksSchemaResource, DatabricksSchemaResource,
"/api/databricks/databases/<data_source_id>/<database_name>/tables", "/api/databricks/databases/<data_source_id>/<database_name>/tables",
@@ -162,19 +153,13 @@ api.add_org_resource(
DatabricksTableColumnListResource, DatabricksTableColumnListResource,
"/api/databricks/databases/<data_source_id>/<database_name>/columns/<table_name>", "/api/databricks/databases/<data_source_id>/<database_name>/columns/<table_name>",
) )
api.add_org_resource( api.add_org_resource(DataSourcePauseResource, "/api/data_sources/<data_source_id>/pause")
DataSourcePauseResource, "/api/data_sources/<data_source_id>/pause"
)
api.add_org_resource(DataSourceTestResource, "/api/data_sources/<data_source_id>/test") api.add_org_resource(DataSourceTestResource, "/api/data_sources/<data_source_id>/test")
api.add_org_resource( api.add_org_resource(DataSourceResource, "/api/data_sources/<data_source_id>", endpoint="data_source")
DataSourceResource, "/api/data_sources/<data_source_id>", endpoint="data_source"
)
api.add_org_resource(GroupListResource, "/api/groups", endpoint="groups") api.add_org_resource(GroupListResource, "/api/groups", endpoint="groups")
api.add_org_resource(GroupResource, "/api/groups/<group_id>", endpoint="group") api.add_org_resource(GroupResource, "/api/groups/<group_id>", endpoint="group")
api.add_org_resource( api.add_org_resource(GroupMemberListResource, "/api/groups/<group_id>/members", endpoint="group_members")
GroupMemberListResource, "/api/groups/<group_id>/members", endpoint="group_members"
)
api.add_org_resource( api.add_org_resource(
GroupMemberResource, GroupMemberResource,
"/api/groups/<group_id>/members/<user_id>", "/api/groups/<group_id>/members/<user_id>",
@@ -193,12 +178,8 @@ api.add_org_resource(
api.add_org_resource(EventsResource, "/api/events", endpoint="events") api.add_org_resource(EventsResource, "/api/events", endpoint="events")
api.add_org_resource( api.add_org_resource(QueryFavoriteListResource, "/api/queries/favorites", endpoint="query_favorites")
QueryFavoriteListResource, "/api/queries/favorites", endpoint="query_favorites" api.add_org_resource(QueryFavoriteResource, "/api/queries/<query_id>/favorite", endpoint="query_favorite")
)
api.add_org_resource(
QueryFavoriteResource, "/api/queries/<query_id>/favorite", endpoint="query_favorite"
)
api.add_org_resource( api.add_org_resource(
DashboardFavoriteListResource, DashboardFavoriteListResource,
"/api/dashboards/favorites", "/api/dashboards/favorites",
@@ -213,28 +194,16 @@ api.add_org_resource(
api.add_org_resource(MyDashboardsResource, "/api/dashboards/my", endpoint="my_dashboards") api.add_org_resource(MyDashboardsResource, "/api/dashboards/my", endpoint="my_dashboards")
api.add_org_resource(QueryTagsResource, "/api/queries/tags", endpoint="query_tags") api.add_org_resource(QueryTagsResource, "/api/queries/tags", endpoint="query_tags")
api.add_org_resource( api.add_org_resource(DashboardTagsResource, "/api/dashboards/tags", endpoint="dashboard_tags")
DashboardTagsResource, "/api/dashboards/tags", endpoint="dashboard_tags"
)
api.add_org_resource( api.add_org_resource(QuerySearchResource, "/api/queries/search", endpoint="queries_search")
QuerySearchResource, "/api/queries/search", endpoint="queries_search" api.add_org_resource(QueryRecentResource, "/api/queries/recent", endpoint="recent_queries")
) api.add_org_resource(QueryArchiveResource, "/api/queries/archive", endpoint="queries_archive")
api.add_org_resource(
QueryRecentResource, "/api/queries/recent", endpoint="recent_queries"
)
api.add_org_resource(
QueryArchiveResource, "/api/queries/archive", endpoint="queries_archive"
)
api.add_org_resource(QueryListResource, "/api/queries", endpoint="queries") api.add_org_resource(QueryListResource, "/api/queries", endpoint="queries")
api.add_org_resource(MyQueriesResource, "/api/queries/my", endpoint="my_queries") api.add_org_resource(MyQueriesResource, "/api/queries/my", endpoint="my_queries")
api.add_org_resource( api.add_org_resource(QueryRefreshResource, "/api/queries/<query_id>/refresh", endpoint="query_refresh")
QueryRefreshResource, "/api/queries/<query_id>/refresh", endpoint="query_refresh"
)
api.add_org_resource(QueryResource, "/api/queries/<query_id>", endpoint="query") api.add_org_resource(QueryResource, "/api/queries/<query_id>", endpoint="query")
api.add_org_resource( api.add_org_resource(QueryForkResource, "/api/queries/<query_id>/fork", endpoint="query_fork")
QueryForkResource, "/api/queries/<query_id>/fork", endpoint="query_fork"
)
api.add_org_resource( api.add_org_resource(
QueryRegenerateApiKeyResource, QueryRegenerateApiKeyResource,
"/api/queries/<query_id>/regenerate_api_key", "/api/queries/<query_id>/regenerate_api_key",
@@ -252,9 +221,7 @@ api.add_org_resource(
endpoint="check_permissions", endpoint="check_permissions",
) )
api.add_org_resource( api.add_org_resource(QueryResultListResource, "/api/query_results", endpoint="query_results")
QueryResultListResource, "/api/query_results", endpoint="query_results"
)
api.add_org_resource( api.add_org_resource(
QueryResultDropdownResource, QueryResultDropdownResource,
"/api/queries/<query_id>/dropdown", "/api/queries/<query_id>/dropdown",
@@ -283,9 +250,7 @@ api.add_org_resource(
api.add_org_resource(UserListResource, "/api/users", endpoint="users") api.add_org_resource(UserListResource, "/api/users", endpoint="users")
api.add_org_resource(UserResource, "/api/users/<user_id>", endpoint="user") api.add_org_resource(UserResource, "/api/users/<user_id>", endpoint="user")
api.add_org_resource( api.add_org_resource(UserInviteResource, "/api/users/<user_id>/invite", endpoint="user_invite")
UserInviteResource, "/api/users/<user_id>/invite", endpoint="user_invite"
)
api.add_org_resource( api.add_org_resource(
UserResetPasswordResource, UserResetPasswordResource,
"/api/users/<user_id>/reset_password", "/api/users/<user_id>/reset_password",
@@ -296,13 +261,9 @@ api.add_org_resource(
"/api/users/<user_id>/regenerate_api_key", "/api/users/<user_id>/regenerate_api_key",
endpoint="user_regenerate_api_key", endpoint="user_regenerate_api_key",
) )
api.add_org_resource( api.add_org_resource(UserDisableResource, "/api/users/<user_id>/disable", endpoint="user_disable")
UserDisableResource, "/api/users/<user_id>/disable", endpoint="user_disable"
)
api.add_org_resource( api.add_org_resource(VisualizationListResource, "/api/visualizations", endpoint="visualizations")
VisualizationListResource, "/api/visualizations", endpoint="visualizations"
)
api.add_org_resource( api.add_org_resource(
VisualizationResource, VisualizationResource,
"/api/visualizations/<visualization_id>", "/api/visualizations/<visualization_id>",
@@ -312,23 +273,11 @@ api.add_org_resource(
api.add_org_resource(WidgetListResource, "/api/widgets", endpoint="widgets") api.add_org_resource(WidgetListResource, "/api/widgets", endpoint="widgets")
api.add_org_resource(WidgetResource, "/api/widgets/<int:widget_id>", endpoint="widget") api.add_org_resource(WidgetResource, "/api/widgets/<int:widget_id>", endpoint="widget")
api.add_org_resource( api.add_org_resource(DestinationTypeListResource, "/api/destinations/types", endpoint="destination_types")
DestinationTypeListResource, "/api/destinations/types", endpoint="destination_types" api.add_org_resource(DestinationResource, "/api/destinations/<destination_id>", endpoint="destination")
) api.add_org_resource(DestinationListResource, "/api/destinations", endpoint="destinations")
api.add_org_resource(
DestinationResource, "/api/destinations/<destination_id>", endpoint="destination"
)
api.add_org_resource(
DestinationListResource, "/api/destinations", endpoint="destinations"
)
api.add_org_resource( api.add_org_resource(QuerySnippetResource, "/api/query_snippets/<snippet_id>", endpoint="query_snippet")
QuerySnippetResource, "/api/query_snippets/<snippet_id>", endpoint="query_snippet" api.add_org_resource(QuerySnippetListResource, "/api/query_snippets", endpoint="query_snippets")
)
api.add_org_resource(
QuerySnippetListResource, "/api/query_snippets", endpoint="query_snippets"
)
api.add_org_resource( api.add_org_resource(OrganizationSettings, "/api/settings/organization", endpoint="organization_settings")
OrganizationSettings, "/api/settings/organization", endpoint="organization_settings"
)

View File

@@ -1,13 +1,13 @@
import logging import logging
from flask import abort, flash, redirect, render_template, request, url_for from flask import abort, flash, redirect, render_template, request, url_for
from flask_login import current_user, login_required, login_user, logout_user from flask_login import current_user, login_required, login_user, logout_user
from itsdangerous import BadSignature, SignatureExpired
from sqlalchemy.orm.exc import NoResultFound
from redash import __version__, limiter, models, settings from redash import __version__, limiter, models, settings
from redash.authentication import current_org, get_login_url, get_next_path from redash.authentication import current_org, get_login_url, get_next_path
from redash.authentication.account import ( from redash.authentication.account import (
BadSignature,
SignatureExpired,
send_password_reset_email, send_password_reset_email,
send_user_disabled_email, send_user_disabled_email,
send_verify_email, send_verify_email,
@@ -16,16 +16,13 @@ from redash.authentication.account import (
from redash.handlers import routes from redash.handlers import routes
from redash.handlers.base import json_response, org_scoped_rule from redash.handlers.base import json_response, org_scoped_rule
from redash.version_check import get_latest_version from redash.version_check import get_latest_version
from sqlalchemy.orm.exc import NoResultFound
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_google_auth_url(next_path): def get_google_auth_url(next_path):
if settings.MULTI_ORG: if settings.MULTI_ORG:
google_auth_url = url_for( google_auth_url = url_for("google_oauth.authorize_org", next=next_path, org_slug=current_org.slug)
"google_oauth.authorize_org", next=next_path, org_slug=current_org.slug
)
else: else:
google_auth_url = url_for("google_oauth.authorize", next=next_path) google_auth_url = url_for("google_oauth.authorize", next=next_path)
return google_auth_url return google_auth_url
@@ -65,8 +62,7 @@ def render_token_login_page(template, org_slug, token, invite):
render_template( render_template(
"error.html", "error.html",
error_message=( error_message=(
"This invitation has already been accepted. " "This invitation has already been accepted. Please try resetting your password instead."
"Please try resetting your password instead."
), ),
), ),
400, 400,
@@ -126,9 +122,7 @@ def verify(token, org_slug=None):
org = current_org._get_current_object() org = current_org._get_current_object()
user = models.User.get_by_id_and_org(user_id, org) user = models.User.get_by_id_and_org(user_id, org)
except (BadSignature, NoResultFound): except (BadSignature, NoResultFound):
logger.exception( logger.exception("Failed to verify email verification token: %s, org=%s", token, org_slug)
"Failed to verify email verification token: %s, org=%s", token, org_slug
)
return ( return (
render_template( render_template(
"error.html", "error.html",
@@ -175,11 +169,7 @@ def verification_email(org_slug=None):
if not current_user.is_email_verified: if not current_user.is_email_verified:
send_verify_email(current_user, current_org) send_verify_email(current_user, current_org)
return json_response( return json_response({"message": "Please check your email inbox in order to verify your email address."})
{
"message": "Please check your email inbox in order to verify your email address."
}
)
@routes.route(org_scoped_rule("/login"), methods=["GET", "POST"]) @routes.route(org_scoped_rule("/login"), methods=["GET", "POST"])
@@ -187,9 +177,9 @@ def verification_email(org_slug=None):
def login(org_slug=None): def login(org_slug=None):
# We intentionally use == as otherwise it won't actually use the proxy. So weird :O # We intentionally use == as otherwise it won't actually use the proxy. So weird :O
# noinspection PyComparisonWithNone # noinspection PyComparisonWithNone
if current_org == None and not settings.MULTI_ORG: if current_org == None and not settings.MULTI_ORG: # noqa: E711
return redirect("/setup") return redirect("/setup")
elif current_org == None: elif current_org == None: # noqa: E711
return redirect("/") return redirect("/")
index_url = url_for("redash.index", org_slug=org_slug) index_url = url_for("redash.index", org_slug=org_slug)
@@ -198,16 +188,11 @@ def login(org_slug=None):
if current_user.is_authenticated: if current_user.is_authenticated:
return redirect(next_path) return redirect(next_path)
if request.method == "POST" and current_org.get_setting("auth_password_login_enabled"): if request.method == "POST" and current_org.get_setting("auth_password_login_enabled"):
try: try:
org = current_org._get_current_object() org = current_org._get_current_object()
user = models.User.get_by_email_and_org(request.form["email"], org) user = models.User.get_by_email_and_org(request.form["email"], org)
if ( if user and not user.is_disabled and user.verify_password(request.form["password"]):
user
and not user.is_disabled
and user.verify_password(request.form["password"])
):
remember = "remember" in request.form remember = "remember" in request.form
login_user(user, remember=remember) login_user(user, remember=remember)
return redirect(next_path) return redirect(next_path)
@@ -218,8 +203,6 @@ def login(org_slug=None):
elif request.method == "POST" and not current_org.get_setting("auth_password_login_enabled"): elif request.method == "POST" and not current_org.get_setting("auth_password_login_enabled"):
flash("Password login is not enabled for your organization.") flash("Password login is not enabled for your organization.")
google_auth_url = get_google_auth_url(next_path) google_auth_url = get_google_auth_url(next_path)
return render_template( return render_template(
@@ -280,20 +263,13 @@ def client_config():
else: else:
client_config = {} client_config = {}
if ( if current_user.has_permission("admin") and current_org.get_setting("beacon_consent") is None:
current_user.has_permission("admin")
and current_org.get_setting("beacon_consent") is None
):
client_config["showBeaconConsentMessage"] = True client_config["showBeaconConsentMessage"] = True
defaults = { defaults = {
"allowScriptsInUserInput": settings.ALLOW_SCRIPTS_IN_USER_INPUT, "allowScriptsInUserInput": settings.ALLOW_SCRIPTS_IN_USER_INPUT,
"showPermissionsControl": current_org.get_setting( "showPermissionsControl": current_org.get_setting("feature_show_permissions_control"),
"feature_show_permissions_control" "hidePlotlyModeBar": current_org.get_setting("hide_plotly_mode_bar"),
),
"hidePlotlyModeBar": current_org.get_setting(
"hide_plotly_mode_bar"
),
"disablePublicUrls": current_org.get_setting("disable_public_urls"), "disablePublicUrls": current_org.get_setting("disable_public_urls"),
"allowCustomJSVisualizations": settings.FEATURE_ALLOW_CUSTOM_JS_VISUALIZATIONS, "allowCustomJSVisualizations": settings.FEATURE_ALLOW_CUSTOM_JS_VISUALIZATIONS,
"autoPublishNamedQueries": settings.FEATURE_AUTO_PUBLISH_NAMED_QUERIES, "autoPublishNamedQueries": settings.FEATURE_AUTO_PUBLISH_NAMED_QUERIES,
@@ -330,9 +306,7 @@ def messages():
@routes.route("/api/config", methods=["GET"]) @routes.route("/api/config", methods=["GET"])
def config(org_slug=None): def config(org_slug=None):
return json_response( return json_response({"org_slug": current_org.slug, "client_config": client_config()})
{"org_slug": current_org.slug, "client_config": client_config()}
)
@routes.route(org_scoped_rule("/api/session"), methods=["GET"]) @routes.route(org_scoped_rule("/api/session"), methods=["GET"])

View File

@@ -15,9 +15,7 @@ from redash.models import db
from redash.tasks import record_event as record_event_task from redash.tasks import record_event as record_event_task
from redash.utils import json_dumps from redash.utils import json_dumps
routes = Blueprint( routes = Blueprint("redash", __name__, template_folder=settings.fix_assets_path("templates"))
"redash", __name__, template_folder=settings.fix_assets_path("templates")
)
class BaseResource(Resource): class BaseResource(Resource):
@@ -116,9 +114,7 @@ def json_response(response):
def filter_by_tags(result_set, column): def filter_by_tags(result_set, column):
if request.args.getlist("tags"): if request.args.getlist("tags"):
tags = request.args.getlist("tags") tags = request.args.getlist("tags")
result_set = result_set.filter( result_set = result_set.filter(cast(column, postgresql.ARRAY(db.Text)).contains(tags))
cast(column, postgresql.ARRAY(db.Text)).contains(tags)
)
return result_set return result_set

View File

@@ -1,15 +1,16 @@
from flask import request, url_for from flask import request, url_for
from funcy import project, partial
from flask_restful import abort from flask_restful import abort
from funcy import partial, project
from sqlalchemy.orm.exc import StaleDataError
from redash import models from redash import models
from redash.handlers.base import ( from redash.handlers.base import (
BaseResource, BaseResource,
get_object_or_404,
paginate,
filter_by_tags, filter_by_tags,
order_results as _order_results, get_object_or_404,
) )
from redash.handlers.base import order_results as _order_results
from redash.handlers.base import paginate
from redash.permissions import ( from redash.permissions import (
can_modify, can_modify,
require_admin_or_owner, require_admin_or_owner,
@@ -17,12 +18,7 @@ from redash.permissions import (
require_permission, require_permission,
) )
from redash.security import csp_allows_embeding from redash.security import csp_allows_embeding
from redash.serializers import ( from redash.serializers import DashboardSerializer, public_dashboard
DashboardSerializer,
public_dashboard,
)
from sqlalchemy.orm.exc import StaleDataError
# Ordering map for relationships # Ordering map for relationships
order_map = { order_map = {
@@ -32,9 +28,7 @@ order_map = {
"-created_at": "-created_at", "-created_at": "-created_at",
} }
order_results = partial( order_results = partial(_order_results, default_order="-created_at", allowed_orders=order_map)
_order_results, default_order="-created_at", allowed_orders=order_map
)
class DashboardListResource(BaseResource): class DashboardListResource(BaseResource):
@@ -61,9 +55,7 @@ class DashboardListResource(BaseResource):
search_term, search_term,
) )
else: else:
results = models.Dashboard.all( results = models.Dashboard.all(self.current_org, self.current_user.group_ids, self.current_user.id)
self.current_org, self.current_user.group_ids, self.current_user.id
)
results = filter_by_tags(results, models.Dashboard.tags) results = filter_by_tags(results, models.Dashboard.tags)
@@ -83,9 +75,7 @@ class DashboardListResource(BaseResource):
) )
if search_term: if search_term:
self.record_event( self.record_event({"action": "search", "object_type": "dashboard", "term": search_term})
{"action": "search", "object_type": "dashboard", "term": search_term}
)
else: else:
self.record_event({"action": "list", "object_type": "dashboard"}) self.record_event({"action": "list", "object_type": "dashboard"})
@@ -142,12 +132,7 @@ class MyDashboardsResource(BaseResource):
page = request.args.get("page", 1, type=int) page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 25, type=int) page_size = request.args.get("page_size", 25, type=int)
return paginate( return paginate(ordered_results, page, page_size, DashboardSerializer)
ordered_results,
page,
page_size,
DashboardSerializer
)
class DashboardResource(BaseResource): class DashboardResource(BaseResource):
@@ -193,9 +178,7 @@ class DashboardResource(BaseResource):
fn = models.Dashboard.get_by_id_and_org fn = models.Dashboard.get_by_id_and_org
dashboard = get_object_or_404(fn, dashboard_id, self.current_org) dashboard = get_object_or_404(fn, dashboard_id, self.current_org)
response = DashboardSerializer( response = DashboardSerializer(dashboard, with_widgets=True, user=self.current_user).serialize()
dashboard, with_widgets=True, user=self.current_user
).serialize()
api_key = models.ApiKey.get_by_object(dashboard) api_key = models.ApiKey.get_by_object(dashboard)
if api_key: if api_key:
@@ -209,9 +192,7 @@ class DashboardResource(BaseResource):
response["can_edit"] = can_modify(dashboard, self.current_user) response["can_edit"] = can_modify(dashboard, self.current_user)
self.record_event( self.record_event({"action": "view", "object_id": dashboard.id, "object_type": "dashboard"})
{"action": "view", "object_id": dashboard.id, "object_type": "dashboard"}
)
return response return response
@@ -262,13 +243,9 @@ class DashboardResource(BaseResource):
except StaleDataError: except StaleDataError:
abort(409) abort(409)
result = DashboardSerializer( result = DashboardSerializer(dashboard, with_widgets=True, user=self.current_user).serialize()
dashboard, with_widgets=True, user=self.current_user
).serialize()
self.record_event( self.record_event({"action": "edit", "object_id": dashboard.id, "object_type": "dashboard"})
{"action": "edit", "object_id": dashboard.id, "object_type": "dashboard"}
)
return result return result
@@ -285,14 +262,10 @@ class DashboardResource(BaseResource):
dashboard.is_archived = True dashboard.is_archived = True
dashboard.record_changes(changed_by=self.current_user) dashboard.record_changes(changed_by=self.current_user)
models.db.session.add(dashboard) models.db.session.add(dashboard)
d = DashboardSerializer( d = DashboardSerializer(dashboard, with_widgets=True, user=self.current_user).serialize()
dashboard, with_widgets=True, user=self.current_user
).serialize()
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "archive", "object_id": dashboard.id, "object_type": "dashboard"})
{"action": "archive", "object_id": dashboard.id, "object_type": "dashboard"}
)
return d return d
@@ -396,9 +369,7 @@ class DashboardFavoriteListResource(BaseResource):
self.current_user.id, self.current_user.id,
search_term, search_term,
) )
favorites = models.Dashboard.favorites( favorites = models.Dashboard.favorites(self.current_user, base_query=base_query)
self.current_user, base_query=base_query
)
else: else:
favorites = models.Dashboard.favorites(self.current_user) favorites = models.Dashboard.favorites(self.current_user)

View File

@@ -7,7 +7,11 @@ from funcy import project
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from redash import models from redash import models
from redash.handlers.base import BaseResource, get_object_or_404, require_fields from redash.handlers.base import (
BaseResource,
get_object_or_404,
require_fields,
)
from redash.permissions import ( from redash.permissions import (
require_access, require_access,
require_admin, require_admin,
@@ -17,28 +21,22 @@ from redash.permissions import (
from redash.query_runner import ( from redash.query_runner import (
get_configuration_schema_for_query_runner_type, get_configuration_schema_for_query_runner_type,
query_runners, query_runners,
NotSupported,
) )
from redash.serializers import serialize_job
from redash.tasks.general import get_schema, test_connection
from redash.utils import filter_none from redash.utils import filter_none
from redash.utils.configuration import ConfigurationContainer, ValidationError from redash.utils.configuration import ConfigurationContainer, ValidationError
from redash.tasks.general import test_connection, get_schema
from redash.serializers import serialize_job
class DataSourceTypeListResource(BaseResource): class DataSourceTypeListResource(BaseResource):
@require_admin @require_admin
def get(self): def get(self):
return [ return [q.to_dict() for q in sorted(query_runners.values(), key=lambda q: q.name().lower())]
q.to_dict()
for q in sorted(query_runners.values(), key=lambda q: q.name().lower())
]
class DataSourceResource(BaseResource): class DataSourceResource(BaseResource):
def get(self, data_source_id): def get(self, data_source_id):
data_source = get_object_or_404( data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, self.current_org)
models.DataSource.get_by_id_and_org, data_source_id, self.current_org
)
require_access(data_source, self.current_user, view_only) require_access(data_source, self.current_user, view_only)
ds = {} ds = {}
@@ -47,19 +45,13 @@ class DataSourceResource(BaseResource):
ds = data_source.to_dict(all=self.current_user.has_permission("admin")) ds = data_source.to_dict(all=self.current_user.has_permission("admin"))
# add view_only info, required for frontend permissions # add view_only info, required for frontend permissions
ds["view_only"] = all( ds["view_only"] = all(project(data_source.groups, self.current_user.group_ids).values())
project(data_source.groups, self.current_user.group_ids).values() self.record_event({"action": "view", "object_id": data_source_id, "object_type": "datasource"})
)
self.record_event(
{"action": "view", "object_id": data_source_id, "object_type": "datasource"}
)
return ds return ds
@require_admin @require_admin
def post(self, data_source_id): def post(self, data_source_id):
data_source = models.DataSource.get_by_id_and_org( data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org)
data_source_id, self.current_org
)
req = request.get_json(True) req = request.get_json(True)
schema = get_configuration_schema_for_query_runner_type(req["type"]) schema = get_configuration_schema_for_query_runner_type(req["type"])
@@ -81,24 +73,18 @@ class DataSourceResource(BaseResource):
if req["name"] in str(e): if req["name"] in str(e):
abort( abort(
400, 400,
message="Data source with the name {} already exists.".format( message="Data source with the name {} already exists.".format(req["name"]),
req["name"]
),
) )
abort(400) abort(400)
self.record_event( self.record_event({"action": "edit", "object_id": data_source.id, "object_type": "datasource"})
{"action": "edit", "object_id": data_source.id, "object_type": "datasource"}
)
return data_source.to_dict(all=True) return data_source.to_dict(all=True)
@require_admin @require_admin
def delete(self, data_source_id): def delete(self, data_source_id):
data_source = models.DataSource.get_by_id_and_org( data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org)
data_source_id, self.current_org
)
data_source.delete() data_source.delete()
self.record_event( self.record_event(
@@ -118,9 +104,7 @@ class DataSourceListResource(BaseResource):
if self.current_user.has_permission("admin"): if self.current_user.has_permission("admin"):
data_sources = models.DataSource.all(self.current_org) data_sources = models.DataSource.all(self.current_org)
else: else:
data_sources = models.DataSource.all( data_sources = models.DataSource.all(self.current_org, group_ids=self.current_user.group_ids)
self.current_org, group_ids=self.current_user.group_ids
)
response = {} response = {}
for ds in data_sources: for ds in data_sources:
@@ -129,14 +113,10 @@ class DataSourceListResource(BaseResource):
try: try:
d = ds.to_dict() d = ds.to_dict()
d["view_only"] = all( d["view_only"] = all(project(ds.groups, self.current_user.group_ids).values())
project(ds.groups, self.current_user.group_ids).values()
)
response[ds.id] = d response[ds.id] = d
except AttributeError: except AttributeError:
logging.exception( logging.exception("Error with DataSource#to_dict (data source id: %d)", ds.id)
"Error with DataSource#to_dict (data source id: %d)", ds.id
)
self.record_event( self.record_event(
{ {
@@ -171,9 +151,7 @@ class DataSourceListResource(BaseResource):
if req["name"] in str(e): if req["name"] in str(e):
abort( abort(
400, 400,
message="Data source with the name {} already exists.".format( message="Data source with the name {} already exists.".format(req["name"]),
req["name"]
),
) )
abort(400) abort(400)
@@ -191,9 +169,7 @@ class DataSourceListResource(BaseResource):
class DataSourceSchemaResource(BaseResource): class DataSourceSchemaResource(BaseResource):
def get(self, data_source_id): def get(self, data_source_id):
data_source = get_object_or_404( data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, self.current_org)
models.DataSource.get_by_id_and_org, data_source_id, self.current_org
)
require_access(data_source, self.current_user, view_only) require_access(data_source, self.current_user, view_only)
refresh = request.args.get("refresh") is not None refresh = request.args.get("refresh") is not None
@@ -211,9 +187,7 @@ class DataSourceSchemaResource(BaseResource):
class DataSourcePauseResource(BaseResource): class DataSourcePauseResource(BaseResource):
@require_admin @require_admin
def post(self, data_source_id): def post(self, data_source_id):
data_source = get_object_or_404( data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, self.current_org)
models.DataSource.get_by_id_and_org, data_source_id, self.current_org
)
data = request.get_json(force=True, silent=True) data = request.get_json(force=True, silent=True)
if data: if data:
reason = data.get("reason") reason = data.get("reason")
@@ -233,9 +207,7 @@ class DataSourcePauseResource(BaseResource):
@require_admin @require_admin
def delete(self, data_source_id): def delete(self, data_source_id):
data_source = get_object_or_404( data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, self.current_org)
models.DataSource.get_by_id_and_org, data_source_id, self.current_org
)
data_source.resume() data_source.resume()
self.record_event( self.record_event(
@@ -251,9 +223,7 @@ class DataSourcePauseResource(BaseResource):
class DataSourceTestResource(BaseResource): class DataSourceTestResource(BaseResource):
@require_admin @require_admin
def post(self, data_source_id): def post(self, data_source_id):
data_source = get_object_or_404( data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, self.current_org)
models.DataSource.get_by_id_and_org, data_source_id, self.current_org
)
response = {} response = {}

View File

@@ -1,25 +1,21 @@
from flask_restful import abort
from flask import request from flask import request
from flask_restful import abort
from redash import models, redis_connection from redash import models, redis_connection
from redash.handlers.base import BaseResource, get_object_or_404 from redash.handlers.base import BaseResource, get_object_or_404
from redash.permissions import ( from redash.permissions import require_access, view_only
require_access,
view_only,
)
from redash.tasks.databricks import (
get_databricks_databases,
get_databricks_tables,
get_database_tables_with_columns,
get_databricks_table_columns,
)
from redash.serializers import serialize_job from redash.serializers import serialize_job
from redash.utils import json_loads, json_dumps from redash.tasks.databricks import (
get_database_tables_with_columns,
get_databricks_databases,
get_databricks_table_columns,
get_databricks_tables,
)
from redash.utils import json_loads
def _get_databricks_data_source(data_source_id, user, org): def _get_databricks_data_source(data_source_id, user, org):
data_source = get_object_or_404( data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, org)
models.DataSource.get_by_id_and_org, data_source_id, org
)
require_access(data_source, user, view_only) require_access(data_source, user, view_only)
if not data_source.type == "databricks": if not data_source.type == "databricks":
@@ -48,9 +44,7 @@ def _get_tables_from_cache(data_source_id, database_name):
class DatabricksDatabaseListResource(BaseResource): class DatabricksDatabaseListResource(BaseResource):
def get(self, data_source_id): def get(self, data_source_id):
data_source = _get_databricks_data_source( data_source = _get_databricks_data_source(data_source_id, user=self.current_user, org=self.current_org)
data_source_id, user=self.current_user, org=self.current_org
)
refresh = request.args.get("refresh") is not None refresh = request.args.get("refresh") is not None
if not refresh: if not refresh:
@@ -59,17 +53,13 @@ class DatabricksDatabaseListResource(BaseResource):
if cached_databases is not None: if cached_databases is not None:
return cached_databases return cached_databases
job = get_databricks_databases.delay( job = get_databricks_databases.delay(data_source.id, redis_key=_databases_key(data_source_id))
data_source.id, redis_key=_databases_key(data_source_id)
)
return serialize_job(job) return serialize_job(job)
class DatabricksSchemaResource(BaseResource): class DatabricksSchemaResource(BaseResource):
def get(self, data_source_id, database_name): def get(self, data_source_id, database_name):
data_source = _get_databricks_data_source( data_source = _get_databricks_data_source(data_source_id, user=self.current_user, org=self.current_org)
data_source_id, user=self.current_user, org=self.current_org
)
refresh = request.args.get("refresh") is not None refresh = request.args.get("refresh") is not None
if not refresh: if not refresh:
@@ -89,9 +79,7 @@ class DatabricksSchemaResource(BaseResource):
class DatabricksTableColumnListResource(BaseResource): class DatabricksTableColumnListResource(BaseResource):
def get(self, data_source_id, database_name, table_name): def get(self, data_source_id, database_name, table_name):
data_source = _get_databricks_data_source( data_source = _get_databricks_data_source(data_source_id, user=self.current_user, org=self.current_org)
data_source_id, user=self.current_user, org=self.current_org
)
job = get_databricks_table_columns.delay(data_source.id, database_name, table_name) job = get_databricks_table_columns.delay(data_source.id, database_name, table_name)
return serialize_job(job) return serialize_job(job)

View File

@@ -21,9 +21,7 @@ class DestinationTypeListResource(BaseResource):
class DestinationResource(BaseResource): class DestinationResource(BaseResource):
@require_admin @require_admin
def get(self, destination_id): def get(self, destination_id):
destination = models.NotificationDestination.get_by_id_and_org( destination = models.NotificationDestination.get_by_id_and_org(destination_id, self.current_org)
destination_id, self.current_org
)
d = destination.to_dict(all=True) d = destination.to_dict(all=True)
self.record_event( self.record_event(
{ {
@@ -36,9 +34,7 @@ class DestinationResource(BaseResource):
@require_admin @require_admin
def post(self, destination_id): def post(self, destination_id):
destination = models.NotificationDestination.get_by_id_and_org( destination = models.NotificationDestination.get_by_id_and_org(destination_id, self.current_org)
destination_id, self.current_org
)
req = request.get_json(True) req = request.get_json(True)
schema = get_configuration_schema_for_destination_type(req["type"]) schema = get_configuration_schema_for_destination_type(req["type"])
@@ -58,9 +54,7 @@ class DestinationResource(BaseResource):
if "name" in str(e): if "name" in str(e):
abort( abort(
400, 400,
message="Alert Destination with the name {} already exists.".format( message="Alert Destination with the name {} already exists.".format(req["name"]),
req["name"]
),
) )
abort(500) abort(500)
@@ -68,9 +62,7 @@ class DestinationResource(BaseResource):
@require_admin @require_admin
def delete(self, destination_id): def delete(self, destination_id):
destination = models.NotificationDestination.get_by_id_and_org( destination = models.NotificationDestination.get_by_id_and_org(destination_id, self.current_org)
destination_id, self.current_org
)
models.db.session.delete(destination) models.db.session.delete(destination)
models.db.session.commit() models.db.session.commit()
@@ -135,9 +127,7 @@ class DestinationListResource(BaseResource):
if "name" in str(e): if "name" in str(e):
abort( abort(
400, 400,
message="Alert Destination with the name {} already exists.".format( message="Alert Destination with the name {} already exists.".format(req["name"]),
req["name"]
),
) )
abort(500) abort(500)

View File

@@ -1,13 +1,18 @@
from flask import request from flask import request
from .authentication import current_org
from flask_login import current_user, login_required from flask_login import current_user, login_required
from redash import models from redash import models
from redash.handlers import routes from redash.handlers import routes
from redash.handlers.base import get_object_or_404, org_scoped_rule, record_event from redash.handlers.base import (
get_object_or_404,
org_scoped_rule,
record_event,
)
from redash.handlers.static import render_index from redash.handlers.static import render_index
from redash.security import csp_allows_embeding from redash.security import csp_allows_embeding
from .authentication import current_org
@routes.route( @routes.route(
org_scoped_rule("/embed/query/<query_id>/visualization/<visualization_id>"), org_scoped_rule("/embed/query/<query_id>/visualization/<visualization_id>"),

View File

@@ -1,6 +1,6 @@
from flask import request
import geolite2 import geolite2
import maxminddb import maxminddb
from flask import request
from user_agents import parse as parse_ua from user_agents import parse as parse_ua
from redash.handlers.base import BaseResource, paginate from redash.handlers.base import BaseResource, paginate
@@ -44,9 +44,7 @@ def serialize_event(event):
} }
if event.user_id: if event.user_id:
d["user_name"] = event.additional_properties.get( d["user_name"] = event.additional_properties.get("user_name", "User {}".format(event.user_id))
"user_name", "User {}".format(event.user_id)
)
if not event.user_id: if not event.user_id:
d["user_name"] = event.additional_properties.get("api_key", "Unknown") d["user_name"] = event.additional_properties.get("api_key", "Unknown")

View File

@@ -1,21 +1,16 @@
from flask import request
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from redash import models from redash import models
from redash.handlers.base import BaseResource, get_object_or_404, paginate from redash.handlers.base import BaseResource, get_object_or_404
from redash.permissions import require_access, view_only from redash.permissions import require_access, view_only
class QueryFavoriteResource(BaseResource): class QueryFavoriteResource(BaseResource):
def post(self, query_id): def post(self, query_id):
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
require_access(query, self.current_user, view_only) require_access(query, self.current_user, view_only)
fav = models.Favorite( fav = models.Favorite(org_id=self.current_org.id, object=query, user=self.current_user)
org_id=self.current_org.id, object=query, user=self.current_user
)
models.db.session.add(fav) models.db.session.add(fav)
try: try:
@@ -26,14 +21,10 @@ class QueryFavoriteResource(BaseResource):
else: else:
raise e raise e
self.record_event( self.record_event({"action": "favorite", "object_id": query.id, "object_type": "query"})
{"action": "favorite", "object_id": query.id, "object_type": "query"}
)
def delete(self, query_id): def delete(self, query_id):
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
require_access(query, self.current_user, view_only) require_access(query, self.current_user, view_only)
models.Favorite.query.filter( models.Favorite.query.filter(
@@ -43,19 +34,13 @@ class QueryFavoriteResource(BaseResource):
).delete() ).delete()
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "favorite", "object_id": query.id, "object_type": "query"})
{"action": "favorite", "object_id": query.id, "object_type": "query"}
)
class DashboardFavoriteResource(BaseResource): class DashboardFavoriteResource(BaseResource):
def post(self, object_id): def post(self, object_id):
dashboard = get_object_or_404( dashboard = get_object_or_404(models.Dashboard.get_by_id_and_org, object_id, self.current_org)
models.Dashboard.get_by_id_and_org, object_id, self.current_org fav = models.Favorite(org_id=self.current_org.id, object=dashboard, user=self.current_user)
)
fav = models.Favorite(
org_id=self.current_org.id, object=dashboard, user=self.current_user
)
models.db.session.add(fav) models.db.session.add(fav)
try: try:
@@ -75,9 +60,7 @@ class DashboardFavoriteResource(BaseResource):
) )
def delete(self, object_id): def delete(self, object_id):
dashboard = get_object_or_404( dashboard = get_object_or_404(models.Dashboard.get_by_id_and_org, object_id, self.current_org)
models.Dashboard.get_by_id_and_org, object_id, self.current_org
)
models.Favorite.query.filter( models.Favorite.query.filter(
models.Favorite.object == dashboard, models.Favorite.object == dashboard,
models.Favorite.user == self.current_user, models.Favorite.user == self.current_user,

View File

@@ -1,9 +1,9 @@
import time
from flask import request from flask import request
from flask_restful import abort from flask_restful import abort
from redash import models from redash import models
from redash.permissions import require_admin, require_permission
from redash.handlers.base import BaseResource, get_object_or_404 from redash.handlers.base import BaseResource, get_object_or_404
from redash.permissions import require_admin, require_permission
class GroupListResource(BaseResource): class GroupListResource(BaseResource):
@@ -14,9 +14,7 @@ class GroupListResource(BaseResource):
models.db.session.add(group) models.db.session.add(group)
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "create", "object_id": group.id, "object_type": "group"})
{"action": "create", "object_id": group.id, "object_type": "group"}
)
return group.to_dict() return group.to_dict()
@@ -24,13 +22,9 @@ class GroupListResource(BaseResource):
if self.current_user.has_permission("admin"): if self.current_user.has_permission("admin"):
groups = models.Group.all(self.current_org) groups = models.Group.all(self.current_org)
else: else:
groups = models.Group.query.filter( groups = models.Group.query.filter(models.Group.id.in_(self.current_user.group_ids))
models.Group.id.in_(self.current_user.group_ids)
)
self.record_event( self.record_event({"action": "list", "object_id": "groups", "object_type": "group"})
{"action": "list", "object_id": "groups", "object_type": "group"}
)
return [g.to_dict() for g in groups] return [g.to_dict() for g in groups]
@@ -46,24 +40,17 @@ class GroupResource(BaseResource):
group.name = request.json["name"] group.name = request.json["name"]
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "edit", "object_id": group.id, "object_type": "group"})
{"action": "edit", "object_id": group.id, "object_type": "group"}
)
return group.to_dict() return group.to_dict()
def get(self, group_id): def get(self, group_id):
if not ( if not (self.current_user.has_permission("admin") or int(group_id) in self.current_user.group_ids):
self.current_user.has_permission("admin")
or int(group_id) in self.current_user.group_ids
):
abort(403) abort(403)
group = models.Group.get_by_id_and_org(group_id, self.current_org) group = models.Group.get_by_id_and_org(group_id, self.current_org)
self.record_event( self.record_event({"action": "view", "object_id": group_id, "object_type": "group"})
{"action": "view", "object_id": group_id, "object_type": "group"}
)
return group.to_dict() return group.to_dict()
@@ -103,10 +90,7 @@ class GroupMemberListResource(BaseResource):
@require_permission("list_users") @require_permission("list_users")
def get(self, group_id): def get(self, group_id):
if not ( if not (self.current_user.has_permission("admin") or int(group_id) in self.current_user.group_ids):
self.current_user.has_permission("admin")
or int(group_id) in self.current_user.group_ids
):
abort(403) abort(403)
members = models.Group.members(group_id) members = models.Group.members(group_id)
@@ -140,9 +124,7 @@ class GroupDataSourceListResource(BaseResource):
@require_admin @require_admin
def post(self, group_id): def post(self, group_id):
data_source_id = request.json["data_source_id"] data_source_id = request.json["data_source_id"]
data_source = models.DataSource.get_by_id_and_org( data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org)
data_source_id, self.current_org
)
group = models.Group.get_by_id_and_org(group_id, self.current_org) group = models.Group.get_by_id_and_org(group_id, self.current_org)
data_source_group = data_source.add_group(group) data_source_group = data_source.add_group(group)
@@ -161,18 +143,14 @@ class GroupDataSourceListResource(BaseResource):
@require_admin @require_admin
def get(self, group_id): def get(self, group_id):
group = get_object_or_404( group = get_object_or_404(models.Group.get_by_id_and_org, group_id, self.current_org)
models.Group.get_by_id_and_org, group_id, self.current_org
)
# TOOD: move to models # TOOD: move to models
data_sources = models.DataSource.query.join(models.DataSourceGroup).filter( data_sources = models.DataSource.query.join(models.DataSourceGroup).filter(
models.DataSourceGroup.group == group models.DataSourceGroup.group == group
) )
self.record_event( self.record_event({"action": "list", "object_id": group_id, "object_type": "group"})
{"action": "list", "object_id": group_id, "object_type": "group"}
)
return [ds.to_dict(with_permissions_for=group) for ds in data_sources] return [ds.to_dict(with_permissions_for=group) for ds in data_sources]
@@ -180,9 +158,7 @@ class GroupDataSourceListResource(BaseResource):
class GroupDataSourceResource(BaseResource): class GroupDataSourceResource(BaseResource):
@require_admin @require_admin
def post(self, group_id, data_source_id): def post(self, group_id, data_source_id):
data_source = models.DataSource.get_by_id_and_org( data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org)
data_source_id, self.current_org
)
group = models.Group.get_by_id_and_org(group_id, self.current_org) group = models.Group.get_by_id_and_org(group_id, self.current_org)
view_only = request.json["view_only"] view_only = request.json["view_only"]
@@ -203,9 +179,7 @@ class GroupDataSourceResource(BaseResource):
@require_admin @require_admin
def delete(self, group_id, data_source_id): def delete(self, group_id, data_source_id):
data_source = models.DataSource.get_by_id_and_org( data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org)
data_source_id, self.current_org
)
group = models.Group.get_by_id_and_org(group_id, self.current_org) group = models.Group.get_by_id_and_org(group_id, self.current_org)
data_source.remove_group(group) data_source.remove_group(group)

View File

@@ -1,9 +1,9 @@
from flask_login import current_user, login_required from flask_login import current_user, login_required
from redash import models from redash import models
from redash.authentication import current_org
from redash.handlers import routes from redash.handlers import routes
from redash.handlers.base import json_response, org_scoped_rule from redash.handlers.base import json_response, org_scoped_rule
from redash.authentication import current_org
@routes.route(org_scoped_rule("/api/organization/status"), methods=["GET"]) @routes.route(org_scoped_rule("/api/organization/status"), methods=["GET"])
@@ -12,14 +12,10 @@ def organization_status(org_slug=None):
counters = { counters = {
"users": models.User.all(current_org).count(), "users": models.User.all(current_org).count(),
"alerts": models.Alert.all(group_ids=current_user.group_ids).count(), "alerts": models.Alert.all(group_ids=current_user.group_ids).count(),
"data_sources": models.DataSource.all( "data_sources": models.DataSource.all(current_org, group_ids=current_user.group_ids).count(),
current_org, group_ids=current_user.group_ids "queries": models.Query.all_queries(current_user.group_ids, current_user.id, include_drafts=True).count(),
).count(),
"queries": models.Query.all_queries(
current_user.group_ids, current_user.id, include_drafts=True
).count(),
"dashboards": models.Dashboard.query.filter( "dashboards": models.Dashboard.query.filter(
models.Dashboard.org == current_org, models.Dashboard.is_archived == False models.Dashboard.org == current_org, models.Dashboard.is_archived is False
).count(), ).count(),
} }

View File

@@ -1,12 +1,12 @@
from collections import defaultdict from collections import defaultdict
from redash.handlers.base import BaseResource, get_object_or_404
from redash.models import AccessPermission, Query, Dashboard, User, db
from redash.permissions import require_admin_or_owner, ACCESS_TYPES
from flask import request from flask import request
from flask_restful import abort from flask_restful import abort
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from redash.handlers.base import BaseResource, get_object_or_404
from redash.models import AccessPermission, Dashboard, Query, User, db
from redash.permissions import ACCESS_TYPES, require_admin_or_owner
model_to_types = {"queries": Query, "dashboards": Dashboard} model_to_types = {"queries": Query, "dashboards": Dashboard}
@@ -51,9 +51,7 @@ class ObjectPermissionsListResource(BaseResource):
except NoResultFound: except NoResultFound:
abort(400, message="User not found.") abort(400, message="User not found.")
permission = AccessPermission.grant( permission = AccessPermission.grant(obj, access_type, grantee, self.current_user)
obj, access_type, grantee, self.current_user
)
db.session.commit() db.session.commit()
self.record_event( self.record_event(

View File

@@ -2,8 +2,8 @@ import sqlparse
from flask import jsonify, request, url_for from flask import jsonify, request, url_for
from flask_login import login_required from flask_login import login_required
from flask_restful import abort from flask_restful import abort
from sqlalchemy.orm.exc import StaleDataError
from funcy import partial from funcy import partial
from sqlalchemy.orm.exc import StaleDataError
from redash import models, settings from redash import models, settings
from redash.authentication.org_resolving import current_org from redash.authentication.org_resolving import current_org
@@ -11,12 +11,11 @@ from redash.handlers.base import (
BaseResource, BaseResource,
filter_by_tags, filter_by_tags,
get_object_or_404, get_object_or_404,
org_scoped_rule,
paginate,
routes,
order_results as _order_results,
) )
from redash.handlers.base import order_results as _order_results
from redash.handlers.base import org_scoped_rule, paginate, routes
from redash.handlers.query_results import run_query from redash.handlers.query_results import run_query
from redash.models.parameterized_query import ParameterizedQuery
from redash.permissions import ( from redash.permissions import (
can_modify, can_modify,
not_view_only, not_view_only,
@@ -26,10 +25,8 @@ from redash.permissions import (
require_permission, require_permission,
view_only, view_only,
) )
from redash.utils import collect_parameters_from_request
from redash.serializers import QuerySerializer from redash.serializers import QuerySerializer
from redash.models.parameterized_query import ParameterizedQuery from redash.utils import collect_parameters_from_request
# Ordering map for relationships # Ordering map for relationships
order_map = { order_map = {
@@ -47,9 +44,7 @@ order_map = {
"-created_by": "-users-name", "-created_by": "-users-name",
} }
order_results = partial( order_results = partial(_order_results, default_order="-created_at", allowed_orders=order_map)
_order_results, default_order="-created_at", allowed_orders=order_map
)
@routes.route(org_scoped_rule("/api/queries/format"), methods=["POST"]) @routes.route(org_scoped_rule("/api/queries/format"), methods=["POST"])
@@ -64,9 +59,7 @@ def format_sql_query(org_slug=None):
arguments = request.get_json(force=True) arguments = request.get_json(force=True)
query = arguments.get("query", "") query = arguments.get("query", "")
return jsonify( return jsonify({"query": sqlparse.format(query, **settings.SQLPARSE_FORMAT_OPTIONS)})
{"query": sqlparse.format(query, **settings.SQLPARSE_FORMAT_OPTIONS)}
)
class QuerySearchResource(BaseResource): class QuerySearchResource(BaseResource):
@@ -107,14 +100,8 @@ class QueryRecentResource(BaseResource):
Responds with a list of :ref:`query <query-response-label>` objects. Responds with a list of :ref:`query <query-response-label>` objects.
""" """
results = ( results = models.Query.by_user(self.current_user).order_by(models.Query.updated_at.desc()).limit(10)
models.Query.by_user(self.current_user) return QuerySerializer(results, with_last_modified_by=False, with_user=False).serialize()
.order_by(models.Query.updated_at.desc())
.limit(10)
)
return QuerySerializer(
results, with_last_modified_by=False, with_user=False
).serialize()
class BaseQueryListResource(BaseResource): class BaseQueryListResource(BaseResource):
@@ -128,9 +115,7 @@ class BaseQueryListResource(BaseResource):
multi_byte_search=current_org.get_setting("multi_byte_search_enabled"), multi_byte_search=current_org.get_setting("multi_byte_search_enabled"),
) )
else: else:
results = models.Query.all_queries( results = models.Query.all_queries(self.current_user.group_ids, self.current_user.id, include_drafts=True)
self.current_user.group_ids, self.current_user.id, include_drafts=True
)
return filter_by_tags(results, models.Query.tags) return filter_by_tags(results, models.Query.tags)
@require_permission("view_query") @require_permission("view_query")
@@ -170,9 +155,7 @@ class BaseQueryListResource(BaseResource):
) )
if search_term: if search_term:
self.record_event( self.record_event({"action": "search", "object_type": "query", "term": search_term})
{"action": "search", "object_type": "query", "term": search_term}
)
else: else:
self.record_event({"action": "list", "object_type": "query"}) self.record_event({"action": "list", "object_type": "query"})
@@ -181,9 +164,7 @@ class BaseQueryListResource(BaseResource):
def require_access_to_dropdown_queries(user, query_def): def require_access_to_dropdown_queries(user, query_def):
parameters = query_def.get("options", {}).get("parameters", []) parameters = query_def.get("options", {}).get("parameters", [])
dropdown_query_ids = set( dropdown_query_ids = set([str(p["queryId"]) for p in parameters if p["type"] == "query"])
[str(p["queryId"]) for p in parameters if p["type"] == "query"]
)
if dropdown_query_ids: if dropdown_query_ids:
groups = models.Query.all_groups_for_query_ids(dropdown_query_ids) groups = models.Query.all_groups_for_query_ids(dropdown_query_ids)
@@ -234,9 +215,7 @@ class QueryListResource(BaseQueryListResource):
:>json number runtime: Runtime of last query execution, in seconds (may be null) :>json number runtime: Runtime of last query execution, in seconds (may be null)
""" """
query_def = request.get_json(force=True) query_def = request.get_json(force=True)
data_source = models.DataSource.get_by_id_and_org( data_source = models.DataSource.get_by_id_and_org(query_def.pop("data_source_id"), self.current_org)
query_def.pop("data_source_id"), self.current_org
)
require_access(data_source, self.current_user, not_view_only) require_access(data_source, self.current_user, not_view_only)
require_access_to_dropdown_queries(self.current_user, query_def) require_access_to_dropdown_queries(self.current_user, query_def)
@@ -259,9 +238,7 @@ class QueryListResource(BaseQueryListResource):
models.db.session.add(query) models.db.session.add(query)
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "create", "object_id": query.id, "object_type": "query"})
{"action": "create", "object_id": query.id, "object_type": "query"}
)
return QuerySerializer(query, with_visualizations=True).serialize() return QuerySerializer(query, with_visualizations=True).serialize()
@@ -340,9 +317,7 @@ class QueryResource(BaseResource):
Responds with the updated :ref:`query <query-response-label>` object. Responds with the updated :ref:`query <query-response-label>` object.
""" """
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
query_def = request.get_json(force=True) query_def = request.get_json(force=True)
require_object_modify_permission(query, self.current_user) require_object_modify_permission(query, self.current_user)
@@ -367,9 +342,7 @@ class QueryResource(BaseResource):
query_def["tags"] = [tag for tag in query_def["tags"] if tag] query_def["tags"] = [tag for tag in query_def["tags"] if tag]
if "data_source_id" in query_def: if "data_source_id" in query_def:
data_source = models.DataSource.get_by_id_and_org( data_source = models.DataSource.get_by_id_and_org(query_def["data_source_id"], self.current_org)
query_def["data_source_id"], self.current_org
)
require_access(data_source, self.current_user, not_view_only) require_access(data_source, self.current_user, not_view_only)
query_def["last_modified_by"] = self.current_user query_def["last_modified_by"] = self.current_user
@@ -397,17 +370,13 @@ class QueryResource(BaseResource):
Responds with the :ref:`query <query-response-label>` contents. Responds with the :ref:`query <query-response-label>` contents.
""" """
q = get_object_or_404( q = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
require_access(q, self.current_user, view_only) require_access(q, self.current_user, view_only)
result = QuerySerializer(q, with_visualizations=True).serialize() result = QuerySerializer(q, with_visualizations=True).serialize()
result["can_edit"] = can_modify(q, self.current_user) result["can_edit"] = can_modify(q, self.current_user)
self.record_event( self.record_event({"action": "view", "object_id": query_id, "object_type": "query"})
{"action": "view", "object_id": query_id, "object_type": "query"}
)
return result return result
@@ -418,9 +387,7 @@ class QueryResource(BaseResource):
:param query_id: ID of query to archive :param query_id: ID of query to archive
""" """
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
require_admin_or_owner(query.user_id) require_admin_or_owner(query.user_id)
query.archive(self.current_user) query.archive(self.current_user)
models.db.session.commit() models.db.session.commit()
@@ -429,9 +396,7 @@ class QueryResource(BaseResource):
class QueryRegenerateApiKeyResource(BaseResource): class QueryRegenerateApiKeyResource(BaseResource):
@require_permission("edit_query") @require_permission("edit_query")
def post(self, query_id): def post(self, query_id):
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
require_admin_or_owner(query.user_id) require_admin_or_owner(query.user_id)
query.regenerate_api_key() query.regenerate_api_key()
models.db.session.commit() models.db.session.commit()
@@ -458,16 +423,12 @@ class QueryForkResource(BaseResource):
Responds with created :ref:`query <query-response-label>` object. Responds with created :ref:`query <query-response-label>` object.
""" """
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
require_access(query.data_source, self.current_user, not_view_only) require_access(query.data_source, self.current_user, not_view_only)
forked_query = query.fork(self.current_user) forked_query = query.fork(self.current_user)
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "fork", "object_id": query_id, "object_type": "query"})
{"action": "fork", "object_id": query_id, "object_type": "query"}
)
return QuerySerializer(forked_query, with_visualizations=True).serialize() return QuerySerializer(forked_query, with_visualizations=True).serialize()
@@ -487,17 +448,13 @@ class QueryRefreshResource(BaseResource):
if self.current_user.is_api_user(): if self.current_user.is_api_user():
abort(403, message="Please use a user API key.") abort(403, message="Please use a user API key.")
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
require_access(query, self.current_user, not_view_only) require_access(query, self.current_user, not_view_only)
parameter_values = collect_parameters_from_request(request.args) parameter_values = collect_parameters_from_request(request.args)
parameterized_query = ParameterizedQuery(query.query_text, org=self.current_org) parameterized_query = ParameterizedQuery(query.query_text, org=self.current_org)
should_apply_auto_limit = query.options.get("apply_auto_limit", False) should_apply_auto_limit = query.options.get("apply_auto_limit", False)
return run_query( return run_query(parameterized_query, parameter_values, query.data_source, query.id, should_apply_auto_limit)
parameterized_query, parameter_values, query.data_source, query.id, should_apply_auto_limit
)
class QueryTagsResource(BaseResource): class QueryTagsResource(BaseResource):

View File

@@ -1,41 +1,39 @@
import logging
import time
import unicodedata import unicodedata
from flask import make_response, request from flask import make_response, request
from flask_login import current_user from flask_login import current_user
from flask_restful import abort from flask_restful import abort
from werkzeug.urls import url_quote from werkzeug.urls import url_quote
from redash import models, settings from redash import models, settings
from redash.handlers.base import BaseResource, get_object_or_404, record_event from redash.handlers.base import BaseResource, get_object_or_404, record_event
from redash.models.parameterized_query import (
InvalidParameterError,
ParameterizedQuery,
QueryDetachedFromDataSourceError,
dropdown_values,
)
from redash.permissions import ( from redash.permissions import (
has_access, has_access,
not_view_only, not_view_only,
require_access, require_access,
require_permission,
require_any_of_permission, require_any_of_permission,
require_permission,
view_only, view_only,
) )
from redash.serializers import (
serialize_job,
serialize_query_result,
serialize_query_result_to_dsv,
serialize_query_result_to_xlsx,
)
from redash.tasks import Job from redash.tasks import Job
from redash.tasks.queries import enqueue_query from redash.tasks.queries import enqueue_query
from redash.utils import ( from redash.utils import (
collect_parameters_from_request, collect_parameters_from_request,
json_dumps, json_dumps,
utcnow,
to_filename, to_filename,
) )
from redash.models.parameterized_query import (
ParameterizedQuery,
InvalidParameterError,
QueryDetachedFromDataSourceError,
dropdown_values,
)
from redash.serializers import (
serialize_query_result,
serialize_query_result_to_dsv,
serialize_query_result_to_xlsx,
serialize_job,
)
def error_response(message, http_status=400): def error_response(message, http_status=400):
@@ -51,23 +49,15 @@ error_messages = {
"This query contains potentially unsafe parameters and cannot be executed with read-only access to this data source.", "This query contains potentially unsafe parameters and cannot be executed with read-only access to this data source.",
403, 403,
), ),
"no_permission": error_response( "no_permission": error_response("You do not have permission to run queries with this data source.", 403),
"You do not have permission to run queries with this data source.", 403 "select_data_source": error_response("Please select data source to run this query.", 401),
),
"select_data_source": error_response(
"Please select data source to run this query.", 401
),
} }
def run_query( def run_query(query, parameters, data_source, query_id, should_apply_auto_limit, max_age=0):
query, parameters, data_source, query_id, should_apply_auto_limit, max_age=0
):
if data_source.paused: if data_source.paused:
if data_source.pause_reason: if data_source.pause_reason:
message = "{} is paused ({}). Please try later.".format( message = "{} is paused ({}). Please try later.".format(data_source.name, data_source.pause_reason)
data_source.name, data_source.pause_reason
)
else: else:
message = "{} is paused. Please try later.".format(data_source.name) message = "{} is paused. Please try later.".format(data_source.name)
@@ -78,14 +68,10 @@ def run_query(
except (InvalidParameterError, QueryDetachedFromDataSourceError) as e: except (InvalidParameterError, QueryDetachedFromDataSourceError) as e:
abort(400, message=str(e)) abort(400, message=str(e))
query_text = data_source.query_runner.apply_auto_limit( query_text = data_source.query_runner.apply_auto_limit(query.text, should_apply_auto_limit)
query.text, should_apply_auto_limit
)
if query.missing_params: if query.missing_params:
return error_response( return error_response("Missing parameter value for: {}".format(", ".join(query.missing_params)))
"Missing parameter value for: {}".format(", ".join(query.missing_params))
)
if max_age == 0: if max_age == 0:
query_result = None query_result = None
@@ -107,11 +93,7 @@ def run_query(
) )
if query_result: if query_result:
return { return {"query_result": serialize_query_result(query_result, current_user.is_api_user())}
"query_result": serialize_query_result(
query_result, current_user.is_api_user()
)
}
else: else:
job = enqueue_query( job = enqueue_query(
query_text, query_text,
@@ -119,9 +101,7 @@ def run_query(
current_user.id, current_user.id,
current_user.is_api_user(), current_user.is_api_user(),
metadata={ metadata={
"Username": repr(current_user) "Username": repr(current_user) if current_user.is_api_user() else current_user.email,
if current_user.is_api_user()
else current_user.email,
"query_id": query_id, "query_id": query_id,
}, },
) )
@@ -145,9 +125,7 @@ def content_disposition_filenames(attachment_filename):
attachment_filename = attachment_filename.encode("ascii") attachment_filename = attachment_filename.encode("ascii")
except UnicodeEncodeError: except UnicodeEncodeError:
filenames = { filenames = {
"filename": unicodedata.normalize("NFKD", attachment_filename).encode( "filename": unicodedata.normalize("NFKD", attachment_filename).encode("ascii", "ignore"),
"ascii", "ignore"
),
"filename*": "UTF-8''%s" % url_quote(attachment_filename, safe=b""), "filename*": "UTF-8''%s" % url_quote(attachment_filename, safe=b""),
} }
else: else:
@@ -180,18 +158,14 @@ class QueryResultListResource(BaseResource):
max_age = -1 max_age = -1
max_age = int(max_age) max_age = int(max_age)
query_id = params.get("query_id", "adhoc") query_id = params.get("query_id", "adhoc")
parameters = params.get( parameters = params.get("parameters", collect_parameters_from_request(request.args))
"parameters", collect_parameters_from_request(request.args)
)
parameterized_query = ParameterizedQuery(query, org=self.current_org) parameterized_query = ParameterizedQuery(query, org=self.current_org)
should_apply_auto_limit = params.get("apply_auto_limit", False) should_apply_auto_limit = params.get("apply_auto_limit", False)
data_source_id = params.get("data_source_id") data_source_id = params.get("data_source_id")
if data_source_id: if data_source_id:
data_source = models.DataSource.get_by_id_and_org( data_source = models.DataSource.get_by_id_and_org(params.get("data_source_id"), self.current_org)
params.get("data_source_id"), self.current_org
)
else: else:
return error_messages["select_data_source"] return error_messages["select_data_source"]
@@ -213,9 +187,7 @@ ONE_YEAR = 60 * 60 * 24 * 365.25
class QueryResultDropdownResource(BaseResource): class QueryResultDropdownResource(BaseResource):
def get(self, query_id): def get(self, query_id):
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
require_access(query.data_source, current_user, view_only) require_access(query.data_source, current_user, view_only)
try: try:
return dropdown_values(query_id, self.current_org) return dropdown_values(query_id, self.current_org)
@@ -225,18 +197,12 @@ class QueryResultDropdownResource(BaseResource):
class QueryDropdownsResource(BaseResource): class QueryDropdownsResource(BaseResource):
def get(self, query_id, dropdown_query_id): def get(self, query_id, dropdown_query_id):
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
require_access(query, current_user, view_only) require_access(query, current_user, view_only)
related_queries_ids = [ related_queries_ids = [p["queryId"] for p in query.parameters if p["type"] == "query"]
p["queryId"] for p in query.parameters if p["type"] == "query"
]
if int(dropdown_query_id) not in related_queries_ids: if int(dropdown_query_id) not in related_queries_ids:
dropdown_query = get_object_or_404( dropdown_query = get_object_or_404(models.Query.get_by_id_and_org, dropdown_query_id, self.current_org)
models.Query.get_by_id_and_org, dropdown_query_id, self.current_org
)
require_access(dropdown_query.data_source, current_user, view_only) require_access(dropdown_query.data_source, current_user, view_only)
return dropdown_values(dropdown_query_id, self.current_org) return dropdown_values(dropdown_query_id, self.current_org)
@@ -250,9 +216,7 @@ class QueryResultResource(BaseResource):
if set(["*", origin]) & settings.ACCESS_CONTROL_ALLOW_ORIGIN: if set(["*", origin]) & settings.ACCESS_CONTROL_ALLOW_ORIGIN:
headers["Access-Control-Allow-Origin"] = origin headers["Access-Control-Allow-Origin"] = origin
headers["Access-Control-Allow-Credentials"] = str( headers["Access-Control-Allow-Credentials"] = str(settings.ACCESS_CONTROL_ALLOW_CREDENTIALS).lower()
settings.ACCESS_CONTROL_ALLOW_CREDENTIALS
).lower()
@require_any_of_permission(("view_query", "execute_query")) @require_any_of_permission(("view_query", "execute_query"))
def options(self, query_id=None, query_result_id=None, filetype="json"): def options(self, query_id=None, query_result_id=None, filetype="json"):
@@ -260,14 +224,10 @@ class QueryResultResource(BaseResource):
self.add_cors_headers(headers) self.add_cors_headers(headers)
if settings.ACCESS_CONTROL_REQUEST_METHOD: if settings.ACCESS_CONTROL_REQUEST_METHOD:
headers[ headers["Access-Control-Request-Method"] = settings.ACCESS_CONTROL_REQUEST_METHOD
"Access-Control-Request-Method"
] = settings.ACCESS_CONTROL_REQUEST_METHOD
if settings.ACCESS_CONTROL_ALLOW_HEADERS: if settings.ACCESS_CONTROL_ALLOW_HEADERS:
headers[ headers["Access-Control-Allow-Headers"] = settings.ACCESS_CONTROL_ALLOW_HEADERS
"Access-Control-Allow-Headers"
] = settings.ACCESS_CONTROL_ALLOW_HEADERS
return make_response("", 200, headers) return make_response("", 200, headers)
@@ -292,16 +252,12 @@ class QueryResultResource(BaseResource):
max_age = -1 max_age = -1
max_age = int(max_age) max_age = int(max_age)
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
allow_executing_with_view_only_permissions = query.parameterized.is_safe allow_executing_with_view_only_permissions = query.parameterized.is_safe
should_apply_auto_limit = params.get("apply_auto_limit", False) should_apply_auto_limit = params.get("apply_auto_limit", False)
if has_access( if has_access(query, self.current_user, allow_executing_with_view_only_permissions):
query, self.current_user, allow_executing_with_view_only_permissions
):
return run_query( return run_query(
query.parameterized, query.parameterized,
parameter_values, parameter_values,
@@ -346,31 +302,19 @@ class QueryResultResource(BaseResource):
query = None query = None
if query_result_id: if query_result_id:
query_result = get_object_or_404( query_result = get_object_or_404(models.QueryResult.get_by_id_and_org, query_result_id, self.current_org)
models.QueryResult.get_by_id_and_org, query_result_id, self.current_org
)
if query_id is not None: if query_id is not None:
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org)
models.Query.get_by_id_and_org, query_id, self.current_org
)
if ( if query_result is None and query is not None and query.latest_query_data_id is not None:
query_result is None
and query is not None
and query.latest_query_data_id is not None
):
query_result = get_object_or_404( query_result = get_object_or_404(
models.QueryResult.get_by_id_and_org, models.QueryResult.get_by_id_and_org,
query.latest_query_data_id, query.latest_query_data_id,
self.current_org, self.current_org,
) )
if ( if query is not None and query_result is not None and self.current_user.is_api_user():
query is not None
and query_result is not None
and self.current_user.is_api_user()
):
if query.query_hash != query_result.query_hash: if query.query_hash != query_result.query_hash:
abort(404, message="No cached result found for this query.") abort(404, message="No cached result found for this query.")
@@ -409,9 +353,7 @@ class QueryResultResource(BaseResource):
self.add_cors_headers(response.headers) self.add_cors_headers(response.headers)
if should_cache: if should_cache:
response.headers.add_header( response.headers.add_header("Cache-Control", "private,max-age=%d" % ONE_YEAR)
"Cache-Control", "private,max-age=%d" % ONE_YEAR
)
filename = get_download_filename(query_result, query, filetype) filename = get_download_filename(query_result, query, filetype)
@@ -432,22 +374,16 @@ class QueryResultResource(BaseResource):
@staticmethod @staticmethod
def make_csv_response(query_result): def make_csv_response(query_result):
headers = {"Content-Type": "text/csv; charset=UTF-8"} headers = {"Content-Type": "text/csv; charset=UTF-8"}
return make_response( return make_response(serialize_query_result_to_dsv(query_result, ","), 200, headers)
serialize_query_result_to_dsv(query_result, ","), 200, headers
)
@staticmethod @staticmethod
def make_tsv_response(query_result): def make_tsv_response(query_result):
headers = {"Content-Type": "text/tab-separated-values; charset=UTF-8"} headers = {"Content-Type": "text/tab-separated-values; charset=UTF-8"}
return make_response( return make_response(serialize_query_result_to_dsv(query_result, "\t"), 200, headers)
serialize_query_result_to_dsv(query_result, "\t"), 200, headers
)
@staticmethod @staticmethod
def make_excel_response(query_result): def make_excel_response(query_result):
headers = { headers = {"Content-Type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"}
"Content-Type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
}
return make_response(serialize_query_result_to_xlsx(query_result), 200, headers) return make_response(serialize_query_result_to_xlsx(query_result), 200, headers)

View File

@@ -2,42 +2,36 @@ from flask import request
from funcy import project from funcy import project
from redash import models from redash import models
from redash.handlers.base import (
BaseResource,
get_object_or_404,
require_fields,
)
from redash.permissions import require_admin_or_owner from redash.permissions import require_admin_or_owner
from redash.handlers.base import BaseResource, require_fields, get_object_or_404
class QuerySnippetResource(BaseResource): class QuerySnippetResource(BaseResource):
def get(self, snippet_id): def get(self, snippet_id):
snippet = get_object_or_404( snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org)
models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org
)
self.record_event( self.record_event({"action": "view", "object_id": snippet_id, "object_type": "query_snippet"})
{"action": "view", "object_id": snippet_id, "object_type": "query_snippet"}
)
return snippet.to_dict() return snippet.to_dict()
def post(self, snippet_id): def post(self, snippet_id):
req = request.get_json(True) req = request.get_json(True)
params = project(req, ("trigger", "description", "snippet")) params = project(req, ("trigger", "description", "snippet"))
snippet = get_object_or_404( snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org)
models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org
)
require_admin_or_owner(snippet.user.id) require_admin_or_owner(snippet.user.id)
self.update_model(snippet, params) self.update_model(snippet, params)
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "edit", "object_id": snippet.id, "object_type": "query_snippet"})
{"action": "edit", "object_id": snippet.id, "object_type": "query_snippet"}
)
return snippet.to_dict() return snippet.to_dict()
def delete(self, snippet_id): def delete(self, snippet_id):
snippet = get_object_or_404( snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org)
models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org
)
require_admin_or_owner(snippet.user.id) require_admin_or_owner(snippet.user.id)
models.db.session.delete(snippet) models.db.session.delete(snippet)
models.db.session.commit() models.db.session.commit()
@@ -79,7 +73,4 @@ class QuerySnippetListResource(BaseResource):
def get(self): def get(self):
self.record_event({"action": "list", "object_type": "query_snippet"}) self.record_event({"action": "list", "object_type": "query_snippet"})
return [ return [snippet.to_dict() for snippet in models.QuerySnippet.all(org=self.current_org)]
snippet.to_dict()
for snippet in models.QuerySnippet.all(org=self.current_org)
]

View File

@@ -1,7 +1,7 @@
from flask import request from flask import request
from redash.models import db, Organization from redash.handlers.base import BaseResource
from redash.handlers.base import BaseResource, record_event from redash.models import Organization, db
from redash.permissions import require_admin from redash.permissions import require_admin
from redash.settings.organization import settings as org_settings from redash.settings.organization import settings as org_settings
@@ -45,9 +45,7 @@ class OrganizationSettings(BaseResource):
previous_values[k] = self.current_org.google_apps_domains previous_values[k] = self.current_org.google_apps_domains
self.current_org.settings[Organization.SETTING_GOOGLE_APPS_DOMAINS] = v self.current_org.settings[Organization.SETTING_GOOGLE_APPS_DOMAINS] = v
else: else:
previous_values[k] = self.current_org.get_setting( previous_values[k] = self.current_org.get_setting(k, raise_on_missing=False)
k, raise_on_missing=False
)
self.current_org.set_setting(k, v) self.current_org.set_setting(k, v)
db.session.add(self.current_org) db.session.add(self.current_org)

View File

@@ -1,13 +1,13 @@
from flask import g, redirect, render_template, request, url_for from flask import g, redirect, render_template, request, url_for
from flask_login import login_user from flask_login import login_user
from wtforms import BooleanField, Form, PasswordField, StringField, validators
from wtforms.fields.html5 import EmailField
from redash import settings from redash import settings
from redash.authentication.org_resolving import current_org from redash.authentication.org_resolving import current_org
from redash.handlers.base import routes from redash.handlers.base import routes
from redash.models import Group, Organization, User, db from redash.models import Group, Organization, User, db
from redash.tasks.general import subscribe from redash.tasks.general import subscribe
from wtforms import BooleanField, Form, PasswordField, StringField, validators
from wtforms.fields.html5 import EmailField
class SetupForm(Form): class SetupForm(Form):
@@ -53,7 +53,7 @@ def create_org(org_name, user_name, email, password):
@routes.route("/setup", methods=["GET", "POST"]) @routes.route("/setup", methods=["GET", "POST"])
def setup(): def setup():
if current_org != None or settings.MULTI_ORG: if current_org != None or settings.MULTI_ORG: # noqa: E711
return redirect("/") return redirect("/")
form = SetupForm(request.form) form = SetupForm(request.form)
@@ -61,9 +61,7 @@ def setup():
form.security_notifications.data = True form.security_notifications.data = True
if request.method == "POST" and form.validate(): if request.method == "POST" and form.validate():
default_org, user = create_org( default_org, user = create_org(form.org_name.data, form.name.data, form.email.data, form.password.data)
form.org_name.data, form.name.data, form.email.data, form.password.data
)
g.org = default_org g.org = default_org
login_user(user) login_user(user)

View File

@@ -1,39 +1,29 @@
import re
import time
from flask import request
from flask_restful import abort
from flask_login import current_user, login_user
from funcy import project
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.exc import IntegrityError
from disposable_email_domains import blacklist from disposable_email_domains import blacklist
from funcy import partial from flask import request
from flask_login import current_user, login_user
from redash import models, limiter from flask_restful import abort
from redash.permissions import ( from funcy import partial, project
require_permission, from sqlalchemy.exc import IntegrityError
require_admin_or_owner, from sqlalchemy.orm.exc import NoResultFound
is_admin_or_owner,
require_permission_or_owner,
require_admin,
)
from redash.handlers.base import (
BaseResource,
require_fields,
get_object_or_404,
paginate,
order_results as _order_results,
)
from redash import limiter, models, settings
from redash.authentication.account import ( from redash.authentication.account import (
invite_link_for_user, invite_link_for_user,
send_invite_email, send_invite_email,
send_password_reset_email, send_password_reset_email,
send_verify_email, send_verify_email,
) )
from redash.handlers.base import BaseResource, get_object_or_404
from redash.handlers.base import order_results as _order_results
from redash.handlers.base import paginate, require_fields
from redash.permissions import (
is_admin_or_owner,
require_admin,
require_admin_or_owner,
require_permission,
require_permission_or_owner,
)
from redash.settings import parse_boolean from redash.settings import parse_boolean
from redash import settings
# Ordering map for relationships # Ordering map for relationships
order_map = { order_map = {
@@ -47,9 +37,7 @@ order_map = {
"-groups": "-group_ids", "-groups": "-group_ids",
} }
order_results = partial( order_results = partial(_order_results, default_order="-created_at", allowed_orders=order_map)
_order_results, default_order="-created_at", allowed_orders=order_map
)
def invite_user(org, inviter, user, send_email=True): def invite_user(org, inviter, user, send_email=True):
@@ -73,9 +61,7 @@ def require_allowed_email(email):
class UserListResource(BaseResource): class UserListResource(BaseResource):
decorators = BaseResource.decorators + [ decorators = BaseResource.decorators + [limiter.limit("200/day;50/hour", methods=["POST"])]
limiter.limit("200/day;50/hour", methods=["POST"])
]
def get_users(self, disabled, pending, search_term): def get_users(self, disabled, pending, search_term):
if disabled: if disabled:
@@ -97,9 +83,7 @@ class UserListResource(BaseResource):
} }
) )
else: else:
self.record_event( self.record_event({"action": "list", "object_type": "user", "pending": pending})
{"action": "list", "object_type": "user", "pending": pending}
)
# order results according to passed order parameter, # order results according to passed order parameter,
# special-casing search queries where the database # special-casing search queries where the database
@@ -131,9 +115,7 @@ class UserListResource(BaseResource):
disabled = request.args.get("disabled", "false") # get enabled users by default disabled = request.args.get("disabled", "false") # get enabled users by default
disabled = parse_boolean(disabled) disabled = parse_boolean(disabled)
pending = request.args.get( pending = request.args.get("pending", None) # get both active and pending by default
"pending", None
) # get both active and pending by default
if pending is not None: if pending is not None:
pending = parse_boolean(pending) pending = parse_boolean(pending)
@@ -166,14 +148,10 @@ class UserListResource(BaseResource):
abort(400, message="Email already taken.") abort(400, message="Email already taken.")
abort(500) abort(500)
self.record_event( self.record_event({"action": "create", "object_id": user.id, "object_type": "user"})
{"action": "create", "object_id": user.id, "object_type": "user"}
)
should_send_invitation = "no_invite" not in request.args should_send_invitation = "no_invite" not in request.args
return invite_user( return invite_user(self.current_org, self.current_user, user, send_email=should_send_invitation)
self.current_org, self.current_user, user, send_email=should_send_invitation
)
class UserInviteResource(BaseResource): class UserInviteResource(BaseResource):
@@ -205,9 +183,7 @@ class UserRegenerateApiKeyResource(BaseResource):
user.regenerate_api_key() user.regenerate_api_key()
models.db.session.commit() models.db.session.commit()
self.record_event( self.record_event({"action": "regnerate_api_key", "object_id": user.id, "object_type": "user"})
{"action": "regnerate_api_key", "object_id": user.id, "object_type": "user"}
)
return user.to_dict(with_api_key=True) return user.to_dict(with_api_key=True)
@@ -217,32 +193,24 @@ class UserResource(BaseResource):
def get(self, user_id): def get(self, user_id):
require_permission_or_owner("list_users", user_id) require_permission_or_owner("list_users", user_id)
user = get_object_or_404( user = get_object_or_404(models.User.get_by_id_and_org, user_id, self.current_org)
models.User.get_by_id_and_org, user_id, self.current_org
)
self.record_event( self.record_event({"action": "view", "object_id": user_id, "object_type": "user"})
{"action": "view", "object_id": user_id, "object_type": "user"}
)
return user.to_dict(with_api_key=is_admin_or_owner(user_id)) return user.to_dict(with_api_key=is_admin_or_owner(user_id))
def post(self, user_id): def post(self, user_id): # noqa: C901
require_admin_or_owner(user_id) require_admin_or_owner(user_id)
user = models.User.get_by_id_and_org(user_id, self.current_org) user = models.User.get_by_id_and_org(user_id, self.current_org)
req = request.get_json(True) req = request.get_json(True)
params = project( params = project(req, ("email", "name", "password", "old_password", "group_ids"))
req, ("email", "name", "password", "old_password", "group_ids")
)
if "password" in params and "old_password" not in params: if "password" in params and "old_password" not in params:
abort(403, message="Must provide current password to update password.") abort(403, message="Must provide current password to update password.")
if "old_password" in params and not user.verify_password( if "old_password" in params and not user.verify_password(params["old_password"]):
params["old_password"]
):
abort(403, message="Incorrect current password.") abort(403, message="Incorrect current password.")
if "password" in params: if "password" in params:
@@ -266,9 +234,7 @@ class UserResource(BaseResource):
require_allowed_email(params["email"]) require_allowed_email(params["email"])
email_address_changed = "email" in params and params["email"] != user.email email_address_changed = "email" in params and params["email"] != user.email
needs_to_verify_email = ( needs_to_verify_email = email_address_changed and settings.email_server_is_configured()
email_address_changed and settings.email_server_is_configured()
)
if needs_to_verify_email: if needs_to_verify_email:
user.is_email_verified = False user.is_email_verified = False
@@ -312,13 +278,13 @@ class UserResource(BaseResource):
abort( abort(
403, 403,
message="You cannot delete your own account. " message="You cannot delete your own account. "
"Please ask another admin to do this for you.", "Please ask another admin to do this for you.", # fmt: skip
) )
elif not user.is_invitation_pending: elif not user.is_invitation_pending:
abort( abort(
403, 403,
message="You cannot delete activated users. " message="You cannot delete activated users. "
"Please disable the user instead.", "Please disable the user instead.", # fmt: skip
) )
models.db.session.delete(user) models.db.session.delete(user)
models.db.session.commit() models.db.session.commit()
@@ -336,7 +302,7 @@ class UserDisableResource(BaseResource):
abort( abort(
403, 403,
message="You cannot disable your own account. " message="You cannot disable your own account. "
"Please ask another admin to do this for you.", "Please ask another admin to do this for you.", # fmt: skip
) )
user.disable() user.disable()
models.db.session.commit() models.db.session.commit()

View File

@@ -2,8 +2,11 @@ from flask import request
from redash import models from redash import models
from redash.handlers.base import BaseResource, get_object_or_404 from redash.handlers.base import BaseResource, get_object_or_404
from redash.permissions import (
require_object_modify_permission,
require_permission,
)
from redash.serializers import serialize_visualization from redash.serializers import serialize_visualization
from redash.permissions import require_object_modify_permission, require_permission
from redash.utils import json_dumps from redash.utils import json_dumps
@@ -12,9 +15,7 @@ class VisualizationListResource(BaseResource):
def post(self): def post(self):
kwargs = request.get_json(force=True) kwargs = request.get_json(force=True)
query = get_object_or_404( query = get_object_or_404(models.Query.get_by_id_and_org, kwargs.pop("query_id"), self.current_org)
models.Query.get_by_id_and_org, kwargs.pop("query_id"), self.current_org
)
require_object_modify_permission(query, self.current_user) require_object_modify_permission(query, self.current_user)
kwargs["options"] = json_dumps(kwargs["options"]) kwargs["options"] = json_dumps(kwargs["options"])
@@ -29,9 +30,7 @@ class VisualizationListResource(BaseResource):
class VisualizationResource(BaseResource): class VisualizationResource(BaseResource):
@require_permission("edit_query") @require_permission("edit_query")
def post(self, visualization_id): def post(self, visualization_id):
vis = get_object_or_404( vis = get_object_or_404(models.Visualization.get_by_id_and_org, visualization_id, self.current_org)
models.Visualization.get_by_id_and_org, visualization_id, self.current_org
)
require_object_modify_permission(vis.query_rel, self.current_user) require_object_modify_permission(vis.query_rel, self.current_user)
kwargs = request.get_json(force=True) kwargs = request.get_json(force=True)
@@ -48,9 +47,7 @@ class VisualizationResource(BaseResource):
@require_permission("edit_query") @require_permission("edit_query")
def delete(self, visualization_id): def delete(self, visualization_id):
vis = get_object_or_404( vis = get_object_or_404(models.Visualization.get_by_id_and_org, visualization_id, self.current_org)
models.Visualization.get_by_id_and_org, visualization_id, self.current_org
)
require_object_modify_permission(vis.query_rel, self.current_user) require_object_modify_permission(vis.query_rel, self.current_user)
self.record_event( self.record_event(
{ {

View File

@@ -1,10 +1,9 @@
import os import os
import simplejson import simplejson
from flask import url_for from flask import url_for
WEBPACK_MANIFEST_PATH = os.path.join( WEBPACK_MANIFEST_PATH = os.path.join(os.path.dirname(__file__), "../../client/dist/", "asset-manifest.json")
os.path.dirname(__file__), "../../client/dist/", "asset-manifest.json"
)
def configure_webpack(app): def configure_webpack(app):

View File

@@ -2,13 +2,13 @@ from flask import request
from redash import models from redash import models
from redash.handlers.base import BaseResource from redash.handlers.base import BaseResource
from redash.serializers import serialize_widget
from redash.permissions import ( from redash.permissions import (
require_access, require_access,
require_object_modify_permission, require_object_modify_permission,
require_permission, require_permission,
view_only, view_only,
) )
from redash.serializers import serialize_widget
from redash.utils import json_dumps from redash.utils import json_dumps
@@ -27,9 +27,7 @@ class WidgetListResource(BaseResource):
:>json object widget: The created widget :>json object widget: The created widget
""" """
widget_properties = request.get_json(force=True) widget_properties = request.get_json(force=True)
dashboard = models.Dashboard.get_by_id_and_org( dashboard = models.Dashboard.get_by_id_and_org(widget_properties.get("dashboard_id"), self.current_org)
widget_properties.get("dashboard_id"), self.current_org
)
require_object_modify_permission(dashboard, self.current_user) require_object_modify_permission(dashboard, self.current_user)
widget_properties["options"] = json_dumps(widget_properties["options"]) widget_properties["options"] = json_dumps(widget_properties["options"])
@@ -37,9 +35,7 @@ class WidgetListResource(BaseResource):
visualization_id = widget_properties.pop("visualization_id") visualization_id = widget_properties.pop("visualization_id")
if visualization_id: if visualization_id:
visualization = models.Visualization.get_by_id_and_org( visualization = models.Visualization.get_by_id_and_org(visualization_id, self.current_org)
visualization_id, self.current_org
)
require_access(visualization.query_rel, self.current_user, view_only) require_access(visualization.query_rel, self.current_user, view_only)
else: else:
visualization = None visualization = None
@@ -82,8 +78,6 @@ class WidgetResource(BaseResource):
""" """
widget = models.Widget.get_by_id_and_org(widget_id, self.current_org) widget = models.Widget.get_by_id_and_org(widget_id, self.current_org)
require_object_modify_permission(widget.dashboard, self.current_user) require_object_modify_permission(widget.dashboard, self.current_user)
self.record_event( self.record_event({"action": "delete", "object_id": widget_id, "object_type": "widget"})
{"action": "delete", "object_id": widget_id, "object_type": "widget"}
)
models.db.session.delete(widget) models.db.session.delete(widget)
models.db.session.commit() models.db.session.commit()

View File

@@ -2,13 +2,13 @@ import logging
import time import time
from flask import g, has_request_context from flask import g, has_request_context
from redash import statsd_client
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.event import listens_for from sqlalchemy.event import listens_for
from sqlalchemy.orm.util import _ORMJoin from sqlalchemy.orm.util import _ORMJoin
from sqlalchemy.sql.selectable import Alias from sqlalchemy.sql.selectable import Alias
from redash import statsd_client
metrics_logger = logging.getLogger("metrics") metrics_logger = logging.getLogger("metrics")

View File

@@ -35,16 +35,12 @@ def calculate_metrics(response):
queries_duration, queries_duration,
) )
statsd_client.timing( statsd_client.timing("requests.{}.{}".format(endpoint, request.method.lower()), request_duration)
"requests.{}.{}".format(endpoint, request.method.lower()), request_duration
)
return response return response
MockResponse = namedtuple( MockResponse = namedtuple("MockResponse", ["status_code", "content_type", "content_length"])
"MockResponse", ["status_code", "content_type", "content_length"]
)
def calculate_metrics_on_exception(error): def calculate_metrics_on_exception(error):

View File

@@ -1,60 +1,80 @@
import datetime
import calendar import calendar
import datetime
import logging import logging
import time
import numbers import numbers
import pytz import time
from sqlalchemy import distinct, or_, and_, UniqueConstraint, cast import pytz
from sqlalchemy import UniqueConstraint, and_, cast, distinct, func, or_
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.event import listens_for from sqlalchemy.event import listens_for
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import backref, contains_eager, joinedload, subqueryload, load_only from sqlalchemy.orm import (
backref,
contains_eager,
joinedload,
load_only,
subqueryload,
)
from sqlalchemy.orm.exc import NoResultFound # noqa: F401 from sqlalchemy.orm.exc import NoResultFound # noqa: F401
from sqlalchemy import func
from sqlalchemy_utils import generic_relationship from sqlalchemy_utils import generic_relationship
from sqlalchemy_utils.types import TSVectorType
from sqlalchemy_utils.models import generic_repr from sqlalchemy_utils.models import generic_repr
from sqlalchemy_utils.types import TSVectorType
from sqlalchemy_utils.types.encrypted.encrypted_type import FernetEngine from sqlalchemy_utils.types.encrypted.encrypted_type import FernetEngine
from redash import redis_connection, utils, settings from redash import redis_connection, settings, utils
from redash.destinations import ( from redash.destinations import (
get_configuration_schema_for_destination_type, get_configuration_schema_for_destination_type,
get_destination, get_destination,
) )
from redash.metrics import database # noqa: F401 from redash.metrics import database # noqa: F401
from redash.models.base import (
Column,
GFKBase,
SearchBaseQuery,
db,
gfk_type,
key_type,
primary_key,
)
from redash.models.changes import Change, ChangeTrackingMixin # noqa
from redash.models.mixins import BelongsToOrgMixin, TimestampMixin
from redash.models.organizations import Organization
from redash.models.parameterized_query import ParameterizedQuery
from redash.models.types import (
Configuration,
EncryptedConfiguration,
MutableDict,
MutableList,
PseudoJSON,
pseudo_json_cast_property,
)
from redash.models.users import ( # noqa
AccessPermission,
AnonymousUser,
ApiUser,
Group,
User,
)
from redash.query_runner import ( from redash.query_runner import (
with_ssh_tunnel,
get_configuration_schema_for_query_runner_type,
get_query_runner,
TYPE_BOOLEAN, TYPE_BOOLEAN,
TYPE_DATE, TYPE_DATE,
TYPE_DATETIME, TYPE_DATETIME,
BaseQueryRunner) BaseQueryRunner,
get_configuration_schema_for_query_runner_type,
get_query_runner,
with_ssh_tunnel,
)
from redash.utils import ( from redash.utils import (
base_url,
gen_query_hash,
generate_token, generate_token,
json_dumps, json_dumps,
json_loads, json_loads,
mustache_render, mustache_render,
base_url,
sentry, sentry,
gen_query_hash)
from redash.utils.configuration import ConfigurationContainer
from redash.models.parameterized_query import ParameterizedQuery
from .base import db, gfk_type, Column, GFKBase, SearchBaseQuery, key_type, primary_key
from .changes import ChangeTrackingMixin, Change # noqa
from .mixins import BelongsToOrgMixin, TimestampMixin
from .organizations import Organization
from .types import (
EncryptedConfiguration,
Configuration,
MutableDict,
MutableList,
PseudoJSON,
pseudo_json_cast_property
) )
from .users import AccessPermission, AnonymousUser, ApiUser, Group, User # noqa from redash.utils.configuration import ConfigurationContainer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -93,18 +113,14 @@ class DataSource(BelongsToOrgMixin, db.Model):
options = Column( options = Column(
"encrypted_options", "encrypted_options",
ConfigurationContainer.as_mutable( ConfigurationContainer.as_mutable(
EncryptedConfiguration( EncryptedConfiguration(db.Text, settings.DATASOURCE_SECRET_KEY, FernetEngine)
db.Text, settings.DATASOURCE_SECRET_KEY, FernetEngine
)
), ),
) )
queue_name = Column(db.String(255), default="queries") queue_name = Column(db.String(255), default="queries")
scheduled_queue_name = Column(db.String(255), default="scheduled_queries") scheduled_queue_name = Column(db.String(255), default="scheduled_queries")
created_at = Column(db.DateTime(True), default=db.func.now()) created_at = Column(db.DateTime(True), default=db.func.now())
data_source_groups = db.relationship( data_source_groups = db.relationship("DataSourceGroup", back_populates="data_source", cascade="all")
"DataSourceGroup", back_populates="data_source", cascade="all"
)
__tablename__ = "data_sources" __tablename__ = "data_sources"
__table_args__ = (db.Index("data_sources_org_id_name", "org_id", "name"),) __table_args__ = (db.Index("data_sources_org_id_name", "org_id", "name"),)
@@ -122,7 +138,7 @@ class DataSource(BelongsToOrgMixin, db.Model):
"syntax": self.query_runner.syntax, "syntax": self.query_runner.syntax,
"paused": self.paused, "paused": self.paused,
"pause_reason": self.pause_reason, "pause_reason": self.pause_reason,
"supports_auto_limit": self.query_runner.supports_auto_limit "supports_auto_limit": self.query_runner.supports_auto_limit,
} }
if all: if all:
@@ -151,9 +167,7 @@ class DataSource(BelongsToOrgMixin, db.Model):
@classmethod @classmethod
def create_with_group(cls, *args, **kwargs): def create_with_group(cls, *args, **kwargs):
data_source = cls(*args, **kwargs) data_source = cls(*args, **kwargs)
data_source_group = DataSourceGroup( data_source_group = DataSourceGroup(data_source=data_source, group=data_source.org.default_group)
data_source=data_source, group=data_source.org.default_group
)
db.session.add_all([data_source, data_source_group]) db.session.add_all([data_source, data_source_group])
return data_source return data_source
@@ -162,9 +176,7 @@ class DataSource(BelongsToOrgMixin, db.Model):
data_sources = cls.query.filter(cls.org == org).order_by(cls.id.asc()) data_sources = cls.query.filter(cls.org == org).order_by(cls.id.asc())
if group_ids: if group_ids:
data_sources = data_sources.join(DataSourceGroup).filter( data_sources = data_sources.join(DataSourceGroup).filter(DataSourceGroup.group_id.in_(group_ids))
DataSourceGroup.group_id.in_(group_ids)
)
return data_sources.distinct() return data_sources.distinct()
@@ -173,9 +185,7 @@ class DataSource(BelongsToOrgMixin, db.Model):
return cls.query.filter(cls.id == _id).one() return cls.query.filter(cls.id == _id).one()
def delete(self): def delete(self):
Query.query.filter(Query.data_source == self).update( Query.query.filter(Query.data_source == self).update(dict(data_source_id=None, latest_query_data_id=None))
dict(data_source_id=None, latest_query_data_id=None)
)
QueryResult.query.filter(QueryResult.data_source == self).delete() QueryResult.query.filter(QueryResult.data_source == self).delete()
res = db.session.delete(self) res = db.session.delete(self)
db.session.commit() db.session.commit()
@@ -200,9 +210,7 @@ class DataSource(BelongsToOrgMixin, db.Model):
try: try:
out_schema = self._sort_schema(schema) out_schema = self._sort_schema(schema)
except Exception: except Exception:
logging.exception( logging.exception("Error sorting schema columns for data_source {}".format(self.id))
"Error sorting schema columns for data_source {}".format(self.id)
)
out_schema = schema out_schema = schema
finally: finally:
redis_connection.set(self._schema_key, json_dumps(out_schema)) redis_connection.set(self._schema_key, json_dumps(out_schema))
@@ -243,15 +251,11 @@ class DataSource(BelongsToOrgMixin, db.Model):
return dsg return dsg
def remove_group(self, group): def remove_group(self, group):
DataSourceGroup.query.filter( DataSourceGroup.query.filter(DataSourceGroup.group == group, DataSourceGroup.data_source == self).delete()
DataSourceGroup.group == group, DataSourceGroup.data_source == self
).delete()
db.session.commit() db.session.commit()
def update_group_permission(self, group, view_only): def update_group_permission(self, group, view_only):
dsg = DataSourceGroup.query.filter( dsg = DataSourceGroup.query.filter(DataSourceGroup.group == group, DataSourceGroup.data_source == self).one()
DataSourceGroup.group == group, DataSourceGroup.data_source == self
).one()
dsg.view_only = view_only dsg.view_only = view_only
db.session.add(dsg) db.session.add(dsg)
return dsg return dsg
@@ -314,9 +318,7 @@ class DBPersistence(object):
self._data = data self._data = data
QueryResultPersistence = ( QueryResultPersistence = settings.dynamic_settings.QueryResultPersistence or DBPersistence
settings.dynamic_settings.QueryResultPersistence or DBPersistence
)
@generic_repr("id", "org_id", "data_source_id", "query_hash", "runtime", "retrieved_at") @generic_repr("id", "org_id", "data_source_id", "query_hash", "runtime", "retrieved_at")
@@ -351,11 +353,9 @@ class QueryResult(db.Model, QueryResultPersistence, BelongsToOrgMixin):
@classmethod @classmethod
def unused(cls, days=7): def unused(cls, days=7):
age_threshold = datetime.datetime.now() - datetime.timedelta(days=days) age_threshold = datetime.datetime.now() - datetime.timedelta(days=days)
return ( return (cls.query.filter(Query.id.is_(None), cls.retrieved_at < age_threshold).outerjoin(Query)).options(
cls.query.filter( load_only("id")
Query.id.is_(None), cls.retrieved_at < age_threshold )
).outerjoin(Query)
).options(load_only("id"))
@classmethod @classmethod
def get_latest(cls, data_source, query, max_age=0): def get_latest(cls, data_source, query, max_age=0):
@@ -365,16 +365,13 @@ class QueryResult(db.Model, QueryResultPersistence, BelongsToOrgMixin):
max_age = settings.QUERY_RESULTS_EXPIRED_TTL max_age = settings.QUERY_RESULTS_EXPIRED_TTL
if max_age == -1: if max_age == -1:
query = cls.query.filter( query = cls.query.filter(cls.query_hash == query_hash, cls.data_source == data_source)
cls.query_hash == query_hash, cls.data_source == data_source
)
else: else:
query = cls.query.filter( query = cls.query.filter(
cls.query_hash == query_hash, cls.query_hash == query_hash,
cls.data_source == data_source, cls.data_source == data_source,
( (
db.func.timezone("utc", cls.retrieved_at) db.func.timezone("utc", cls.retrieved_at) + datetime.timedelta(seconds=max_age)
+ datetime.timedelta(seconds=max_age)
>= db.func.timezone("utc", db.func.now()) >= db.func.timezone("utc", db.func.now())
), ),
) )
@@ -382,9 +379,7 @@ class QueryResult(db.Model, QueryResultPersistence, BelongsToOrgMixin):
return query.order_by(cls.retrieved_at.desc()).first() return query.order_by(cls.retrieved_at.desc()).first()
@classmethod @classmethod
def store_result( def store_result(cls, org, data_source, query_hash, query, data, run_time, retrieved_at):
cls, org, data_source, query_hash, query, data, run_time, retrieved_at
):
query_result = cls( query_result = cls(
org_id=org, org_id=org,
query_hash=query_hash, query_hash=query_hash,
@@ -405,9 +400,7 @@ class QueryResult(db.Model, QueryResultPersistence, BelongsToOrgMixin):
return self.data_source.groups return self.data_source.groups
def should_schedule_next( def should_schedule_next(previous_iteration, now, interval, time=None, day_of_week=None, failures=0):
previous_iteration, now, interval, time=None, day_of_week=None, failures=0
):
# if time exists then interval > 23 hours (82800s) # if time exists then interval > 23 hours (82800s)
# if day_of_week exists then interval > 6 days (518400s) # if day_of_week exists then interval > 6 days (518400s)
if time is None: if time is None:
@@ -421,32 +414,23 @@ def should_schedule_next(
# - The query scheduled to run at 23:59. # - The query scheduled to run at 23:59.
# - The scheduler wakes up at 00:01. # - The scheduler wakes up at 00:01.
# - Using naive implementation of comparing timestamps, it will skip the execution. # - Using naive implementation of comparing timestamps, it will skip the execution.
normalized_previous_iteration = previous_iteration.replace( normalized_previous_iteration = previous_iteration.replace(hour=hour, minute=minute)
hour=hour, minute=minute
)
if normalized_previous_iteration > previous_iteration: if normalized_previous_iteration > previous_iteration:
previous_iteration = normalized_previous_iteration - datetime.timedelta( previous_iteration = normalized_previous_iteration - datetime.timedelta(days=1)
days=1
)
days_delay = int(interval) / 60 / 60 / 24 days_delay = int(interval) / 60 / 60 / 24
days_to_add = 0 days_to_add = 0
if day_of_week is not None: if day_of_week is not None:
days_to_add = ( days_to_add = list(calendar.day_name).index(day_of_week) - normalized_previous_iteration.weekday()
list(calendar.day_name).index(day_of_week)
- normalized_previous_iteration.weekday()
)
next_iteration = ( next_iteration = (
previous_iteration previous_iteration + datetime.timedelta(days=days_delay) + datetime.timedelta(days=days_to_add)
+ datetime.timedelta(days=days_delay)
+ datetime.timedelta(days=days_to_add)
).replace(hour=hour, minute=minute) ).replace(hour=hour, minute=minute)
if failures: if failures:
try: try:
next_iteration += datetime.timedelta(minutes=2 ** failures) next_iteration += datetime.timedelta(minutes=2**failures)
except OverflowError: except OverflowError:
return False return False
return now > next_iteration return now > next_iteration
@@ -475,9 +459,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
org = db.relationship(Organization, backref="queries") org = db.relationship(Organization, backref="queries")
data_source_id = Column(key_type("DataSource"), db.ForeignKey("data_sources.id"), nullable=True) data_source_id = Column(key_type("DataSource"), db.ForeignKey("data_sources.id"), nullable=True)
data_source = db.relationship(DataSource, backref="queries") data_source = db.relationship(DataSource, backref="queries")
latest_query_data_id = Column( latest_query_data_id = Column(key_type("QueryResult"), db.ForeignKey("query_results.id"), nullable=True)
key_type("QueryResult"), db.ForeignKey("query_results.id"), nullable=True
)
latest_query_data = db.relationship(QueryResult) latest_query_data = db.relationship(QueryResult)
name = Column(db.String(255)) name = Column(db.String(255))
description = Column(db.String(4096), nullable=True) description = Column(db.String(4096), nullable=True)
@@ -487,9 +469,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
user_id = Column(key_type("User"), db.ForeignKey("users.id")) user_id = Column(key_type("User"), db.ForeignKey("users.id"))
user = db.relationship(User, foreign_keys=[user_id]) user = db.relationship(User, foreign_keys=[user_id])
last_modified_by_id = Column(key_type("User"), db.ForeignKey("users.id"), nullable=True) last_modified_by_id = Column(key_type("User"), db.ForeignKey("users.id"), nullable=True)
last_modified_by = db.relationship( last_modified_by = db.relationship(User, backref="modified_queries", foreign_keys=[last_modified_by_id])
User, backref="modified_queries", foreign_keys=[last_modified_by_id]
)
is_archived = Column(db.Boolean, default=False, index=True) is_archived = Column(db.Boolean, default=False, index=True)
is_draft = Column(db.Boolean, default=True, index=True) is_draft = Column(db.Boolean, default=True, index=True)
schedule = Column(MutableDict.as_mutable(PseudoJSON), nullable=True) schedule = Column(MutableDict.as_mutable(PseudoJSON), nullable=True)
@@ -507,9 +487,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
), ),
nullable=True, nullable=True,
) )
tags = Column( tags = Column("tags", MutableList.as_mutable(postgresql.ARRAY(db.Unicode)), nullable=True)
"tags", MutableList.as_mutable(postgresql.ARRAY(db.Unicode)), nullable=True
)
query_class = SearchBaseQuery query_class = SearchBaseQuery
__tablename__ = "queries" __tablename__ = "queries"
@@ -551,37 +529,27 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
return query return query
@classmethod @classmethod
def all_queries( def all_queries(cls, group_ids, user_id=None, include_drafts=False, include_archived=False):
cls, group_ids, user_id=None, include_drafts=False, include_archived=False
):
query_ids = ( query_ids = (
db.session.query(distinct(cls.id)) db.session.query(distinct(cls.id))
.join( .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id)
DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id
)
.filter(Query.is_archived.is_(include_archived)) .filter(Query.is_archived.is_(include_archived))
.filter(DataSourceGroup.group_id.in_(group_ids)) .filter(DataSourceGroup.group_id.in_(group_ids))
) )
queries = ( queries = (
cls.query.options( cls.query.options(
joinedload(Query.user), joinedload(Query.user),
joinedload(Query.latest_query_data).load_only( joinedload(Query.latest_query_data).load_only("runtime", "retrieved_at"),
"runtime", "retrieved_at"
),
) )
.filter(cls.id.in_(query_ids)) .filter(cls.id.in_(query_ids))
# Adding outer joins to be able to order by relationship # Adding outer joins to be able to order by relationship
.outerjoin(User, User.id == Query.user_id) .outerjoin(User, User.id == Query.user_id)
.outerjoin(QueryResult, QueryResult.id == Query.latest_query_data_id) .outerjoin(QueryResult, QueryResult.id == Query.latest_query_data_id)
.options( .options(contains_eager(Query.user), contains_eager(Query.latest_query_data))
contains_eager(Query.user), contains_eager(Query.latest_query_data)
)
) )
if not include_drafts: if not include_drafts:
queries = queries.filter( queries = queries.filter(or_(Query.is_draft.is_(False), Query.user_id == user_id))
or_(Query.is_draft.is_(False), Query.user_id == user_id)
)
return queries return queries
@classmethod @classmethod
@@ -597,9 +565,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
@classmethod @classmethod
def all_tags(cls, user, include_drafts=False): def all_tags(cls, user, include_drafts=False):
queries = cls.all_queries( queries = cls.all_queries(group_ids=user.group_ids, user_id=user.id, include_drafts=include_drafts)
group_ids=user.group_ids, user_id=user.id, include_drafts=include_drafts
)
tag_column = func.unnest(cls.tags).label("tag") tag_column = func.unnest(cls.tags).label("tag")
usage_count = func.count(1).label("usage_count") usage_count = func.count(1).label("usage_count")
@@ -628,18 +594,13 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
query query
for query in queries for query in queries
if query.schedule["until"] is not None if query.schedule["until"] is not None
and pytz.utc.localize( and pytz.utc.localize(datetime.datetime.strptime(query.schedule["until"], "%Y-%m-%d")) <= now
datetime.datetime.strptime(query.schedule["until"], "%Y-%m-%d")
)
<= now
] ]
@classmethod @classmethod
def outdated_queries(cls): def outdated_queries(cls):
queries = ( queries = (
Query.query.options( Query.query.options(joinedload(Query.latest_query_data).load_only("retrieved_at"))
joinedload(Query.latest_query_data).load_only("retrieved_at")
)
.filter(Query.schedule.isnot(None)) .filter(Query.schedule.isnot(None))
.order_by(Query.id) .order_by(Query.id)
.all() .all()
@@ -655,9 +616,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
continue continue
if query.schedule["until"]: if query.schedule["until"]:
schedule_until = pytz.utc.localize( schedule_until = pytz.utc.localize(datetime.datetime.strptime(query.schedule["until"], "%Y-%m-%d"))
datetime.datetime.strptime(query.schedule["until"], "%Y-%m-%d")
)
if schedule_until <= now: if schedule_until <= now:
continue continue
@@ -685,9 +644,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
% (query.id, repr(e)) % (query.id, repr(e))
) )
logging.info(message) logging.info(message)
sentry.capture_exception( sentry.capture_exception(type(e)(message).with_traceback(e.__traceback__))
type(e)(message).with_traceback(e.__traceback__)
)
return list(outdated_queries.values()) return list(outdated_queries.values())
@@ -713,9 +670,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
# Since tsvector doesn't work well with CJK languages, use `ilike` too # Since tsvector doesn't work well with CJK languages, use `ilike` too
pattern = "%{}%".format(term) pattern = "%{}%".format(term)
return ( return (
all_queries.filter( all_queries.filter(or_(cls.name.ilike(pattern), cls.description.ilike(pattern)))
or_(cls.name.ilike(pattern), cls.description.ilike(pattern))
)
.order_by(Query.id) .order_by(Query.id)
.limit(limit) .limit(limit)
) )
@@ -732,18 +687,14 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
query = ( query = (
cls.query.filter(Event.created_at > (db.func.current_date() - 7)) cls.query.filter(Event.created_at > (db.func.current_date() - 7))
.join(Event, Query.id == Event.object_id.cast(db.Integer)) .join(Event, Query.id == Event.object_id.cast(db.Integer))
.join( .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id)
DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id
)
.filter( .filter(
Event.action.in_( Event.action.in_(["edit", "execute", "edit_name", "edit_description", "view_source"]),
["edit", "execute", "edit_name", "edit_description", "view_source"] Event.object_id is not None,
),
Event.object_id != None,
Event.object_type == "query", Event.object_type == "query",
DataSourceGroup.group_id.in_(group_ids), DataSourceGroup.group_id.in_(group_ids),
or_(Query.is_draft == False, Query.user_id == user_id), or_(Query.is_draft.is_(False), Query.user_id is user_id),
Query.is_archived == False, Query.is_archived.is_(False),
) )
.group_by(Event.object_id, Query.id) .group_by(Event.object_id, Query.id)
.order_by(db.desc(db.func.count(0))) .order_by(db.desc(db.func.count(0)))
@@ -806,16 +757,12 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
kwargs = {a: getattr(self, a) for a in forked_list} kwargs = {a: getattr(self, a) for a in forked_list}
# Query.create will add default TABLE visualization, so use constructor to create bare copy of query # Query.create will add default TABLE visualization, so use constructor to create bare copy of query
forked_query = Query( forked_query = Query(name="Copy of (#{}) {}".format(self.id, self.name), user=user, **kwargs)
name="Copy of (#{}) {}".format(self.id, self.name), user=user, **kwargs
)
for v in sorted(self.visualizations, key=lambda v: v.id): for v in sorted(self.visualizations, key=lambda v: v.id):
forked_v = v.copy() forked_v = v.copy()
forked_v["query_rel"] = forked_query forked_v["query_rel"] = forked_query
fv = Visualization( fv = Visualization(**forked_v) # it will magically add it to `forked_query.visualizations`
**forked_v
) # it will magically add it to `forked_query.visualizations`
db.session.add(fv) db.session.add(fv)
db.session.add(forked_query) db.session.add(forked_query)
@@ -898,9 +845,7 @@ class Favorite(TimestampMixin, db.Model):
user = db.relationship(User, backref="favorites") user = db.relationship(User, backref="favorites")
__tablename__ = "favorites" __tablename__ = "favorites"
__table_args__ = ( __table_args__ = (UniqueConstraint("object_type", "object_id", "user_id", name="unique_favorite"),)
UniqueConstraint("object_type", "object_id", "user_id", name="unique_favorite"),
)
@classmethod @classmethod
def is_favorite(cls, user, object): def is_favorite(cls, user, object):
@@ -966,9 +911,7 @@ def next_state(op, value, threshold):
return new_state return new_state
@generic_repr( @generic_repr("id", "name", "query_id", "user_id", "state", "last_triggered_at", "rearm")
"id", "name", "query_id", "user_id", "state", "last_triggered_at", "rearm"
)
class Alert(TimestampMixin, BelongsToOrgMixin, db.Model): class Alert(TimestampMixin, BelongsToOrgMixin, db.Model):
UNKNOWN_STATE = "unknown" UNKNOWN_STATE = "unknown"
OK_STATE = "ok" OK_STATE = "ok"
@@ -993,9 +936,7 @@ class Alert(TimestampMixin, BelongsToOrgMixin, db.Model):
return ( return (
cls.query.options(joinedload(Alert.user), joinedload(Alert.query_rel)) cls.query.options(joinedload(Alert.user), joinedload(Alert.query_rel))
.join(Query) .join(Query)
.join( .join(DataSourceGroup, DataSourceGroup.data_source_id == Query.data_source_id)
DataSourceGroup, DataSourceGroup.data_source_id == Query.data_source_id
)
.filter(DataSourceGroup.group_id.in_(group_ids)) .filter(DataSourceGroup.group_id.in_(group_ids))
) )
@@ -1019,9 +960,7 @@ class Alert(TimestampMixin, BelongsToOrgMixin, db.Model):
return new_state return new_state
def subscribers(self): def subscribers(self):
return User.query.join(AlertSubscription).filter( return User.query.join(AlertSubscription).filter(AlertSubscription.alert == self)
AlertSubscription.alert == self
)
def render_template(self, template): def render_template(self, template):
if template is None: if template is None:
@@ -1043,9 +982,7 @@ class Alert(TimestampMixin, BelongsToOrgMixin, db.Model):
"ALERT_CONDITION": self.options["op"], "ALERT_CONDITION": self.options["op"],
"ALERT_THRESHOLD": self.options["value"], "ALERT_THRESHOLD": self.options["value"],
"QUERY_NAME": self.query_rel.name, "QUERY_NAME": self.query_rel.name,
"QUERY_URL": "{host}/queries/{query_id}".format( "QUERY_URL": "{host}/queries/{query_id}".format(host=host, query_id=self.query_rel.id),
host=host, query_id=self.query_rel.id
),
"QUERY_RESULT_VALUE": result_value, "QUERY_RESULT_VALUE": result_value,
"QUERY_RESULT_ROWS": data["rows"], "QUERY_RESULT_ROWS": data["rows"],
"QUERY_RESULT_COLS": data["columns"], "QUERY_RESULT_COLS": data["columns"],
@@ -1081,9 +1018,7 @@ def generate_slug(ctx):
@gfk_type @gfk_type
@generic_repr( @generic_repr("id", "name", "slug", "user_id", "org_id", "version", "is_archived", "is_draft")
"id", "name", "slug", "user_id", "org_id", "version", "is_archived", "is_draft"
)
class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
id = primary_key("Dashboard") id = primary_key("Dashboard")
version = Column(db.Integer) version = Column(db.Integer)
@@ -1099,12 +1034,8 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model
is_archived = Column(db.Boolean, default=False, index=True) is_archived = Column(db.Boolean, default=False, index=True)
is_draft = Column(db.Boolean, default=True, index=True) is_draft = Column(db.Boolean, default=True, index=True)
widgets = db.relationship("Widget", backref="dashboard", lazy="dynamic") widgets = db.relationship("Widget", backref="dashboard", lazy="dynamic")
tags = Column( tags = Column("tags", MutableList.as_mutable(postgresql.ARRAY(db.Unicode)), nullable=True)
"tags", MutableList.as_mutable(postgresql.ARRAY(db.Unicode)), nullable=True options = Column(MutableDict.as_mutable(postgresql.JSON), server_default="{}", default={})
)
options = Column(
MutableDict.as_mutable(postgresql.JSON), server_default="{}", default={}
)
__tablename__ = "dashboards" __tablename__ = "dashboards"
__mapper_args__ = {"version_id_col": version} __mapper_args__ = {"version_id_col": version}
@@ -1119,39 +1050,27 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model
@classmethod @classmethod
def all(cls, org, group_ids, user_id): def all(cls, org, group_ids, user_id):
query = ( query = (
Dashboard.query.options( Dashboard.query.options(joinedload(Dashboard.user).load_only("id", "name", "details", "email"))
joinedload(Dashboard.user).load_only( .distinct(cls.lowercase_name, Dashboard.created_at, Dashboard.slug)
"id", "name", "details", "email"
)
).distinct(cls.lowercase_name, Dashboard.created_at, Dashboard.slug)
.outerjoin(Widget) .outerjoin(Widget)
.outerjoin(Visualization) .outerjoin(Visualization)
.outerjoin(Query) .outerjoin(Query)
.outerjoin( .outerjoin(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id)
DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id
)
.filter( .filter(
Dashboard.is_archived == False, Dashboard.is_archived.is_(False),
( (DataSourceGroup.group_id.in_(group_ids) | (Dashboard.user_id == user_id)),
DataSourceGroup.group_id.in_(group_ids)
| (Dashboard.user_id == user_id)
),
Dashboard.org == org, Dashboard.org == org,
) )
) )
query = query.filter( query = query.filter(or_(Dashboard.user_id == user_id, Dashboard.is_draft.is_(False)))
or_(Dashboard.user_id == user_id, Dashboard.is_draft == False)
)
return query return query
@classmethod @classmethod
def search(cls, org, groups_ids, user_id, search_term): def search(cls, org, groups_ids, user_id, search_term):
# TODO: switch to FTS # TODO: switch to FTS
return cls.all(org, groups_ids, user_id).filter( return cls.all(org, groups_ids, user_id).filter(cls.name.ilike("%{}%".format(search_term)))
cls.name.ilike("%{}%".format(search_term))
)
@classmethod @classmethod
def search_by_user(cls, term, user, limit=None): def search_by_user(cls, term, user, limit=None):
@@ -1237,12 +1156,8 @@ class Visualization(TimestampMixin, BelongsToOrgMixin, db.Model):
@generic_repr("id", "visualization_id", "dashboard_id") @generic_repr("id", "visualization_id", "dashboard_id")
class Widget(TimestampMixin, BelongsToOrgMixin, db.Model): class Widget(TimestampMixin, BelongsToOrgMixin, db.Model):
id = primary_key("Widget") id = primary_key("Widget")
visualization_id = Column( visualization_id = Column(key_type("Visualization"), db.ForeignKey("visualizations.id"), nullable=True)
key_type("Visualization"), db.ForeignKey("visualizations.id"), nullable=True visualization = db.relationship(Visualization, backref=backref("widgets", cascade="delete"))
)
visualization = db.relationship(
Visualization, backref=backref("widgets", cascade="delete")
)
text = Column(db.Text, nullable=True) text = Column(db.Text, nullable=True)
width = Column(db.Integer) width = Column(db.Integer)
options = Column(db.Text) options = Column(db.Text)
@@ -1258,9 +1173,7 @@ class Widget(TimestampMixin, BelongsToOrgMixin, db.Model):
return super(Widget, cls).get_by_id_and_org(object_id, org, Dashboard) return super(Widget, cls).get_by_id_and_org(object_id, org, Dashboard)
@generic_repr( @generic_repr("id", "object_type", "object_id", "action", "user_id", "org_id", "created_at")
"id", "object_type", "object_id", "action", "user_id", "org_id", "created_at"
)
class Event(db.Model): class Event(db.Model):
id = primary_key("Event") id = primary_key("Event")
org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id"))
@@ -1270,9 +1183,7 @@ class Event(db.Model):
action = Column(db.String(255)) action = Column(db.String(255))
object_type = Column(db.String(255)) object_type = Column(db.String(255))
object_id = Column(db.String(255), nullable=True) object_id = Column(db.String(255), nullable=True)
additional_properties = Column( additional_properties = Column(MutableDict.as_mutable(PseudoJSON), nullable=True, default={})
MutableDict.as_mutable(PseudoJSON), nullable=True, default={}
)
created_at = Column(db.DateTime(True), default=db.func.now()) created_at = Column(db.DateTime(True), default=db.func.now())
__tablename__ = "events" __tablename__ = "events"
@@ -1332,20 +1243,18 @@ class ApiKey(TimestampMixin, GFKBase, db.Model):
created_by = db.relationship(User) created_by = db.relationship(User)
__tablename__ = "api_keys" __tablename__ = "api_keys"
__table_args__ = ( __table_args__ = (db.Index("api_keys_object_type_object_id", "object_type", "object_id"),)
db.Index("api_keys_object_type_object_id", "object_type", "object_id"),
)
@classmethod @classmethod
def get_by_api_key(cls, api_key): def get_by_api_key(cls, api_key):
return cls.query.filter(cls.api_key == api_key, cls.active == True).one() return cls.query.filter(cls.api_key == api_key, cls.active.is_(True)).one()
@classmethod @classmethod
def get_by_object(cls, object): def get_by_object(cls, object):
return cls.query.filter( return cls.query.filter(
cls.object_type == object.__class__.__tablename__, cls.object_type == object.__class__.__tablename__,
cls.object_id == object.id, cls.object_id == object.id,
cls.active == True, cls.active.is_(True),
).first() ).first()
@classmethod @classmethod
@@ -1367,19 +1276,13 @@ class NotificationDestination(BelongsToOrgMixin, db.Model):
options = Column( options = Column(
"encrypted_options", "encrypted_options",
ConfigurationContainer.as_mutable( ConfigurationContainer.as_mutable(
EncryptedConfiguration( EncryptedConfiguration(db.Text, settings.DATASOURCE_SECRET_KEY, FernetEngine)
db.Text, settings.DATASOURCE_SECRET_KEY, FernetEngine
)
), ),
) )
created_at = Column(db.DateTime(True), default=db.func.now()) created_at = Column(db.DateTime(True), default=db.func.now())
__tablename__ = "notification_destinations" __tablename__ = "notification_destinations"
__table_args__ = ( __table_args__ = (db.Index("notification_destinations_org_id_name", "org_id", "name", unique=True),)
db.Index(
"notification_destinations_org_id_name", "org_id", "name", unique=True
),
)
def __str__(self): def __str__(self):
return str(self.name) return str(self.name)
@@ -1405,18 +1308,14 @@ class NotificationDestination(BelongsToOrgMixin, db.Model):
@classmethod @classmethod
def all(cls, org): def all(cls, org):
notification_destinations = cls.query.filter(cls.org == org).order_by( notification_destinations = cls.query.filter(cls.org == org).order_by(cls.id.asc())
cls.id.asc()
)
return notification_destinations return notification_destinations
def notify(self, alert, query, user, new_state, app, host): def notify(self, alert, query, user, new_state, app, host):
schema = get_configuration_schema_for_destination_type(self.type) schema = get_configuration_schema_for_destination_type(self.type)
self.options.set_schema(schema) self.options.set_schema(schema)
return self.destination.notify( return self.destination.notify(alert, query, user, new_state, app, host, self.options)
alert, query, user, new_state, app, host, self.options
)
@generic_repr("id", "user_id", "destination_id", "alert_id") @generic_repr("id", "user_id", "destination_id", "alert_id")
@@ -1451,9 +1350,7 @@ class AlertSubscription(TimestampMixin, db.Model):
@classmethod @classmethod
def all(cls, alert_id): def all(cls, alert_id):
return AlertSubscription.query.join(User).filter( return AlertSubscription.query.join(User).filter(AlertSubscription.alert_id == alert_id)
AlertSubscription.alert_id == alert_id
)
def notify(self, alert, query, user, new_state, app, host): def notify(self, alert, query, user, new_state, app, host):
if self.destination: if self.destination:

View File

@@ -85,11 +85,7 @@ class GFKBase(object):
return self._object return self._object
else: else:
object_class = _gfk_types[self.object_type] object_class = _gfk_types[self.object_type]
self._object = ( self._object = session.query(object_class).filter(object_class.id == self.object_id).first()
session.query(object_class)
.filter(object_class.id == self.object_id)
.first()
)
return self._object return self._object
@object.setter @object.setter

View File

@@ -1,7 +1,7 @@
from sqlalchemy.inspection import inspect from sqlalchemy.inspection import inspect
from sqlalchemy_utils.models import generic_repr from sqlalchemy_utils.models import generic_repr
from .base import GFKBase, db, Column, primary_key, key_type from .base import Column, GFKBase, db, key_type, primary_key
from .types import PseudoJSON from .types import PseudoJSON
@@ -39,9 +39,7 @@ class Change(GFKBase, db.Model):
@classmethod @classmethod
def last_change(cls, obj): def last_change(cls, obj):
return ( return (
cls.query.filter( cls.query.filter(cls.object_id == obj.id, cls.object_type == obj.__class__.__tablename__)
cls.object_id == obj.id, cls.object_type == obj.__class__.__tablename__
)
.order_by(cls.object_version.desc()) .order_by(cls.object_version.desc())
.first() .first()
) )

View File

@@ -1,6 +1,6 @@
from sqlalchemy.event import listens_for from sqlalchemy.event import listens_for
from .base import db, Column from .base import Column, db
class TimestampMixin(object): class TimestampMixin(object):

View File

@@ -3,10 +3,10 @@ from sqlalchemy_utils.models import generic_repr
from redash.settings.organization import settings as org_settings from redash.settings.organization import settings as org_settings
from .base import db, Column, primary_key from .base import Column, db, primary_key
from .mixins import TimestampMixin from .mixins import TimestampMixin
from .types import MutableDict, PseudoJSON from .types import MutableDict, PseudoJSON
from .users import User, Group from .users import Group, User
@generic_repr("id", "name", "slug") @generic_repr("id", "name", "slug")
@@ -36,9 +36,7 @@ class Organization(TimestampMixin, db.Model):
@property @property
def default_group(self): def default_group(self):
return self.groups.filter( return self.groups.filter(Group.name == "default", Group.type == Group.BUILTIN_GROUP).first()
Group.name == "default", Group.type == Group.BUILTIN_GROUP
).first()
@property @property
def google_apps_domains(self): def google_apps_domains(self):
@@ -80,9 +78,7 @@ class Organization(TimestampMixin, db.Model):
@property @property
def admin_group(self): def admin_group(self):
return self.groups.filter( return self.groups.filter(Group.name == "admin", Group.type == Group.BUILTIN_GROUP).first()
Group.name == "admin", Group.type == Group.BUILTIN_GROUP
).first()
def has_user(self, email): def has_user(self, email):
return self.users.filter(User.email == email).count() == 1 return self.users.filter(User.email == email).count() == 1

View File

@@ -1,10 +1,11 @@
import pystache
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from redash.utils import mustache_render, json_loads
from redash.permissions import require_access, view_only import pystache
from funcy import distinct
from dateutil.parser import parse from dateutil.parser import parse
from funcy import distinct
from redash.utils import mustache_render
def _pluck_name_and_value(default_column, row): def _pluck_name_and_value(default_column, row):
@@ -21,9 +22,7 @@ def _load_result(query_id, org):
query = models.Query.get_by_id_and_org(query_id, org) query = models.Query.get_by_id_and_org(query_id, org)
if query.data_source: if query.data_source:
query_result = models.QueryResult.get_by_id_and_org( query_result = models.QueryResult.get_by_id_and_org(query.latest_query_data_id, org)
query.latest_query_data_id, org
)
return query_result.data return query_result.data
else: else:
raise QueryDetachedFromDataSourceError(query_id) raise QueryDetachedFromDataSourceError(query_id)
@@ -38,18 +37,14 @@ def dropdown_values(query_id, org):
def join_parameter_list_values(parameters, schema): def join_parameter_list_values(parameters, schema):
updated_parameters = {} updated_parameters = {}
for (key, value) in parameters.items(): for key, value in parameters.items():
if isinstance(value, list): if isinstance(value, list):
definition = next( definition = next((definition for definition in schema if definition["name"] == key), {})
(definition for definition in schema if definition["name"] == key), {}
)
multi_values_options = definition.get("multiValuesOptions", {}) multi_values_options = definition.get("multiValuesOptions", {})
separator = str(multi_values_options.get("separator", ",")) separator = str(multi_values_options.get("separator", ","))
prefix = str(multi_values_options.get("prefix", "")) prefix = str(multi_values_options.get("prefix", ""))
suffix = str(multi_values_options.get("suffix", "")) suffix = str(multi_values_options.get("suffix", ""))
updated_parameters[key] = separator.join( updated_parameters[key] = separator.join([prefix + v + suffix for v in value])
[prefix + v + suffix for v in value]
)
else: else:
updated_parameters[key] = value updated_parameters[key] = value
return updated_parameters return updated_parameters
@@ -126,16 +121,12 @@ class ParameterizedQuery(object):
self.parameters = {} self.parameters = {}
def apply(self, parameters): def apply(self, parameters):
invalid_parameter_names = [ invalid_parameter_names = [key for (key, value) in parameters.items() if not self._valid(key, value)]
key for (key, value) in parameters.items() if not self._valid(key, value)
]
if invalid_parameter_names: if invalid_parameter_names:
raise InvalidParameterError(invalid_parameter_names) raise InvalidParameterError(invalid_parameter_names)
else: else:
self.parameters.update(parameters) self.parameters.update(parameters)
self.query = mustache_render( self.query = mustache_render(self.template, join_parameter_list_values(parameters, self.schema))
self.template, join_parameter_list_values(parameters, self.schema)
)
return self return self
@@ -161,9 +152,7 @@ class ParameterizedQuery(object):
validators = { validators = {
"text": lambda value: isinstance(value, str), "text": lambda value: isinstance(value, str),
"number": _is_number, "number": _is_number,
"enum": lambda value: _is_value_within_options( "enum": lambda value: _is_value_within_options(value, enum_options, allow_multiple_values),
value, enum_options, allow_multiple_values
),
"query": lambda value: _is_value_within_options( "query": lambda value: _is_value_within_options(
value, value,
[v["value"] for v in dropdown_values(query_id, self.org)], [v["value"] for v in dropdown_values(query_id, self.org)],
@@ -199,9 +188,7 @@ class ParameterizedQuery(object):
class InvalidParameterError(Exception): class InvalidParameterError(Exception):
def __init__(self, parameters): def __init__(self, parameters):
parameter_names = ", ".join(parameters) parameter_names = ", ".join(parameters)
message = "The following parameter values are incompatible with their definitions: {}".format( message = "The following parameter values are incompatible with their definitions: {}".format(parameter_names)
parameter_names
)
super(InvalidParameterError, self).__init__(message) super(InvalidParameterError, self).__init__(message)

View File

@@ -1,10 +1,9 @@
import pytz
from sqlalchemy.types import TypeDecorator
from sqlalchemy.ext.indexable import index_property
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy_utils import EncryptedType
from sqlalchemy import cast from sqlalchemy import cast
from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.ext.indexable import index_property
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy.types import TypeDecorator
from sqlalchemy_utils import EncryptedType
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
from redash.utils.configuration import ConfigurationContainer from redash.utils.configuration import ConfigurationContainer
@@ -24,9 +23,7 @@ class Configuration(TypeDecorator):
class EncryptedConfiguration(EncryptedType): class EncryptedConfiguration(EncryptedType):
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
return super(EncryptedConfiguration, self).process_bind_param( return super(EncryptedConfiguration, self).process_bind_param(value.to_json(), dialect)
value.to_json(), dialect
)
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
return ConfigurationContainer.from_json( return ConfigurationContainer.from_json(
@@ -118,9 +115,11 @@ class pseudo_json_cast_property(index_property):
entity attribute as the specified cast type. Useful entity attribute as the specified cast type. Useful
for PseudoJSON colums for easier querying/filtering. for PseudoJSON colums for easier querying/filtering.
""" """
def __init__(self, cast_type, *args, **kwargs): def __init__(self, cast_type, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.cast_type = cast_type self.cast_type = cast_type
def expr(self, model): def expr(self, model):
expr = cast(getattr(model, self.attr_name), JSON)[self.index] expr = cast(getattr(model, self.attr_name), JSON)[self.index]
return expr.astext.cast(self.cast_type) return expr.astext.cast(self.cast_type)

View File

@@ -76,9 +76,7 @@ class PermissionsCheckMixin(object):
@generic_repr("id", "name", "email") @generic_repr("id", "name", "email")
class User( class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin):
TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin
):
id = primary_key("User") id = primary_key("User")
org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id"))
org = db.relationship("Organization", backref=db.backref("users", lazy="dynamic")) org = db.relationship("Organization", backref=db.backref("users", lazy="dynamic"))
@@ -99,18 +97,10 @@ class User(
server_default="{}", server_default="{}",
default={}, default={},
) )
active_at = json_cast_property( active_at = json_cast_property(db.DateTime(True), "details", "active_at", default=None)
db.DateTime(True), "details", "active_at", default=None _profile_image_url = json_cast_property(db.Text(), "details", "profile_image_url", default=None)
) is_invitation_pending = json_cast_property(db.Boolean(True), "details", "is_invitation_pending", default=False)
_profile_image_url = json_cast_property( is_email_verified = json_cast_property(db.Boolean(True), "details", "is_email_verified", default=True)
db.Text(), "details", "profile_image_url", default=None
)
is_invitation_pending = json_cast_property(
db.Boolean(True), "details", "is_invitation_pending", default=False
)
is_email_verified = json_cast_property(
db.Boolean(True), "details", "is_email_verified", default=True
)
__tablename__ = "users" __tablename__ = "users"
__table_args__ = (db.Index("users_org_id_email", "org_id", "email", unique=True),) __table_args__ = (db.Index("users_org_id_email", "org_id", "email", unique=True),)
@@ -182,14 +172,7 @@ class User(
@property @property
def permissions(self): def permissions(self):
# TODO: this should be cached. # TODO: this should be cached.
return list( return list(itertools.chain(*[g.permissions for g in Group.query.filter(Group.id.in_(self.group_ids))]))
itertools.chain(
*[
g.permissions
for g in Group.query.filter(Group.id.in_(self.group_ids))
]
)
)
@classmethod @classmethod
def get_by_org(cls, org): def get_by_org(cls, org):
@@ -227,9 +210,7 @@ class User(
if pending: if pending:
return base_query.filter(cls.is_invitation_pending.is_(True)) return base_query.filter(cls.is_invitation_pending.is_(True))
else: else:
return base_query.filter( return base_query.filter(cls.is_invitation_pending.isnot(True)) # check for both `false`/`null`
cls.is_invitation_pending.isnot(True)
) # check for both `false`/`null`
@classmethod @classmethod
def find_by_email(cls, email): def find_by_email(cls, email):
@@ -252,9 +233,7 @@ class User(
return AccessPermission.exists(obj, access_type, grantee=self) return AccessPermission.exists(obj, access_type, grantee=self)
def get_id(self): def get_id(self):
identity = hashlib.md5( identity = hashlib.md5("{},{}".format(self.email, self.password_hash).encode()).hexdigest()
"{},{}".format(self.email, self.password_hash).encode()
).hexdigest()
return "{0}-{1}".format(self.id, identity) return "{0}-{1}".format(self.id, identity)
@@ -279,9 +258,7 @@ class Group(db.Model, BelongsToOrgMixin):
REGULAR_GROUP = "regular" REGULAR_GROUP = "regular"
id = primary_key("Group") id = primary_key("Group")
data_sources = db.relationship( data_sources = db.relationship("DataSourceGroup", back_populates="group", cascade="all")
"DataSourceGroup", back_populates="group", cascade="all"
)
org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id")) org_id = Column(key_type("Organization"), db.ForeignKey("organizations.id"))
org = db.relationship("Organization", back_populates="groups") org = db.relationship("Organization", back_populates="groups")
type = Column(db.String(255), default=REGULAR_GROUP) type = Column(db.String(255), default=REGULAR_GROUP)
@@ -317,9 +294,7 @@ class Group(db.Model, BelongsToOrgMixin):
return list(result) return list(result)
@generic_repr( @generic_repr("id", "object_type", "object_id", "access_type", "grantor_id", "grantee_id")
"id", "object_type", "object_id", "access_type", "grantor_id", "grantee_id"
)
class AccessPermission(GFKBase, db.Model): class AccessPermission(GFKBase, db.Model):
id = primary_key("AccessPermission") id = primary_key("AccessPermission")
# 'object' defined in GFKBase # 'object' defined in GFKBase
@@ -368,9 +343,7 @@ class AccessPermission(GFKBase, db.Model):
@classmethod @classmethod
def _query(cls, obj, access_type=None, grantee=None, grantor=None): def _query(cls, obj, access_type=None, grantee=None, grantor=None):
q = cls.query.filter( q = cls.query.filter(cls.object_id == obj.id, cls.object_type == obj.__tablename__)
cls.object_id == obj.id, cls.object_type == obj.__tablename__
)
if access_type: if access_type:
q = q.filter(AccessPermission.access_type == access_type) q = q.filter(AccessPermission.access_type == access_type)

View File

@@ -1,14 +1,11 @@
from __future__ import absolute_import
import itertools
from funcy import flatten from funcy import flatten
from sqlalchemy import union_all
from redash import redis_connection, rq_redis_connection, __version__, settings
from redash.models import db, DataSource, Query, QueryResult, Dashboard, Widget
from redash.utils import json_loads
from rq import Queue, Worker from rq import Queue, Worker
from rq.job import Job from rq.job import Job
from rq.registry import StartedJobRegistry from rq.registry import StartedJobRegistry
from redash import __version__, redis_connection, rq_redis_connection, settings
from redash.models import Dashboard, Query, QueryResult, Widget, db
def get_redis_status(): def get_redis_status():
info = redis_connection.info() info = redis_connection.info()

View File

@@ -92,9 +92,7 @@ def require_super_admin(fn):
def has_permission_or_owner(permission, object_owner_id): def has_permission_or_owner(permission, object_owner_id):
return int(object_owner_id) == current_user.id or current_user.has_permission( return int(object_owner_id) == current_user.id or current_user.has_permission(permission)
permission
)
def is_admin_or_owner(object_owner_id): def is_admin_or_owner(object_owner_id):

View File

@@ -1,22 +1,23 @@
import logging
from contextlib import ExitStack
from dateutil import parser
from functools import wraps
import socket
import ipaddress import ipaddress
import logging
import socket
from contextlib import ExitStack
from functools import wraps
from urllib.parse import urlparse from urllib.parse import urlparse
import sqlparse
from dateutil import parser
from rq.timeouts import JobTimeoutException
from six import text_type from six import text_type
from sshtunnel import open_tunnel from sshtunnel import open_tunnel
from redash import settings, utils from redash import settings, utils
from redash.utils import json_loads from redash.utils import json_loads
from rq.timeouts import JobTimeoutException from redash.utils.requests_session import (
UnacceptableAddressException,
from redash.utils.requests_session import requests_or_advocate, requests_session, UnacceptableAddressException requests_or_advocate,
requests_session,
)
import sqlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -47,9 +48,8 @@ TYPE_STRING = "string"
TYPE_DATETIME = "datetime" TYPE_DATETIME = "datetime"
TYPE_DATE = "date" TYPE_DATE = "date"
SUPPORTED_COLUMN_TYPES = set( SUPPORTED_COLUMN_TYPES = set([TYPE_INTEGER, TYPE_FLOAT, TYPE_BOOLEAN, TYPE_STRING, TYPE_DATETIME, TYPE_DATE])
[TYPE_INTEGER, TYPE_FLOAT, TYPE_BOOLEAN, TYPE_STRING, TYPE_DATETIME, TYPE_DATE]
)
def split_sql_statements(query): def split_sql_statements(query):
def strip_trailing_comments(stmt): def strip_trailing_comments(stmt):
@@ -57,7 +57,7 @@ def split_sql_statements(query):
while idx >= 0: while idx >= 0:
tok = stmt.tokens[idx] tok = stmt.tokens[idx]
if tok.is_whitespace or sqlparse.utils.imt(tok, i=sqlparse.sql.Comment, t=sqlparse.tokens.Comment): if tok.is_whitespace or sqlparse.utils.imt(tok, i=sqlparse.sql.Comment, t=sqlparse.tokens.Comment):
stmt.tokens[idx] = sqlparse.sql.Token(sqlparse.tokens.Whitespace, ' ') stmt.tokens[idx] = sqlparse.sql.Token(sqlparse.tokens.Whitespace, " ")
else: else:
break break
idx -= 1 idx -= 1
@@ -70,7 +70,7 @@ def split_sql_statements(query):
# we expect that trailing comments already are removed # we expect that trailing comments already are removed
if not tok.is_whitespace: if not tok.is_whitespace:
if sqlparse.utils.imt(tok, t=sqlparse.tokens.Punctuation) and tok.value == ";": if sqlparse.utils.imt(tok, t=sqlparse.tokens.Punctuation) and tok.value == ";":
stmt.tokens[idx] = sqlparse.sql.Token(sqlparse.tokens.Whitespace, ' ') stmt.tokens[idx] = sqlparse.sql.Token(sqlparse.tokens.Whitespace, " ")
break break
idx -= 1 idx -= 1
return stmt return stmt
@@ -101,12 +101,14 @@ def split_sql_statements(query):
def combine_sql_statements(queries): def combine_sql_statements(queries):
return ";\n".join(queries) return ";\n".join(queries)
def find_last_keyword_idx(parsed_query): def find_last_keyword_idx(parsed_query):
for i in reversed(range(len(parsed_query.tokens))): for i in reversed(range(len(parsed_query.tokens))):
if parsed_query.tokens[i].ttype in sqlparse.tokens.Keyword: if parsed_query.tokens[i].ttype in sqlparse.tokens.Keyword:
return i return i
return -1 return -1
class InterruptException(Exception): class InterruptException(Exception):
pass pass
@@ -120,7 +122,7 @@ class BaseQueryRunner(object):
should_annotate_query = True should_annotate_query = True
noop_query = None noop_query = None
limit_query = " LIMIT 1000" limit_query = " LIMIT 1000"
limit_keywords = [ "LIMIT", "OFFSET"] limit_keywords = ["LIMIT", "OFFSET"]
def __init__(self, configuration): def __init__(self, configuration):
self.syntax = "sql" self.syntax = "sql"
@@ -225,9 +227,7 @@ class BaseQueryRunner(object):
duplicates_counter += 1 duplicates_counter += 1
column_names.append(column_name) column_names.append(column_name)
new_columns.append( new_columns.append({"name": column_name, "friendly_name": column_name, "type": col[1]})
{"name": column_name, "friendly_name": column_name, "type": col[1]}
)
return new_columns return new_columns
@@ -306,12 +306,11 @@ class BaseSQLQueryRunner(BaseQueryRunner):
limit_tokens = sqlparse.parse(self.limit_query)[0].tokens limit_tokens = sqlparse.parse(self.limit_query)[0].tokens
length = len(parsed_query.tokens) length = len(parsed_query.tokens)
if parsed_query.tokens[length - 1].ttype == sqlparse.tokens.Punctuation: if parsed_query.tokens[length - 1].ttype == sqlparse.tokens.Punctuation:
parsed_query.tokens[length - 1:length - 1] = limit_tokens parsed_query.tokens[length - 1 : length - 1] = limit_tokens
else: else:
parsed_query.tokens += limit_tokens parsed_query.tokens += limit_tokens
return str(parsed_query) return str(parsed_query)
def apply_auto_limit(self, query_text, should_apply_auto_limit): def apply_auto_limit(self, query_text, should_apply_auto_limit):
if should_apply_auto_limit: if should_apply_auto_limit:
queries = split_sql_statements(query_text) queries = split_sql_statements(query_text)
@@ -367,7 +366,6 @@ class BaseHTTPQueryRunner(BaseQueryRunner):
return None return None
def get_response(self, url, auth=None, http_method="get", **kwargs): def get_response(self, url, auth=None, http_method="get", **kwargs):
# Get authentication values if not given # Get authentication values if not given
if auth is None: if auth is None:
auth = self.get_auth() auth = self.get_auth()
@@ -389,9 +387,8 @@ class BaseHTTPQueryRunner(BaseQueryRunner):
except requests_or_advocate.HTTPError as exc: except requests_or_advocate.HTTPError as exc:
logger.exception(exc) logger.exception(exc)
error = "Failed to execute query. " "Return Code: {} Reason: {}".format( error = "Failed to execute query. "
response.status_code, response.text f"Return Code: {response.status_code} Reason: {response.text}"
)
except UnacceptableAddressException as exc: except UnacceptableAddressException as exc:
logger.exception(exc) logger.exception(exc)
error = "Can't query private addresses." error = "Can't query private addresses."
@@ -491,9 +488,7 @@ def with_ssh_tunnel(query_runner, details):
try: try:
remote_host, remote_port = query_runner.host, query_runner.port remote_host, remote_port = query_runner.host, query_runner.port
except NotImplementedError: except NotImplementedError:
raise NotImplementedError( raise NotImplementedError("SSH tunneling is not implemented for this query runner yet.")
"SSH tunneling is not implemented for this query runner yet."
)
stack = ExitStack() stack = ExitStack()
try: try:
@@ -503,11 +498,7 @@ def with_ssh_tunnel(query_runner, details):
"ssh_username": details["ssh_username"], "ssh_username": details["ssh_username"],
**settings.dynamic_settings.ssh_tunnel_auth(), **settings.dynamic_settings.ssh_tunnel_auth(),
} }
server = stack.enter_context( server = stack.enter_context(open_tunnel(bastion_address, remote_bind_address=remote_address, **auth))
open_tunnel(
bastion_address, remote_bind_address=remote_address, **auth
)
)
except Exception as error: except Exception as error:
raise type(error)("SSH tunnel: {}".format(str(error))) raise type(error)("SSH tunnel: {}".format(str(error)))

View File

@@ -1,9 +1,9 @@
from .elasticsearch2 import ElasticSearch2
from . import register from . import register
from .elasticsearch2 import ElasticSearch2
try: try:
from botocore import credentials, session
from requests_aws_sign import AWSV4Sign from requests_aws_sign import AWSV4Sign
from botocore import session, credentials
enabled = True enabled = True
except ImportError: except ImportError:

View File

@@ -1,6 +1,12 @@
import logging import logging
from redash.query_runner import * from redash.query_runner import (
TYPE_BOOLEAN,
TYPE_FLOAT,
TYPE_STRING,
BaseQueryRunner,
register,
)
from redash.utils import json_dumps from redash.utils import json_dumps
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -49,7 +55,7 @@ class Arango(BaseQueryRunner):
@classmethod @classmethod
def enabled(cls): def enabled(cls):
try: try:
import arango import arango # noqa: F401
except ImportError: except ImportError:
return False return False
@@ -60,18 +66,15 @@ class Arango(BaseQueryRunner):
return "arangodb" return "arangodb"
def run_query(self, query, user): def run_query(self, query, user):
client = ArangoClient(hosts='{}:{}'.format(self.configuration["host"], client = ArangoClient(hosts="{}:{}".format(self.configuration["host"], self.configuration.get("port", 8529)))
self.configuration.get("port", 8529))) db = client.db(
db = client.db(self.configuration["dbname"], self.configuration["dbname"], username=self.configuration["user"], password=self.configuration["password"]
username=self.configuration["user"], )
password=self.configuration["password"])
try: try:
cursor = db.aql.execute(query, max_runtime=self.configuration.get("timeout", 0.0)) cursor = db.aql.execute(query, max_runtime=self.configuration.get("timeout", 0.0))
result = [i for i in cursor] result = [i for i in cursor]
column_tuples = [ column_tuples = [(i, TYPE_STRING) for i in result[0].keys()]
(i, TYPE_STRING) for i in result[0].keys()
]
columns = self.fetch_columns(column_tuples) columns = self.fetch_columns(column_tuples)
data = { data = {
"columns": columns, "columns": columns,

View File

@@ -1,23 +1,28 @@
import logging import logging
import os import os
from redash.query_runner import * from redash.query_runner import (
TYPE_BOOLEAN,
TYPE_DATE,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseQueryRunner,
register,
)
from redash.settings import parse_boolean from redash.settings import parse_boolean
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ANNOTATE_QUERY = parse_boolean(os.environ.get("ATHENA_ANNOTATE_QUERY", "true")) ANNOTATE_QUERY = parse_boolean(os.environ.get("ATHENA_ANNOTATE_QUERY", "true"))
SHOW_EXTRA_SETTINGS = parse_boolean( SHOW_EXTRA_SETTINGS = parse_boolean(os.environ.get("ATHENA_SHOW_EXTRA_SETTINGS", "true"))
os.environ.get("ATHENA_SHOW_EXTRA_SETTINGS", "true")
)
ASSUME_ROLE = parse_boolean(os.environ.get("ATHENA_ASSUME_ROLE", "false")) ASSUME_ROLE = parse_boolean(os.environ.get("ATHENA_ASSUME_ROLE", "false"))
OPTIONAL_CREDENTIALS = parse_boolean( OPTIONAL_CREDENTIALS = parse_boolean(os.environ.get("ATHENA_OPTIONAL_CREDENTIALS", "true"))
os.environ.get("ATHENA_OPTIONAL_CREDENTIALS", "true")
)
try: try:
import pyathena
import boto3 import boto3
import pyathena
enabled = True enabled = True
except ImportError: except ImportError:
@@ -180,14 +185,11 @@ class Athena(BaseQueryRunner):
iterator = table_paginator.paginate(DatabaseName=database["Name"]) iterator = table_paginator.paginate(DatabaseName=database["Name"])
for table in iterator.search("TableList[]"): for table in iterator.search("TableList[]"):
table_name = "%s.%s" % (database["Name"], table["Name"]) table_name = "%s.%s" % (database["Name"], table["Name"])
if 'StorageDescriptor' not in table: if "StorageDescriptor" not in table:
logger.warning("Glue table doesn't have StorageDescriptor: %s", table_name) logger.warning("Glue table doesn't have StorageDescriptor: %s", table_name)
continue continue
if table_name not in schema: if table_name not in schema:
column = [ column = [columns["Name"] for columns in table["StorageDescriptor"]["Columns"]]
columns["Name"]
for columns in table["StorageDescriptor"]["Columns"]
]
schema[table_name] = {"name": table_name, "columns": column} schema[table_name] = {"name": table_name, "columns": column}
for partition in table.get("PartitionKeys", []): for partition in table.get("PartitionKeys", []):
schema[table_name]["columns"].append(partition["Name"]) schema[table_name]["columns"].append(partition["Name"])
@@ -225,19 +227,14 @@ class Athena(BaseQueryRunner):
kms_key=self.configuration.get("kms_key", None), kms_key=self.configuration.get("kms_key", None),
work_group=self.configuration.get("work_group", "primary"), work_group=self.configuration.get("work_group", "primary"),
formatter=SimpleFormatter(), formatter=SimpleFormatter(),
**self._get_iam_credentials(user=user) **self._get_iam_credentials(user=user),
).cursor() ).cursor()
try: try:
cursor.execute(query) cursor.execute(query)
column_tuples = [ column_tuples = [(i[0], _TYPE_MAPPINGS.get(i[1], None)) for i in cursor.description]
(i[0], _TYPE_MAPPINGS.get(i[1], None)) for i in cursor.description
]
columns = self.fetch_columns(column_tuples) columns = self.fetch_columns(column_tuples)
rows = [ rows = [dict(zip(([c["name"] for c in columns]), r)) for i, r in enumerate(cursor.fetchall())]
dict(zip(([c["name"] for c in columns]), r))
for i, r in enumerate(cursor.fetchall())
]
qbytes = None qbytes = None
athena_query_id = None athena_query_id = None
try: try:

View File

@@ -1,10 +1,18 @@
from io import StringIO
import logging
import sys
import uuid
import csv import csv
import logging
import uuid
from redash.query_runner import * from redash.query_runner import (
TYPE_DATE,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseQueryRunner,
InterruptException,
JobTimeoutException,
register,
)
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -12,7 +20,7 @@ logger = logging.getLogger(__name__)
try: try:
import atsd_client import atsd_client
from atsd_client.exceptions import SQLException from atsd_client.exceptions import SQLException
from atsd_client.services import SQLService, MetricsService from atsd_client.services import MetricsService, SQLService
enabled = True enabled = True
except ImportError: except ImportError:

View File

@@ -1,18 +1,22 @@
from redash.query_runner import BaseQueryRunner, register
from redash.query_runner import ( from redash.query_runner import (
TYPE_STRING, TYPE_BOOLEAN,
TYPE_DATE, TYPE_DATE,
TYPE_DATETIME, TYPE_DATETIME,
TYPE_INTEGER,
TYPE_FLOAT, TYPE_FLOAT,
TYPE_BOOLEAN, TYPE_INTEGER,
TYPE_STRING,
BaseQueryRunner,
register,
) )
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
try: try:
from azure.kusto.data.request import KustoClient, KustoConnectionStringBuilder, ClientRequestProperties
from azure.kusto.data.exceptions import KustoServiceError from azure.kusto.data.exceptions import KustoServiceError
from azure.kusto.data.request import (
ClientRequestProperties,
KustoClient,
KustoConnectionStringBuilder,
)
enabled = True enabled = True
except ImportError: except ImportError:
@@ -87,7 +91,6 @@ class AzureKusto(BaseQueryRunner):
return "Azure Data Explorer (Kusto)" return "Azure Data Explorer (Kusto)"
def run_query(self, query, user): def run_query(self, query, user):
kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication(
connection_string=self.configuration["cluster"], connection_string=self.configuration["cluster"],
aad_app_id=self.configuration["azure_ad_client_id"], aad_app_id=self.configuration["azure_ad_client_id"],
@@ -143,9 +146,7 @@ class AzureKusto(BaseQueryRunner):
results = json_loads(results) results = json_loads(results)
schema_as_json = json_loads(results["rows"][0]["DatabaseSchema"]) schema_as_json = json_loads(results["rows"][0]["DatabaseSchema"])
tables_list = schema_as_json["Databases"][self.configuration["database"]][ tables_list = schema_as_json["Databases"][self.configuration["database"]]["Tables"].values()
"Tables"
].values()
schema = {} schema = {}

View File

@@ -1,14 +1,22 @@
import datetime import datetime
import logging import logging
import sys
import time import time
from base64 import b64decode from base64 import b64decode
import httplib2 import httplib2
import requests
from redash import settings from redash import settings
from redash.query_runner import * from redash.query_runner import (
TYPE_BOOLEAN,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseQueryRunner,
InterruptException,
JobTimeoutException,
register,
)
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,7 +24,7 @@ logger = logging.getLogger(__name__)
try: try:
import apiclient.errors import apiclient.errors
from apiclient.discovery import build from apiclient.discovery import build
from apiclient.errors import HttpError from apiclient.errors import HttpError # noqa: F401
from oauth2client.service_account import ServiceAccountCredentials from oauth2client.service_account import ServiceAccountCredentials
enabled = True enabled = True
@@ -52,9 +60,7 @@ def transform_row(row, fields):
for column_index, cell in enumerate(row["f"]): for column_index, cell in enumerate(row["f"]):
field = fields[column_index] field = fields[column_index]
if field.get("mode") == "REPEATED": if field.get("mode") == "REPEATED":
cell_value = [ cell_value = [transform_cell(field["type"], item["v"]) for item in cell["v"]]
transform_cell(field["type"], item["v"]) for item in cell["v"]
]
else: else:
cell_value = transform_cell(field["type"], cell["v"]) cell_value = transform_cell(field["type"], cell["v"])
@@ -64,7 +70,7 @@ def transform_row(row, fields):
def _load_key(filename): def _load_key(filename):
f = file(filename, "rb") f = open(filename, "rb")
try: try:
return f.read() return f.read()
finally: finally:
@@ -180,17 +186,13 @@ class BigQuery(BaseQueryRunner):
job_data["configuration"]["query"]["useLegacySql"] = False job_data["configuration"]["query"]["useLegacySql"] = False
if self.configuration.get("userDefinedFunctionResourceUri"): if self.configuration.get("userDefinedFunctionResourceUri"):
resource_uris = self.configuration["userDefinedFunctionResourceUri"].split( resource_uris = self.configuration["userDefinedFunctionResourceUri"].split(",")
","
)
job_data["configuration"]["query"]["userDefinedFunctionResources"] = [ job_data["configuration"]["query"]["userDefinedFunctionResources"] = [
{"resourceUri": resource_uri} for resource_uri in resource_uris {"resourceUri": resource_uri} for resource_uri in resource_uris
] ]
if "maximumBillingTier" in self.configuration: if "maximumBillingTier" in self.configuration:
job_data["configuration"]["query"][ job_data["configuration"]["query"]["maximumBillingTier"] = self.configuration["maximumBillingTier"]
"maximumBillingTier"
] = self.configuration["maximumBillingTier"]
return job_data return job_data
@@ -233,9 +235,7 @@ class BigQuery(BaseQueryRunner):
{ {
"name": f["name"], "name": f["name"],
"friendly_name": f["name"], "friendly_name": f["name"],
"type": "string" "type": "string" if f.get("mode") == "REPEATED" else types_map.get(f["type"], "string"),
if f.get("mode") == "REPEATED"
else types_map.get(f["type"], "string"),
} }
for f in query_reply["schema"]["fields"] for f in query_reply["schema"]["fields"]
] ]
@@ -273,12 +273,12 @@ class BigQuery(BaseQueryRunner):
datasets = service.datasets().list(projectId=project_id).execute() datasets = service.datasets().list(projectId=project_id).execute()
result.extend(datasets.get("datasets", [])) result.extend(datasets.get("datasets", []))
nextPageToken = datasets.get('nextPageToken', None) nextPageToken = datasets.get("nextPageToken", None)
while nextPageToken is not None: while nextPageToken is not None:
datasets = service.datasets().list(projectId=project_id, pageToken=nextPageToken).execute() datasets = service.datasets().list(projectId=project_id, pageToken=nextPageToken).execute()
result.extend(datasets.get("datasets", [])) result.extend(datasets.get("datasets", []))
nextPageToken = datasets.get('nextPageToken', None) nextPageToken = datasets.get("nextPageToken", None)
return result return result
@@ -302,7 +302,7 @@ class BigQuery(BaseQueryRunner):
query = query_base.format(dataset_id=dataset_id) query = query_base.format(dataset_id=dataset_id)
queries.append(query) queries.append(query)
query = '\nUNION ALL\n'.join(queries) query = "\nUNION ALL\n".join(queries)
results, error = self.run_query(query, None) results, error = self.run_query(query, None)
if error is not None: if error is not None:
self._handle_run_query_error(error) self._handle_run_query_error(error)
@@ -325,14 +325,11 @@ class BigQuery(BaseQueryRunner):
try: try:
if "totalMBytesProcessedLimit" in self.configuration: if "totalMBytesProcessedLimit" in self.configuration:
limitMB = self.configuration["totalMBytesProcessedLimit"] limitMB = self.configuration["totalMBytesProcessedLimit"]
processedMB = ( processedMB = self._get_total_bytes_processed(jobs, query) / 1000.0 / 1000.0
self._get_total_bytes_processed(jobs, query) / 1000.0 / 1000.0
)
if limitMB < processedMB: if limitMB < processedMB:
return ( return (
None, None,
"Larger than %d MBytes will be processed (%f MBytes)" "Larger than %d MBytes will be processed (%f MBytes)" % (limitMB, processedMB),
% (limitMB, processedMB),
) )
data = self._get_query_result(jobs, query) data = self._get_query_result(jobs, query)

View File

@@ -1,5 +1,5 @@
import requests
import httplib2 import httplib2
import requests
try: try:
from apiclient.discovery import build from apiclient.discovery import build
@@ -10,6 +10,7 @@ except ImportError:
enabled = False enabled = False
from redash.query_runner import register from redash.query_runner import register
from .big_query import BigQuery from .big_query import BigQuery
@@ -65,9 +66,7 @@ class BigQueryGCE(BigQuery):
).content ).content
def _get_bigquery_service(self): def _get_bigquery_service(self):
credentials = gce.AppAssertionCredentials( credentials = gce.AppAssertionCredentials(scope="https://www.googleapis.com/auth/bigquery")
scope="https://www.googleapis.com/auth/bigquery"
)
http = httplib2.Http() http = httplib2.Http()
http = credentials.authorize(http) http = credentials.authorize(http)

View File

@@ -10,8 +10,8 @@ from redash.utils import JSONEncoder, json_dumps, json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import Cluster
from cassandra.util import sortedset from cassandra.util import sortedset
enabled = True enabled = True
@@ -20,12 +20,10 @@ except ImportError:
def generate_ssl_options_dict(protocol, cert_path=None): def generate_ssl_options_dict(protocol, cert_path=None):
ssl_options = { ssl_options = {"ssl_version": getattr(ssl, protocol)}
'ssl_version': getattr(ssl, protocol)
}
if cert_path is not None: if cert_path is not None:
ssl_options['ca_certs'] = cert_path ssl_options["ca_certs"] = cert_path
ssl_options['cert_reqs'] = ssl.CERT_REQUIRED ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
return ssl_options return ssl_options
@@ -60,10 +58,7 @@ class Cassandra(BaseQueryRunner):
}, },
"timeout": {"type": "number", "title": "Timeout", "default": 10}, "timeout": {"type": "number", "title": "Timeout", "default": 10},
"useSsl": {"type": "boolean", "title": "Use SSL", "default": False}, "useSsl": {"type": "boolean", "title": "Use SSL", "default": False},
"sslCertificateFile": { "sslCertificateFile": {"type": "string", "title": "SSL Certificate File"},
"type": "string",
"title": "SSL Certificate File"
},
"sslProtocol": { "sslProtocol": {
"type": "string", "type": "string",
"title": "SSL Protocol", "title": "SSL Protocol",
@@ -127,9 +122,7 @@ class Cassandra(BaseQueryRunner):
def run_query(self, query, user): def run_query(self, query, user):
connection = None connection = None
cert_path = self._generate_cert_file() cert_path = self._generate_cert_file()
if self.configuration.get("username", "") and self.configuration.get( if self.configuration.get("username", "") and self.configuration.get("password", ""):
"password", ""
):
auth_provider = PlainTextAuthProvider( auth_provider = PlainTextAuthProvider(
username="{}".format(self.configuration.get("username", "")), username="{}".format(self.configuration.get("username", "")),
password="{}".format(self.configuration.get("password", "")), password="{}".format(self.configuration.get("password", "")),
@@ -169,7 +162,7 @@ class Cassandra(BaseQueryRunner):
def _generate_cert_file(self): def _generate_cert_file(self):
cert_encoded_bytes = self.configuration.get("sslCertificateFile", None) cert_encoded_bytes = self.configuration.get("sslCertificateFile", None)
if cert_encoded_bytes: if cert_encoded_bytes:
with NamedTemporaryFile(mode='w', delete=False) as cert_file: with NamedTemporaryFile(mode="w", delete=False) as cert_file:
cert_bytes = b64decode(cert_encoded_bytes) cert_bytes = b64decode(cert_encoded_bytes)
cert_file.write(cert_bytes.decode("utf-8")) cert_file.write(cert_bytes.decode("utf-8"))
return cert_file.name return cert_file.name
@@ -182,10 +175,7 @@ class Cassandra(BaseQueryRunner):
def _get_ssl_options(self, cert_path): def _get_ssl_options(self, cert_path):
ssl_options = None ssl_options = None
if self.configuration.get("useSsl", False): if self.configuration.get("useSsl", False):
ssl_options = generate_ssl_options_dict( ssl_options = generate_ssl_options_dict(protocol=self.configuration["sslProtocol"], cert_path=cert_path)
protocol=self.configuration["sslProtocol"],
cert_path=cert_path
)
return ssl_options return ssl_options

View File

@@ -5,8 +5,16 @@ from uuid import uuid4
import requests import requests
from redash.query_runner import * from redash.query_runner import (
from redash.query_runner import split_sql_statements TYPE_DATE,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseSQLQueryRunner,
register,
split_sql_statements,
)
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -131,9 +139,7 @@ class ClickHouse(BaseSQLQueryRunner):
return r.json() return r.json()
except requests.RequestException as e: except requests.RequestException as e:
if e.response: if e.response:
details = "({}, Status Code: {})".format( details = "({}, Status Code: {})".format(e.__class__.__name__, e.response.status_code)
e.__class__.__name__, e.response.status_code
)
else: else:
details = "({})".format(e.__class__.__name__) details = "({})".format(e.__class__.__name__)
raise Exception("Connection error to: {} {}.".format(url, details)) raise Exception("Connection error to: {} {}.".format(url, details))
@@ -174,13 +180,9 @@ class ClickHouse(BaseSQLQueryRunner):
if r["type"] in ("Int64", "UInt64", "Nullable(Int64)", "Nullable(UInt64)"): if r["type"] in ("Int64", "UInt64", "Nullable(Int64)", "Nullable(UInt64)"):
columns_int64.append(column_name) columns_int64.append(column_name)
else: else:
columns_totals[column_name] = ( columns_totals[column_name] = "Total" if column_type == TYPE_STRING else None
"Total" if column_type == TYPE_STRING else None
)
columns.append( columns.append({"name": column_name, "friendly_name": column_name, "type": column_type})
{"name": column_name, "friendly_name": column_name, "type": column_type}
)
rows = response.get("data", []) rows = response.get("data", [])
for row in rows: for row in rows:
@@ -215,14 +217,10 @@ class ClickHouse(BaseSQLQueryRunner):
# for the first query # for the first query
session_id = "redash_{}".format(uuid4().hex) session_id = "redash_{}".format(uuid4().hex)
results = self._clickhouse_query( results = self._clickhouse_query(queries[0], session_id, session_check=False)
queries[0], session_id, session_check=False
)
for query in queries[1:]: for query in queries[1:]:
results = self._clickhouse_query( results = self._clickhouse_query(query, session_id, session_check=True)
query, session_id, session_check=True
)
data = json_dumps(results) data = json_dumps(results)
error = None error = None

View File

@@ -1,15 +1,18 @@
import yaml
import datetime import datetime
import yaml
from redash.query_runner import BaseQueryRunner, register from redash.query_runner import BaseQueryRunner, register
from redash.utils import json_dumps, parse_human_time from redash.utils import json_dumps, parse_human_time
try: try:
import boto3 import boto3
enabled = True enabled = True
except ImportError: except ImportError:
enabled = False enabled = False
def parse_response(results): def parse_response(results):
columns = [ columns = [
{"name": "id", "type": "string"}, {"name": "id", "type": "string"},

View File

@@ -1,13 +1,15 @@
import yaml
import datetime import datetime
import time import time
import yaml
from redash.query_runner import BaseQueryRunner, register from redash.query_runner import BaseQueryRunner, register
from redash.utils import json_dumps, parse_human_time from redash.utils import json_dumps, parse_human_time
try: try:
import boto3 import boto3
from botocore.exceptions import ParamValidationError from botocore.exceptions import ParamValidationError # noqa: F401
enabled = True enabled = True
except ImportError: except ImportError:
enabled = False enabled = False
@@ -118,9 +120,7 @@ class CloudWatchInsights(BaseQueryRunner):
log_groups.append( log_groups.append(
{ {
"name": group_name, "name": group_name,
"columns": [ "columns": [field["name"] for field in fields["logGroupFields"]],
field["name"] for field in fields["logGroupFields"]
],
} }
) )
@@ -139,11 +139,7 @@ class CloudWatchInsights(BaseQueryRunner):
data = parse_response(result) data = parse_response(result)
break break
if result["status"] in ("Failed", "Timeout", "Unknown", "Cancelled"): if result["status"] in ("Failed", "Timeout", "Unknown", "Cancelled"):
raise Exception( raise Exception("CloudWatch Insights Query Execution Status: {}".format(result["status"]))
"CloudWatch Insights Query Execution Status: {}".format(
result["status"]
)
)
elif elapsed > TIMEOUT: elif elapsed > TIMEOUT:
raise Exception("Request exceeded timeout.") raise Exception("Request exceeded timeout.")
else: else:

View File

@@ -4,17 +4,22 @@ seeAlso: https://documentation.eccenca.com/
seeAlso: https://eccenca.com/ seeAlso: https://eccenca.com/
""" """
import logging
import json import json
import logging
from os import environ from os import environ
from redash.query_runner import BaseQueryRunner from redash.query_runner import BaseQueryRunner
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
from . import register from . import register
try: try:
from cmem.cmempy.queries import SparqlQuery, QueryCatalog, QUERY_STRING
from cmem.cmempy.dp.proxy.graph import get_graphs_list from cmem.cmempy.dp.proxy.graph import get_graphs_list
from cmem.cmempy.queries import ( # noqa: F401
QUERY_STRING,
QueryCatalog,
SparqlQuery,
)
enabled = True enabled = True
except ImportError: except ImportError:
@@ -151,9 +156,7 @@ class CorporateMemoryQueryRunner(BaseQueryRunner):
# type of None means, there is an error in the query # type of None means, there is an error in the query
# so execution is at least tried on endpoint # so execution is at least tried on endpoint
if query_type not in ["SELECT", None]: if query_type not in ["SELECT", None]:
raise ValueError( raise ValueError("Queries of type {} can not be processed by redash.".format(query_type))
"Queries of type {} can not be processed by redash.".format(query_type)
)
self._setup_environment() self._setup_environment()
try: try:

View File

@@ -1,16 +1,21 @@
import datetime import datetime
import logging import logging
from dateutil.parser import parse from redash.query_runner import (
TYPE_BOOLEAN,
from redash.query_runner import * TYPE_DATETIME,
from redash.utils import JSONEncoder, json_dumps, json_loads, parse_human_time TYPE_FLOAT,
import json TYPE_INTEGER,
TYPE_STRING,
BaseQueryRunner,
register,
)
from redash.utils import json_dumps
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
import httplib2 # noqa: F401
import requests import requests
import httplib2
except ImportError as e: except ImportError as e:
logger.error("Failed to import: " + str(e)) logger.error("Failed to import: " + str(e))
@@ -48,9 +53,7 @@ def parse_results(results):
{ {
"name": column_name, "name": column_name,
"friendly_name": column_name, "friendly_name": column_name,
"type": TYPES_MAP.get( "type": TYPES_MAP.get(type(row[key][inner_key]), TYPE_STRING),
type(row[key][inner_key]), TYPE_STRING
),
} }
) )
@@ -104,7 +107,7 @@ class Couchbase(BaseQueryRunner):
return True return True
def test_connection(self): def test_connection(self):
result = self.call_service(self.noop_query, "") self.call_service(self.noop_query, "")
def get_buckets(self, query, name_param): def get_buckets(self, query, name_param):
defaultColumns = ["meta().id"] defaultColumns = ["meta().id"]
@@ -117,7 +120,6 @@ class Couchbase(BaseQueryRunner):
return list(schema.values()) return list(schema.values())
def get_schema(self, get_stats=False): def get_schema(self, get_stats=False):
try: try:
# Try fetch from Analytics # Try fetch from Analytics
return self.get_buckets( return self.get_buckets(

View File

@@ -1,17 +1,21 @@
import logging
import yaml
import io import io
import logging
from redash.utils.requests_session import requests_or_advocate, UnacceptableAddressException import yaml
from redash.query_runner import * from redash.query_runner import BaseQueryRunner, NotSupported, register
from redash.utils import json_dumps from redash.utils import json_dumps
from redash.utils.requests_session import (
UnacceptableAddressException,
requests_or_advocate,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
import pandas as pd
import numpy as np import numpy as np
import pandas as pd
enabled = True enabled = True
except ImportError: except ImportError:
enabled = False enabled = False
@@ -31,8 +35,8 @@ class CSV(BaseQueryRunner):
@classmethod @classmethod
def configuration_schema(cls): def configuration_schema(cls):
return { return {
'type': 'object', "type": "object",
'properties': {}, "properties": {},
} }
def __init__(self, configuration): def __init__(self, configuration):
@@ -48,37 +52,49 @@ class CSV(BaseQueryRunner):
args = {} args = {}
try: try:
args = yaml.safe_load(query) args = yaml.safe_load(query)
path = args['url'] path = args["url"]
args.pop('url', None) args.pop("url", None)
ua = args['user-agent'] ua = args["user-agent"]
args.pop('user-agent', None) args.pop("user-agent", None)
except: except Exception:
pass pass
try: try:
response = requests_or_advocate.get(url=path, headers={"User-agent": ua}) response = requests_or_advocate.get(url=path, headers={"User-agent": ua})
workbook = pd.read_csv(io.BytesIO(response.content),sep=",", **args) workbook = pd.read_csv(io.BytesIO(response.content), sep=",", **args)
df = workbook.copy() df = workbook.copy()
data = {'columns': [], 'rows': []} data = {"columns": [], "rows": []}
conversions = [ conversions = [
{'pandas_type': np.integer, 'redash_type': 'integer',}, {
{'pandas_type': np.inexact, 'redash_type': 'float',}, "pandas_type": np.integer,
{'pandas_type': np.datetime64, 'redash_type': 'datetime', 'to_redash': lambda x: x.strftime('%Y-%m-%d %H:%M:%S')}, "redash_type": "integer",
{'pandas_type': np.bool_, 'redash_type': 'boolean'}, },
{'pandas_type': np.object, 'redash_type': 'string'} {
"pandas_type": np.inexact,
"redash_type": "float",
},
{
"pandas_type": np.datetime64,
"redash_type": "datetime",
"to_redash": lambda x: x.strftime("%Y-%m-%d %H:%M:%S"),
},
{"pandas_type": np.bool_, "redash_type": "boolean"},
{"pandas_type": np.object, "redash_type": "string"},
] ]
labels = [] labels = []
for dtype, label in zip(df.dtypes, df.columns): for dtype, label in zip(df.dtypes, df.columns):
for conversion in conversions: for conversion in conversions:
if issubclass(dtype.type, conversion['pandas_type']): if issubclass(dtype.type, conversion["pandas_type"]):
data['columns'].append({'name': label, 'friendly_name': label, 'type': conversion['redash_type']}) data["columns"].append(
{"name": label, "friendly_name": label, "type": conversion["redash_type"]}
)
labels.append(label) labels.append(label)
func = conversion.get('to_redash') func = conversion.get("to_redash")
if func: if func:
df[label] = df[label].apply(func) df[label] = df[label].apply(func)
break break
data['rows'] = df[labels].replace({np.nan: None}).to_dict(orient='records') data["rows"] = df[labels].replace({np.nan: None}).to_dict(orient="records")
json_data = json_dumps(data) json_data = json_dumps(data)
error = None error = None
@@ -97,4 +113,5 @@ class CSV(BaseQueryRunner):
def get_schema(self): def get_schema(self):
raise NotSupported() raise NotSupported()
register(CSV) register(CSV)

View File

@@ -1,13 +1,21 @@
try: try:
from databend_sqlalchemy import connector
import re import re
from databend_sqlalchemy import connector
enabled = True enabled = True
except ImportError: except ImportError:
enabled = False enabled = False
from redash.query_runner import BaseQueryRunner, register from redash.query_runner import (
from redash.query_runner import TYPE_STRING, TYPE_INTEGER, TYPE_BOOLEAN, TYPE_FLOAT, TYPE_DATETIME, TYPE_DATE TYPE_DATE,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseQueryRunner,
register,
)
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
@@ -72,12 +80,8 @@ class Databend(BaseQueryRunner):
try: try:
cursor.execute(query) cursor.execute(query)
columns = self.fetch_columns( columns = self.fetch_columns([(i[0], self._define_column_type(i[1])) for i in cursor.description])
[(i[0], self._define_column_type(i[1])) for i in cursor.description] rows = [dict(zip((column["name"] for column in columns), row)) for row in cursor]
)
rows = [
dict(zip((column["name"] for column in columns), row)) for row in cursor
]
data = {"columns": columns, "rows": rows} data = {"columns": columns, "rows": rows}
error = None error = None

View File

@@ -1,22 +1,22 @@
import datetime import datetime
import logging import logging
import os import os
import sqlparse
from redash import __version__, statsd_client
from redash.query_runner import ( from redash.query_runner import (
NotSupported,
register,
BaseSQLQueryRunner,
TYPE_STRING,
TYPE_BOOLEAN, TYPE_BOOLEAN,
TYPE_DATE, TYPE_DATE,
TYPE_DATETIME, TYPE_DATETIME,
TYPE_INTEGER,
TYPE_FLOAT, TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseSQLQueryRunner,
NotSupported,
register,
split_sql_statements,
) )
from redash.settings import cast_int_or_default from redash.settings import cast_int_or_default
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
from redash.query_runner import split_sql_statements
from redash import __version__, settings, statsd_client
try: try:
import pyodbc import pyodbc
@@ -38,6 +38,7 @@ ROW_LIMIT = cast_int_or_default(os.environ.get("DATABRICKS_ROW_LIMIT"), 20000)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _build_odbc_connection_string(**kwargs): def _build_odbc_connection_string(**kwargs):
return ";".join([f"{k}={v}" for k, v in kwargs.items()]) return ";".join([f"{k}={v}" for k, v in kwargs.items()])
@@ -104,24 +105,13 @@ class Databricks(BaseSQLQueryRunner):
if cursor.description is not None: if cursor.description is not None:
result_set = cursor.fetchmany(ROW_LIMIT) result_set = cursor.fetchmany(ROW_LIMIT)
columns = self.fetch_columns( columns = self.fetch_columns([(i[0], TYPES_MAP.get(i[1], TYPE_STRING)) for i in cursor.description])
[
(i[0], TYPES_MAP.get(i[1], TYPE_STRING))
for i in cursor.description
]
)
rows = [ rows = [dict(zip((column["name"] for column in columns), row)) for row in result_set]
dict(zip((column["name"] for column in columns), row))
for row in result_set
]
data = {"columns": columns, "rows": rows} data = {"columns": columns, "rows": rows}
if ( if len(result_set) >= ROW_LIMIT and cursor.fetchone() is not None:
len(result_set) >= ROW_LIMIT
and cursor.fetchone() is not None
):
logger.warning("Truncated result set.") logger.warning("Truncated result set.")
statsd_client.incr("redash.query_runner.databricks.truncated") statsd_client.incr("redash.query_runner.databricks.truncated")
data["truncated"] = True data["truncated"] = True

View File

@@ -1,12 +1,23 @@
import logging import logging
from redash.query_runner import * from redash.query_runner import (
TYPE_DATE,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseSQLQueryRunner,
InterruptException,
JobTimeoutException,
register,
)
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
import select import select
import ibm_db_dbi import ibm_db_dbi
types_map = { types_map = {
@@ -55,7 +66,7 @@ class DB2(BaseSQLQueryRunner):
@classmethod @classmethod
def enabled(cls): def enabled(cls):
try: try:
import ibm_db import ibm_db # noqa: F401
except ImportError: except ImportError:
return False return False
@@ -114,13 +125,8 @@ class DB2(BaseSQLQueryRunner):
cursor.execute(query) cursor.execute(query)
if cursor.description is not None: if cursor.description is not None:
columns = self.fetch_columns( columns = self.fetch_columns([(i[0], types_map.get(i[1], None)) for i in cursor.description])
[(i[0], types_map.get(i[1], None)) for i in cursor.description] rows = [dict(zip((column["name"] for column in columns), row)) for row in cursor]
)
rows = [
dict(zip((column["name"] for column in columns), row))
for row in cursor
]
data = {"columns": columns, "rows": rows} data = {"columns": columns, "rows": rows}
error = None error = None
@@ -128,7 +134,7 @@ class DB2(BaseSQLQueryRunner):
else: else:
error = "Query completed but it returned no data." error = "Query completed but it returned no data."
json_data = None json_data = None
except (select.error, OSError) as e: except (select.error, OSError):
error = "Query interrupted. Please retry." error = "Query interrupted. Please retry."
json_data = None json_data = None
except ibm_db_dbi.DatabaseError as e: except ibm_db_dbi.DatabaseError as e:

View File

@@ -81,7 +81,6 @@ class Dgraph(BaseQueryRunner):
client_stub.close() client_stub.close()
def run_query(self, query, user): def run_query(self, query, user):
json_data = None json_data = None
error = None error = None
@@ -106,9 +105,7 @@ class Dgraph(BaseQueryRunner):
header = list(set(header)) header = list(set(header))
columns = [ columns = [{"name": c, "friendly_name": c, "type": "string"} for c in header]
{"name": c, "friendly_name": c, "type": "string"} for c in header
]
# finally, assemble both the columns and data # finally, assemble both the columns and data
data = {"columns": columns, "rows": processed_data} data = {"columns": columns, "rows": processed_data}

View File

@@ -1,17 +1,17 @@
import os
import logging import logging
import os
import re import re
from dateutil import parser from dateutil import parser
from redash.query_runner import ( from redash.query_runner import (
BaseHTTPQueryRunner,
register,
TYPE_DATETIME,
TYPE_INTEGER,
TYPE_FLOAT,
TYPE_BOOLEAN, TYPE_BOOLEAN,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
BaseHTTPQueryRunner,
guess_type, guess_type,
register,
) )
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
@@ -51,9 +51,7 @@ def parse_response(data):
types = {} types = {}
for c in cols: for c in cols:
columns.append( columns.append({"name": c, "type": guess_type(first_row[c]), "friendly_name": c})
{"name": c, "type": guess_type(first_row[c]), "friendly_name": c}
)
for col in columns: for col in columns:
types[col["name"]] = col["type"] types[col["name"]] = col["type"]
@@ -96,9 +94,7 @@ class Drill(BaseHTTPQueryRunner):
payload = {"queryType": "SQL", "query": query} payload = {"queryType": "SQL", "query": query}
response, error = self.get_response( response, error = self.get_response(drill_url, http_method="post", json=payload)
drill_url, http_method="post", json=payload
)
if error is not None: if error is not None:
return None, error return None, error
@@ -107,7 +103,6 @@ class Drill(BaseHTTPQueryRunner):
return json_dumps(results), None return json_dumps(results), None
def get_schema(self, get_stats=False): def get_schema(self, get_stats=False):
query = """ query = """
SELECT DISTINCT SELECT DISTINCT
TABLE_SCHEMA, TABLE_SCHEMA,

View File

@@ -5,8 +5,13 @@ try:
except ImportError: except ImportError:
enabled = False enabled = False
from redash.query_runner import BaseQueryRunner, register from redash.query_runner import (
from redash.query_runner import TYPE_STRING, TYPE_INTEGER, TYPE_BOOLEAN TYPE_BOOLEAN,
TYPE_INTEGER,
TYPE_STRING,
BaseQueryRunner,
register,
)
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
TYPES_MAP = {1: TYPE_STRING, 2: TYPE_INTEGER, 3: TYPE_BOOLEAN} TYPES_MAP = {1: TYPE_STRING, 2: TYPE_INTEGER, 3: TYPE_BOOLEAN}
@@ -49,12 +54,8 @@ class Druid(BaseQueryRunner):
try: try:
cursor.execute(query) cursor.execute(query)
columns = self.fetch_columns( columns = self.fetch_columns([(i[0], TYPES_MAP.get(i[1], None)) for i in cursor.description])
[(i[0], TYPES_MAP.get(i[1], None)) for i in cursor.description] rows = [dict(zip((column["name"] for column in columns), row)) for row in cursor]
)
rows = [
dict(zip((column["name"] for column in columns), row)) for row in cursor
]
data = {"columns": columns, "rows": rows} data = {"columns": columns, "rows": rows}
error = None error = None

View File

@@ -1,13 +1,21 @@
import logging import logging
import sys
import urllib.request
import urllib.parse
import urllib.error import urllib.error
import urllib.parse
import urllib.request
import requests import requests
from requests.auth import HTTPBasicAuth from requests.auth import HTTPBasicAuth
from redash.query_runner import * from redash.query_runner import (
TYPE_BOOLEAN,
TYPE_DATE,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseQueryRunner,
JobTimeoutException,
register,
)
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
try: try:
@@ -44,7 +52,7 @@ PYTHON_TYPES_MAPPING = {
class BaseElasticSearch(BaseQueryRunner): class BaseElasticSearch(BaseQueryRunner):
should_annotate_query = False should_annotate_query = False
DEBUG_ENABLED = False DEBUG_ENABLED = False
deprecated=True deprecated = True
@classmethod @classmethod
def configuration_schema(cls): def configuration_schema(cls):
@@ -103,9 +111,7 @@ class BaseElasticSearch(BaseQueryRunner):
mappings = r.json() mappings = r.json()
except requests.HTTPError as e: except requests.HTTPError as e:
logger.exception(e) logger.exception(e)
error = "Failed to execute query. Return Code: {0} Reason: {1}".format( error = "Failed to execute query. Return Code: {0} Reason: {1}".format(r.status_code, r.text)
r.status_code, r.text
)
mappings = None mappings = None
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.exception(e) logger.exception(e)
@@ -126,16 +132,12 @@ class BaseElasticSearch(BaseQueryRunner):
if "properties" not in index_mappings["mappings"][m]: if "properties" not in index_mappings["mappings"][m]:
continue continue
for property_name in index_mappings["mappings"][m]["properties"]: for property_name in index_mappings["mappings"][m]["properties"]:
property_data = index_mappings["mappings"][m]["properties"][ property_data = index_mappings["mappings"][m]["properties"][property_name]
property_name
]
if property_name not in mappings: if property_name not in mappings:
property_type = property_data.get("type", None) property_type = property_data.get("type", None)
if property_type: if property_type:
if property_type in ELASTICSEARCH_TYPES_MAPPING: if property_type in ELASTICSEARCH_TYPES_MAPPING:
mappings[property_name] = ELASTICSEARCH_TYPES_MAPPING[ mappings[property_name] = ELASTICSEARCH_TYPES_MAPPING[property_type]
property_type
]
else: else:
mappings[property_name] = TYPE_STRING mappings[property_name] = TYPE_STRING
# raise Exception("Unknown property type: {0}".format(property_type)) # raise Exception("Unknown property type: {0}".format(property_type))
@@ -144,8 +146,7 @@ class BaseElasticSearch(BaseQueryRunner):
def get_schema(self, *args, **kwargs): def get_schema(self, *args, **kwargs):
def parse_doc(doc, path=None): def parse_doc(doc, path=None):
"""Recursively parse a doc type dictionary """Recursively parse a doc type dictionary"""
"""
path = path or [] path = path or []
result = [] result = []
for field, description in doc["properties"].items(): for field, description in doc["properties"].items():
@@ -174,12 +175,8 @@ class BaseElasticSearch(BaseQueryRunner):
schema[name]["columns"] = sorted(set(columns)) schema[name]["columns"] = sorted(set(columns))
return list(schema.values()) return list(schema.values())
def _parse_results( def _parse_results(self, mappings, result_fields, raw_result, result_columns, result_rows): # noqa: C901
self, mappings, result_fields, raw_result, result_columns, result_rows def add_column_if_needed(mappings, column_name, friendly_name, result_columns, result_columns_index):
):
def add_column_if_needed(
mappings, column_name, friendly_name, result_columns, result_columns_index
):
if friendly_name not in result_columns_index: if friendly_name not in result_columns_index:
result_columns.append( result_columns.append(
{ {
@@ -201,14 +198,10 @@ class BaseElasticSearch(BaseQueryRunner):
return return
mappings[key] = type mappings[key] = type
add_column_if_needed( add_column_if_needed(mappings, key, key, result_columns, result_columns_index)
mappings, key, key, result_columns, result_columns_index
)
row[key] = value row[key] = value
def collect_aggregations( def collect_aggregations(mappings, rows, parent_key, data, row, result_columns, result_columns_index):
mappings, rows, parent_key, data, row, result_columns, result_columns_index
):
if isinstance(data, dict): if isinstance(data, dict):
for key, value in data.items(): for key, value in data.items():
val = collect_aggregations( val = collect_aggregations(
@@ -269,9 +262,7 @@ class BaseElasticSearch(BaseQueryRunner):
"string", "string",
) )
else: else:
collect_value( collect_value(mappings, result_row, parent_key, value["key"], "string")
mappings, result_row, parent_key, value["key"], "string"
)
return None return None
@@ -291,9 +282,7 @@ class BaseElasticSearch(BaseQueryRunner):
elif "aggregations" in raw_result: elif "aggregations" in raw_result:
if result_fields: if result_fields:
for field in result_fields: for field in result_fields:
add_column_if_needed( add_column_if_needed(mappings, field, field, result_columns, result_columns_index)
mappings, field, field, result_columns, result_columns_index
)
for key, data in raw_result["aggregations"].items(): for key, data in raw_result["aggregations"].items():
collect_aggregations( collect_aggregations(
@@ -311,9 +300,7 @@ class BaseElasticSearch(BaseQueryRunner):
elif "hits" in raw_result and "hits" in raw_result["hits"]: elif "hits" in raw_result and "hits" in raw_result["hits"]:
if result_fields: if result_fields:
for field in result_fields: for field in result_fields:
add_column_if_needed( add_column_if_needed(mappings, field, field, result_columns, result_columns_index)
mappings, field, field, result_columns, result_columns_index
)
for h in raw_result["hits"]["hits"]: for h in raw_result["hits"]["hits"]:
row = {} row = {}
@@ -323,36 +310,22 @@ class BaseElasticSearch(BaseQueryRunner):
if result_fields and column not in result_fields_index: if result_fields and column not in result_fields_index:
continue continue
add_column_if_needed( add_column_if_needed(mappings, column, column, result_columns, result_columns_index)
mappings, column, column, result_columns, result_columns_index
)
value = h[column_name][column] value = h[column_name][column]
row[column] = ( row[column] = value[0] if isinstance(value, list) and len(value) == 1 else value
value[0]
if isinstance(value, list) and len(value) == 1
else value
)
result_rows.append(row) result_rows.append(row)
else: else:
raise Exception( raise Exception("Redash failed to parse the results it got from Elasticsearch.")
"Redash failed to parse the results it got from Elasticsearch."
)
def test_connection(self): def test_connection(self):
try: try:
r = requests.get( r = requests.get("{0}/_cluster/health".format(self.server_url), auth=self.auth)
"{0}/_cluster/health".format(self.server_url), auth=self.auth
)
r.raise_for_status() r.raise_for_status()
except requests.HTTPError as e: except requests.HTTPError as e:
logger.exception(e) logger.exception(e)
raise Exception( raise Exception("Failed to execute query. Return Code: {0} Reason: {1}".format(r.status_code, r.text))
"Failed to execute query. Return Code: {0} Reason: {1}".format(
r.status_code, r.text
)
)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.exception(e) logger.exception(e)
raise Exception("Connection refused") raise Exception("Connection refused")
@@ -363,18 +336,14 @@ class Kibana(BaseElasticSearch):
def enabled(cls): def enabled(cls):
return True return True
def _execute_simple_query( def _execute_simple_query(self, url, auth, _from, mappings, result_fields, result_columns, result_rows):
self, url, auth, _from, mappings, result_fields, result_columns, result_rows
):
url += "&from={0}".format(_from) url += "&from={0}".format(_from)
r = requests.get(url, auth=self.auth) r = requests.get(url, auth=self.auth)
r.raise_for_status() r.raise_for_status()
raw_result = r.json() raw_result = r.json()
self._parse_results( self._parse_results(mappings, result_fields, raw_result, result_columns, result_rows)
mappings, result_fields, raw_result, result_columns, result_rows
)
total = raw_result["hits"]["total"] total = raw_result["hits"]["total"]
result_size = len(raw_result["hits"]["hits"]) result_size = len(raw_result["hits"]["hits"])
@@ -421,7 +390,7 @@ class Kibana(BaseElasticSearch):
_from = 0 _from = 0
while True: while True:
query_size = size if limit >= (_from + size) else (limit - _from) query_size = size if limit >= (_from + size) else (limit - _from)
total = self._execute_simple_query( self._execute_simple_query(
url + "&size={0}".format(query_size), url + "&size={0}".format(query_size),
self.auth, self.auth,
_from, _from,
@@ -440,9 +409,8 @@ class Kibana(BaseElasticSearch):
json_data = json_dumps({"columns": result_columns, "rows": result_rows}) json_data = json_dumps({"columns": result_columns, "rows": result_rows})
except requests.HTTPError as e: except requests.HTTPError as e:
logger.exception(e) logger.exception(e)
error = "Failed to execute query. Return Code: {0} Reason: {1}".format( r = e.response
r.status_code, r.text error = "Failed to execute query. Return Code: {0} Reason: {1}".format(r.status_code, r.text)
)
json_data = None json_data = None
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.exception(e) logger.exception(e)
@@ -490,9 +458,7 @@ class ElasticSearch(BaseElasticSearch):
result_columns = [] result_columns = []
result_rows = [] result_rows = []
self._parse_results( self._parse_results(mappings, result_fields, r.json(), result_columns, result_rows)
mappings, result_fields, r.json(), result_columns, result_rows
)
json_data = json_dumps({"columns": result_columns, "rows": result_rows}) json_data = json_dumps({"columns": result_columns, "rows": result_rows})
except (KeyboardInterrupt, JobTimeoutException) as e: except (KeyboardInterrupt, JobTimeoutException) as e:
@@ -500,9 +466,7 @@ class ElasticSearch(BaseElasticSearch):
raise raise
except requests.HTTPError as e: except requests.HTTPError as e:
logger.exception(e) logger.exception(e)
error = "Failed to execute query. Return Code: {0} Reason: {1}".format( error = "Failed to execute query. Return Code: {0} Reason: {1}".format(r.status_code, r.text)
r.status_code, r.text
)
json_data = None json_data = None
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.exception(e) logger.exception(e)

View File

@@ -1,10 +1,17 @@
import logging import logging
from typing import Tuple, Optional from typing import Optional, Tuple
from redash.query_runner import * from redash.query_runner import (
TYPE_BOOLEAN,
TYPE_DATE,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseHTTPQueryRunner,
register,
)
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ELASTICSEARCH_TYPES_MAPPING = { ELASTICSEARCH_TYPES_MAPPING = {
@@ -28,7 +35,6 @@ TYPES_MAP = {
class ElasticSearch2(BaseHTTPQueryRunner): class ElasticSearch2(BaseHTTPQueryRunner):
should_annotate_query = False should_annotate_query = False
@classmethod @classmethod
@@ -37,12 +43,12 @@ class ElasticSearch2(BaseHTTPQueryRunner):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.syntax = 'json' self.syntax = "json"
def get_response(self, url, auth=None, http_method='get', **kwargs): def get_response(self, url, auth=None, http_method="get", **kwargs):
url = "{}{}".format(self.configuration["url"], url) url = "{}{}".format(self.configuration["url"], url)
headers = kwargs.pop('headers', {}) headers = kwargs.pop("headers", {})
headers['Accept'] = 'application/json' headers["Accept"] = "application/json"
return super().get_response(url, auth, http_method, headers=headers, **kwargs) return super().get_response(url, auth, http_method, headers=headers, **kwargs)
def test_connection(self): def test_connection(self):
@@ -52,11 +58,7 @@ class ElasticSearch2(BaseHTTPQueryRunner):
def run_query(self, query, user): def run_query(self, query, user):
query, url, result_fields = self._build_query(query) query, url, result_fields = self._build_query(query)
response, error = self.get_response( response, error = self.get_response(url, http_method="post", json=query)
url,
http_method='post',
json=query
)
query_results = response.json() query_results = response.json()
data = self._parse_results(result_fields, query_results) data = self._parse_results(result_fields, query_results)
error = None error = None
@@ -65,8 +67,8 @@ class ElasticSearch2(BaseHTTPQueryRunner):
def _build_query(self, query: str) -> Tuple[dict, str, Optional[list]]: def _build_query(self, query: str) -> Tuple[dict, str, Optional[list]]:
query = json_loads(query) query = json_loads(query)
index_name = query.pop('index', '') index_name = query.pop("index", "")
result_fields = query.pop('result_fields', None) result_fields = query.pop("result_fields", None)
url = "/{}/_search".format(index_name) url = "/{}/_search".format(index_name)
return query, url, result_fields return query, url, result_fields
@@ -77,14 +79,14 @@ class ElasticSearch2(BaseHTTPQueryRunner):
def _parse_properties(prefix: str, properties: dict): def _parse_properties(prefix: str, properties: dict):
for property_name, property_data in properties.items(): for property_name, property_data in properties.items():
if property_name not in mappings: if property_name not in mappings:
property_type = property_data.get('type', None) property_type = property_data.get("type", None)
nested_properties = property_data.get('properties', None) nested_properties = property_data.get("properties", None)
if property_type: if property_type:
mappings[index_name][prefix + property_name] = ( mappings[index_name][prefix + property_name] = ELASTICSEARCH_TYPES_MAPPING.get(
ELASTICSEARCH_TYPES_MAPPING.get(property_type, TYPE_STRING) property_type, TYPE_STRING
) )
elif nested_properties: elif nested_properties:
new_prefix = prefix + property_name + '.' new_prefix = prefix + property_name + "."
_parse_properties(new_prefix, nested_properties) _parse_properties(new_prefix, nested_properties)
for index_name in mappings_data: for index_name in mappings_data:
@@ -92,27 +94,24 @@ class ElasticSearch2(BaseHTTPQueryRunner):
index_mappings = mappings_data[index_name] index_mappings = mappings_data[index_name]
try: try:
for m in index_mappings.get("mappings", {}): for m in index_mappings.get("mappings", {}):
_parse_properties('', index_mappings['mappings'][m]['properties']) _parse_properties("", index_mappings["mappings"][m]["properties"])
except KeyError: except KeyError:
_parse_properties('', index_mappings['mappings']['properties']) _parse_properties("", index_mappings["mappings"]["properties"])
return mappings return mappings
def get_mappings(self): def get_mappings(self):
response, error = self.get_response('/_mappings') response, error = self.get_response("/_mappings")
return self._parse_mappings(response.json()) return self._parse_mappings(response.json())
def get_schema(self, *args, **kwargs): def get_schema(self, *args, **kwargs):
schema = {} schema = {}
for name, columns in self.get_mappings().items(): for name, columns in self.get_mappings().items():
schema[name] = { schema[name] = {"name": name, "columns": list(columns.keys())}
'name': name,
'columns': list(columns.keys())
}
return list(schema.values()) return list(schema.values())
@classmethod @classmethod
def _parse_results(cls, result_fields, raw_result): def _parse_results(cls, result_fields, raw_result): # noqa: C901
result_columns = [] result_columns = []
result_rows = [] result_rows = []
result_columns_index = {c["name"]: c for c in result_columns} result_columns_index = {c["name"]: c for c in result_columns}
@@ -120,11 +119,13 @@ class ElasticSearch2(BaseHTTPQueryRunner):
def add_column_if_needed(column_name, value=None): def add_column_if_needed(column_name, value=None):
if column_name not in result_columns_index: if column_name not in result_columns_index:
result_columns.append({ result_columns.append(
'name': column_name, {
'friendly_name': column_name, "name": column_name,
'type': TYPES_MAP.get(type(value), TYPE_STRING) "friendly_name": column_name,
}) "type": TYPES_MAP.get(type(value), TYPE_STRING),
}
)
result_columns_index[column_name] = result_columns[-1] result_columns_index[column_name] = result_columns[-1]
def get_row(rows, row): def get_row(rows, row):
@@ -143,23 +144,23 @@ class ElasticSearch2(BaseHTTPQueryRunner):
def parse_bucket_to_row(data, row, agg_key): def parse_bucket_to_row(data, row, agg_key):
sub_agg_key = "" sub_agg_key = ""
for key, item in data.items(): for key, item in data.items():
if key == 'key_as_string': if key == "key_as_string":
continue continue
if key == 'key': if key == "key":
if 'key_as_string' in data: if "key_as_string" in data:
collect_value(row, agg_key, data['key_as_string']) collect_value(row, agg_key, data["key_as_string"])
else: else:
collect_value(row, agg_key, data['key']) collect_value(row, agg_key, data["key"])
continue continue
if isinstance(item, (str, int, float)): if isinstance(item, (str, int, float)):
collect_value(row, agg_key + '.' + key, item) collect_value(row, agg_key + "." + key, item)
elif isinstance(item, dict): elif isinstance(item, dict):
if 'buckets' not in item: if "buckets" not in item:
for sub_key, sub_item in item.items(): for sub_key, sub_item in item.items():
collect_value( collect_value(
row, row,
agg_key + '.' + key + '.' + sub_key, agg_key + "." + key + "." + sub_key,
sub_item, sub_item,
) )
else: else:
@@ -179,18 +180,18 @@ class ElasticSearch2(BaseHTTPQueryRunner):
rows.append(row) rows.append(row)
else: else:
depth += 1 depth += 1
parse_buckets_list(rows, sub_agg_key, value[sub_agg_key]['buckets'], row, depth) parse_buckets_list(rows, sub_agg_key, value[sub_agg_key]["buckets"], row, depth)
def collect_aggregations(rows, parent_key, data, row, depth): def collect_aggregations(rows, parent_key, data, row, depth):
row = get_row(rows, row) row = get_row(rows, row)
parse_bucket_to_row(data, row, parent_key) parse_bucket_to_row(data, row, parent_key)
if 'buckets' in data: if "buckets" in data:
parse_buckets_list(rows, parent_key, data['buckets'], row, depth) parse_buckets_list(rows, parent_key, data["buckets"], row, depth)
return None return None
def get_flatten_results(dd, separator='.', prefix=''): def get_flatten_results(dd, separator=".", prefix=""):
if isinstance(dd, dict): if isinstance(dd, dict):
return { return {
prefix + separator + k if prefix else k: v prefix + separator + k if prefix else k: v
@@ -206,17 +207,17 @@ class ElasticSearch2(BaseHTTPQueryRunner):
for r in result_fields: for r in result_fields:
result_fields_index[r] = None result_fields_index[r] = None
if 'error' in raw_result: if "error" in raw_result:
error = raw_result['error'] error = raw_result["error"]
if len(error) > 10240: if len(error) > 10240:
error = error[:10240] + '... continues' error = error[:10240] + "... continues"
raise Exception(error) raise Exception(error)
elif 'aggregations' in raw_result: elif "aggregations" in raw_result:
for key, data in raw_result["aggregations"].items(): for key, data in raw_result["aggregations"].items():
collect_aggregations(result_rows, key, data, None, 0) collect_aggregations(result_rows, key, data, None, 0)
elif 'hits' in raw_result and 'hits' in raw_result['hits']: elif "hits" in raw_result and "hits" in raw_result["hits"]:
for h in raw_result["hits"]["hits"]: for h in raw_result["hits"]["hits"]:
row = {} row = {}
@@ -235,23 +236,17 @@ class ElasticSearch2(BaseHTTPQueryRunner):
else: else:
raise Exception("Redash failed to parse the results it got from Elasticsearch.") raise Exception("Redash failed to parse the results it got from Elasticsearch.")
return { return {"columns": result_columns, "rows": result_rows}
'columns': result_columns,
'rows': result_rows
}
class OpenDistroSQLElasticSearch(ElasticSearch2): class OpenDistroSQLElasticSearch(ElasticSearch2):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.syntax = 'sql' self.syntax = "sql"
def _build_query(self, query: str) -> Tuple[dict, str, Optional[list]]: def _build_query(self, query: str) -> Tuple[dict, str, Optional[list]]:
sql_query = { sql_query = {"query": query}
'query': query sql_query_url = "/_opendistro/_sql"
}
sql_query_url = '/_opendistro/_sql'
return sql_query, sql_query_url, None return sql_query, sql_query_url, None
@classmethod @classmethod
@@ -263,56 +258,52 @@ class OpenDistroSQLElasticSearch(ElasticSearch2):
return "elasticsearch2_OpenDistroSQLElasticSearch" return "elasticsearch2_OpenDistroSQLElasticSearch"
class XPackSQLElasticSearch(ElasticSearch2): class XPackSQLElasticSearch(ElasticSearch2):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.syntax = 'sql' self.syntax = "sql"
def _build_query(self, query: str) -> Tuple[dict, str, Optional[list]]: def _build_query(self, query: str) -> Tuple[dict, str, Optional[list]]:
sql_query = { sql_query = {"query": query}
'query': query sql_query_url = "/_xpack/sql"
}
sql_query_url = '/_xpack/sql'
return sql_query, sql_query_url, None return sql_query, sql_query_url, None
@classmethod @classmethod
def _parse_results(cls, result_fields, raw_result): def _parse_results(cls, result_fields, raw_result):
error = raw_result.get('error') error = raw_result.get("error")
if error: if error:
raise Exception(error) raise Exception(error)
rv = { rv = {
'columns': [ "columns": [
{ {
'name': c['name'], "name": c["name"],
'friendly_name': c['name'], "friendly_name": c["name"],
'type': ELASTICSEARCH_TYPES_MAPPING.get(c['type'], 'string') "type": ELASTICSEARCH_TYPES_MAPPING.get(c["type"], "string"),
} for c in raw_result['columns'] }
for c in raw_result["columns"]
], ],
'rows': [] "rows": [],
} }
query_results_rows = raw_result['rows'] query_results_rows = raw_result["rows"]
for query_results_row in query_results_rows: for query_results_row in query_results_rows:
result_row = dict() result_row = dict()
for column, column_value in zip(rv['columns'], query_results_row): for column, column_value in zip(rv["columns"], query_results_row):
result_row[column['name']] = column_value result_row[column["name"]] = column_value
rv['rows'].append(result_row) rv["rows"].append(result_row)
return rv return rv
@classmethod @classmethod
def name(cls): def name(cls):
return cls.__name__ return cls.__name__
@classmethod @classmethod
def type(cls): def type(cls):
return "elasticsearch2_XPackSQLElasticSearch" return "elasticsearch2_XPackSQLElasticSearch"
register(ElasticSearch2) register(ElasticSearch2)
register(OpenDistroSQLElasticSearch) register(OpenDistroSQLElasticSearch)
register(XPackSQLElasticSearch) register(XPackSQLElasticSearch)

View File

@@ -1,6 +1,14 @@
import datetime import datetime
from redash.query_runner import * from redash.query_runner import (
TYPE_DATE,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseQueryRunner,
register,
)
from redash.utils import json_dumps from redash.utils import json_dumps
@@ -95,8 +103,7 @@ class Exasol(BaseQueryRunner):
try: try:
statement = connection.execute(query) statement = connection.execute(query)
columns = [ columns = [
{"name": n, "friendly_name": n, "type": _type_mapper(t)} {"name": n, "friendly_name": n, "type": _type_mapper(t)} for (n, t) in statement.columns().items()
for (n, t) in statement.columns().items()
] ]
cnames = statement.column_names() cnames = statement.column_names()
@@ -126,7 +133,7 @@ class Exasol(BaseQueryRunner):
statement = connection.execute(query) statement = connection.execute(query)
result = {} result = {}
for (schema, table_name, column) in statement: for schema, table_name, column in statement:
table_name_with_schema = "%s.%s" % (schema, table_name) table_name_with_schema = "%s.%s" % (schema, table_name)
if table_name_with_schema not in result: if table_name_with_schema not in result:

View File

@@ -1,22 +1,27 @@
import logging import logging
import yaml import yaml
from redash.utils.requests_session import requests_or_advocate, UnacceptableAddressException from redash.query_runner import BaseQueryRunner, NotSupported, register
from redash.query_runner import *
from redash.utils import json_dumps from redash.utils import json_dumps
from redash.utils.requests_session import (
UnacceptableAddressException,
requests_or_advocate,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
import pandas as pd
import xlrd
import openpyxl
import numpy as np import numpy as np
import openpyxl # noqa: F401
import pandas as pd
import xlrd # noqa: F401
enabled = True enabled = True
except ImportError: except ImportError:
enabled = False enabled = False
class Excel(BaseQueryRunner): class Excel(BaseQueryRunner):
should_annotate_query = False should_annotate_query = False
@@ -27,8 +32,8 @@ class Excel(BaseQueryRunner):
@classmethod @classmethod
def configuration_schema(cls): def configuration_schema(cls):
return { return {
'type': 'object', "type": "object",
'properties': {}, "properties": {},
} }
def __init__(self, configuration): def __init__(self, configuration):
@@ -44,12 +49,12 @@ class Excel(BaseQueryRunner):
args = {} args = {}
try: try:
args = yaml.safe_load(query) args = yaml.safe_load(query)
path = args['url'] path = args["url"]
args.pop('url', None) args.pop("url", None)
ua = args['user-agent'] ua = args["user-agent"]
args.pop('user-agent', None) args.pop("user-agent", None)
except: except Exception:
pass pass
try: try:
@@ -57,25 +62,37 @@ class Excel(BaseQueryRunner):
workbook = pd.read_excel(response.content, **args) workbook = pd.read_excel(response.content, **args)
df = workbook.copy() df = workbook.copy()
data = {'columns': [], 'rows': []} data = {"columns": [], "rows": []}
conversions = [ conversions = [
{'pandas_type': np.integer, 'redash_type': 'integer',}, {
{'pandas_type': np.inexact, 'redash_type': 'float',}, "pandas_type": np.integer,
{'pandas_type': np.datetime64, 'redash_type': 'datetime', 'to_redash': lambda x: x.strftime('%Y-%m-%d %H:%M:%S')}, "redash_type": "integer",
{'pandas_type': np.bool_, 'redash_type': 'boolean'}, },
{'pandas_type': np.object, 'redash_type': 'string'} {
"pandas_type": np.inexact,
"redash_type": "float",
},
{
"pandas_type": np.datetime64,
"redash_type": "datetime",
"to_redash": lambda x: x.strftime("%Y-%m-%d %H:%M:%S"),
},
{"pandas_type": np.bool_, "redash_type": "boolean"},
{"pandas_type": np.object, "redash_type": "string"},
] ]
labels = [] labels = []
for dtype, label in zip(df.dtypes, df.columns): for dtype, label in zip(df.dtypes, df.columns):
for conversion in conversions: for conversion in conversions:
if issubclass(dtype.type, conversion['pandas_type']): if issubclass(dtype.type, conversion["pandas_type"]):
data['columns'].append({'name': label, 'friendly_name': label, 'type': conversion['redash_type']}) data["columns"].append(
{"name": label, "friendly_name": label, "type": conversion["redash_type"]}
)
labels.append(label) labels.append(label)
func = conversion.get('to_redash') func = conversion.get("to_redash")
if func: if func:
df[label] = df[label].apply(func) df[label] = df[label].apply(func)
break break
data['rows'] = df[labels].replace({np.nan: None}).to_dict(orient='records') data["rows"] = df[labels].replace({np.nan: None}).to_dict(orient="records")
json_data = json_dumps(data) json_data = json_dumps(data)
error = None error = None
@@ -94,4 +111,5 @@ class Excel(BaseQueryRunner):
def get_schema(self): def get_schema(self):
raise NotSupported() raise NotSupported()
register(Excel) register(Excel)

View File

@@ -3,19 +3,27 @@ from base64 import b64decode
from datetime import datetime from datetime import datetime
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from redash.query_runner import * from redash.query_runner import (
TYPE_DATE,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseSQLQueryRunner,
register,
)
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
from oauth2client.service_account import ServiceAccountCredentials import httplib2
from apiclient.discovery import build from apiclient.discovery import build
from apiclient.errors import HttpError from apiclient.errors import HttpError
import httplib2 from oauth2client.service_account import ServiceAccountCredentials
enabled = True enabled = True
except ImportError as e: except ImportError:
enabled = False enabled = False
@@ -48,9 +56,7 @@ def parse_ga_response(response):
d = {} d = {}
for c, value in enumerate(r): for c, value in enumerate(r):
column_name = response["columnHeaders"][c]["name"] column_name = response["columnHeaders"][c]["name"]
column_type = [col for col in columns if col["name"] == column_name][0][ column_type = [col for col in columns if col["name"] == column_name][0]["type"]
"type"
]
# mcf results come a bit different than ga results: # mcf results come a bit different than ga results:
if isinstance(value, dict): if isinstance(value, dict):
@@ -59,9 +65,7 @@ def parse_ga_response(response):
elif "conversionPathValue" in value: elif "conversionPathValue" in value:
steps = [] steps = []
for step in value["conversionPathValue"]: for step in value["conversionPathValue"]:
steps.append( steps.append("{}:{}".format(step["interactionType"], step["nodeValue"]))
"{}:{}".format(step["interactionType"], step["nodeValue"])
)
value = ", ".join(steps) value = ", ".join(steps)
else: else:
raise Exception("Results format not supported") raise Exception("Results format not supported")
@@ -74,9 +78,7 @@ def parse_ga_response(response):
elif len(value) == 12: elif len(value) == 12:
value = datetime.strptime(value, "%Y%m%d%H%M") value = datetime.strptime(value, "%Y%m%d%H%M")
else: else:
raise Exception( raise Exception("Unknown date/time format in results: '{}'".format(value))
"Unknown date/time format in results: '{}'".format(value)
)
d[column_name] = value d[column_name] = value
rows.append(d) rows.append(d)
@@ -119,14 +121,7 @@ class GoogleAnalytics(BaseSQLQueryRunner):
return build("analytics", "v3", http=creds.authorize(httplib2.Http())) return build("analytics", "v3", http=creds.authorize(httplib2.Http()))
def _get_tables(self, schema): def _get_tables(self, schema):
accounts = ( accounts = self._get_analytics_service().management().accounts().list().execute().get("items")
self._get_analytics_service()
.management()
.accounts()
.list()
.execute()
.get("items")
)
if accounts is None: if accounts is None:
raise Exception("Failed getting accounts.") raise Exception("Failed getting accounts.")
else: else:
@@ -143,9 +138,7 @@ class GoogleAnalytics(BaseSQLQueryRunner):
for property_ in properties: for property_ in properties:
if "defaultProfileId" in property_ and "name" in property_: if "defaultProfileId" in property_ and "name" in property_:
schema[account["name"]]["columns"].append( schema[account["name"]]["columns"].append(
"{0} (ga:{1})".format( "{0} (ga:{1})".format(property_["name"], property_["defaultProfileId"])
property_["name"], property_["defaultProfileId"]
)
) )
return list(schema.values()) return list(schema.values())
@@ -162,16 +155,14 @@ class GoogleAnalytics(BaseSQLQueryRunner):
logger.debug("Analytics is about to execute query: %s", query) logger.debug("Analytics is about to execute query: %s", query)
try: try:
params = json_loads(query) params = json_loads(query)
except: except Exception:
query_string = parse_qs(urlparse(query).query, keep_blank_values=True) query_string = parse_qs(urlparse(query).query, keep_blank_values=True)
params = {k.replace('-', '_'): ",".join(v) for k,v in query_string.items()} params = {k.replace("-", "_"): ",".join(v) for k, v in query_string.items()}
if "mcf:" in params["metrics"] and "ga:" in params["metrics"]: if "mcf:" in params["metrics"] and "ga:" in params["metrics"]:
raise Exception("Can't mix mcf: and ga: metrics.") raise Exception("Can't mix mcf: and ga: metrics.")
if "mcf:" in params.get("dimensions", "") and "ga:" in params.get( if "mcf:" in params.get("dimensions", "") and "ga:" in params.get("dimensions", ""):
"dimensions", ""
):
raise Exception("Can't mix mcf: and ga: dimensions.") raise Exception("Can't mix mcf: and ga: dimensions.")
if "mcf:" in params["metrics"]: if "mcf:" in params["metrics"]:

View File

@@ -5,7 +5,16 @@ from dateutil import parser
from requests import Session from requests import Session
from xlsxwriter.utility import xl_col_to_name from xlsxwriter.utility import xl_col_to_name
from redash.query_runner import * from redash.query_runner import (
TYPE_BOOLEAN,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseQueryRunner,
guess_type,
register,
)
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -39,9 +48,7 @@ def _get_columns_and_column_names(row):
duplicate_counter += 1 duplicate_counter += 1
column_names.append(column_name) column_names.append(column_name)
columns.append( columns.append({"name": column_name, "friendly_name": column_name, "type": TYPE_STRING})
{"name": column_name, "friendly_name": column_name, "type": TYPE_STRING}
)
return columns, column_names return columns, column_names
@@ -102,10 +109,7 @@ def parse_worksheet(worksheet):
columns[j]["type"] = guess_type(value) columns[j]["type"] = guess_type(value)
column_types = [c["type"] for c in columns] column_types = [c["type"] for c in columns]
rows = [ rows = [dict(zip(column_names, _value_eval_list(row, column_types))) for row in worksheet[HEADER_INDEX + 1 :]]
dict(zip(column_names, _value_eval_list(row, column_types)))
for row in worksheet[HEADER_INDEX + 1 :]
]
data = {"columns": columns, "rows": rows} data = {"columns": columns, "rows": rows}
return data return data
@@ -210,9 +214,7 @@ class GoogleSpreadsheet(BaseQueryRunner):
except gspread.SpreadsheetNotFound: except gspread.SpreadsheetNotFound:
return ( return (
None, None,
"Spreadsheet ({}) not found. Make sure you used correct id.".format( "Spreadsheet ({}) not found. Make sure you used correct id.".format(key),
key
),
) )
except APIError as e: except APIError as e:
return None, parse_api_error(e) return None, parse_api_error(e)

View File

@@ -3,7 +3,13 @@ import logging
import requests import requests
from redash.query_runner import * from redash.query_runner import (
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_STRING,
BaseQueryRunner,
register,
)
from redash.utils import json_dumps from redash.utils import json_dumps
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -69,11 +75,7 @@ class Graphite(BaseQueryRunner):
verify=self.verify, verify=self.verify,
) )
if r.status_code != 200: if r.status_code != 200:
raise Exception( raise Exception("Got invalid response from Graphite (http status code: {0}).".format(r.status_code))
"Got invalid response from Graphite (http status code: {0}).".format(
r.status_code
)
)
def run_query(self, query, user): def run_query(self, query, user):
url = "%s%s" % (self.base_url, "&".join(query.split("\n"))) url = "%s%s" % (self.base_url, "&".join(query.split("\n")))

View File

@@ -1,8 +1,17 @@
import logging
import sys
import base64 import base64
import logging
from redash.query_runner import * from redash.query_runner import (
TYPE_BOOLEAN,
TYPE_DATE,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseSQLQueryRunner,
JobTimeoutException,
register,
)
from redash.utils import json_dumps from redash.utils import json_dumps
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -71,27 +80,17 @@ class Hive(BaseSQLQueryRunner):
columns_query = "show columns in %s.%s" columns_query = "show columns in %s.%s"
for schema_name in [ for schema_name in [
a a for a in [str(a["database_name"]) for a in self._run_query_internal(schemas_query)] if len(a) > 0
for a in [
str(a["database_name"]) for a in self._run_query_internal(schemas_query)
]
if len(a) > 0
]: ]:
for table_name in [ for table_name in [
a a
for a in [ for a in [str(a["tab_name"]) for a in self._run_query_internal(tables_query % schema_name)]
str(a["tab_name"])
for a in self._run_query_internal(tables_query % schema_name)
]
if len(a) > 0 if len(a) > 0
]: ]:
columns = [ columns = [
a a
for a in [ for a in [
str(a["field"]) str(a["field"]) for a in self._run_query_internal(columns_query % (schema_name, table_name))
for a in self._run_query_internal(
columns_query % (schema_name, table_name)
)
] ]
if len(a) > 0 if len(a) > 0
] ]

View File

@@ -1,6 +1,15 @@
import logging import logging
from redash.query_runner import * from redash.query_runner import (
TYPE_BOOLEAN,
TYPE_DATETIME,
TYPE_FLOAT,
TYPE_INTEGER,
TYPE_STRING,
BaseSQLQueryRunner,
JobTimeoutException,
register,
)
from redash.utils import json_dumps from redash.utils import json_dumps
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -10,7 +19,7 @@ try:
from impala.error import DatabaseError, RPCError from impala.error import DatabaseError, RPCError
enabled = True enabled = True
except ImportError as e: except ImportError:
enabled = False enabled = False
COLUMN_NAME = 0 COLUMN_NAME = 0
@@ -71,18 +80,10 @@ class Impala(BaseSQLQueryRunner):
tables_query = "show tables in %s;" tables_query = "show tables in %s;"
columns_query = "show column stats %s.%s;" columns_query = "show column stats %s.%s;"
for schema_name in [ for schema_name in [str(a["name"]) for a in self._run_query_internal(schemas_query)]:
str(a["name"]) for a in self._run_query_internal(schemas_query) for table_name in [str(a["name"]) for a in self._run_query_internal(tables_query % schema_name)]:
]:
for table_name in [
str(a["name"])
for a in self._run_query_internal(tables_query % schema_name)
]:
columns = [ columns = [
str(a["Column"]) str(a["Column"]) for a in self._run_query_internal(columns_query % (schema_name, table_name))
for a in self._run_query_internal(
columns_query % (schema_name, table_name)
)
] ]
if schema_name != "default": if schema_name != "default":
@@ -93,7 +94,6 @@ class Impala(BaseSQLQueryRunner):
return list(schema_dict.values()) return list(schema_dict.values())
def run_query(self, query, user): def run_query(self, query, user):
connection = None connection = None
try: try:
connection = connect(**self.configuration.to_dict()) connection = connect(**self.configuration.to_dict())

View File

@@ -1,6 +1,6 @@
import logging import logging
from redash.query_runner import * from redash.query_runner import BaseQueryRunner, register
from redash.utils import json_dumps from redash.utils import json_dumps
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,9 +42,7 @@ def _transform_result(results):
result_row[column] = value result_row[column] = value
result_rows.append(result_row) result_rows.append(result_row)
return json_dumps( return json_dumps({"columns": [{"name": c} for c in result_columns], "rows": result_rows})
{"columns": [{"name": c} for c in result_columns], "rows": result_rows}
)
class InfluxDB(BaseQueryRunner): class InfluxDB(BaseQueryRunner):

View File

@@ -1,7 +1,7 @@
import re import re
from collections import OrderedDict from collections import OrderedDict
from redash.query_runner import * from redash.query_runner import TYPE_STRING, BaseHTTPQueryRunner, register
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps, json_loads
@@ -32,7 +32,7 @@ class ResultSet(object):
self.rows = self.rows + set.rows self.rows = self.rows + set.rows
def parse_issue(issue, field_mapping): def parse_issue(issue, field_mapping): # noqa: C901
result = OrderedDict() result = OrderedDict()
result["key"] = issue["key"] result["key"] = issue["key"]
@@ -45,9 +45,7 @@ def parse_issue(issue, field_mapping):
# if field mapping with dict member mappings defined get value of each member # if field mapping with dict member mappings defined get value of each member
for member_name in member_names: for member_name in member_names:
if member_name in v: if member_name in v:
result[ result[field_mapping.get_dict_output_field_name(k, member_name)] = v[member_name]
field_mapping.get_dict_output_field_name(k, member_name)
] = v[member_name]
else: else:
# these special mapping rules are kept for backwards compatibility # these special mapping rules are kept for backwards compatibility
@@ -72,9 +70,7 @@ def parse_issue(issue, field_mapping):
if member_name in listItem: if member_name in listItem:
listValues.append(listItem[member_name]) listValues.append(listItem[member_name])
if len(listValues) > 0: if len(listValues) > 0:
result[ result[field_mapping.get_dict_output_field_name(k, member_name)] = ",".join(listValues)
field_mapping.get_dict_output_field_name(k, member_name)
] = ",".join(listValues)
else: else:
# otherwise support list values only for non-dict items # otherwise support list values only for non-dict items
@@ -114,7 +110,7 @@ class FieldMapping:
member_name = None member_name = None
# check for member name contained in field name # check for member name contained in field name
member_parser = re.search("(\w+)\.(\w+)", k) member_parser = re.search(r"(\w+)\.(\w+)", k)
if member_parser: if member_parser:
field_name = member_parser.group(1) field_name = member_parser.group(1)
member_name = member_parser.group(2) member_name = member_parser.group(2)

Some files were not shown because too many files have changed in this diff Show More