upgrade flask (#6138)

* upgrade flask

Signed-off-by: Ye Sijun <junnplus@gmail.com>

* fix test

Signed-off-by: Ye Sijun <junnplus@gmail.com>

* override value_proc for click.prompt

Signed-off-by: Ye Sijun <junnplus@gmail.com>

---------

Signed-off-by: Ye Sijun <junnplus@gmail.com>
This commit is contained in:
Jun
2023-07-08 06:05:27 +09:00
committed by GitHub
parent d92fc98b13
commit 73f49cbf0c
22 changed files with 147 additions and 122 deletions

View File

@@ -42,7 +42,7 @@ def create_app():
app = Redash() app = Redash()
# Check and update the cached version for use by the client # Check and update the cached version for use by the client
app.before_first_request(reset_new_version_status) reset_new_version_status()
security.init_app(app) security.init_app(app)
request_metrics.init_app(app) request_metrics.init_app(app)

View File

@@ -5,15 +5,16 @@ import time
from datetime import timedelta from datetime import timedelta
from urllib.parse import urlsplit, urlunsplit from urllib.parse import urlsplit, urlunsplit
from flask import jsonify, redirect, request, url_for, session from flask import jsonify, redirect, request, session, url_for
from flask_login import LoginManager, login_user, logout_user, user_logged_in from flask_login import LoginManager, login_user, logout_user, user_logged_in
from sqlalchemy.orm.exc import NoResultFound
from werkzeug.exceptions import Unauthorized
from redash import models, settings from redash import models, settings
from redash.authentication import jwt_auth from redash.authentication import jwt_auth
from redash.authentication.org_resolving import current_org from redash.authentication.org_resolving import current_org
from redash.settings.organization import settings as org_settings from redash.settings.organization import settings as org_settings
from redash.tasks import record_event from redash.tasks import record_event
from sqlalchemy.orm.exc import NoResultFound
from werkzeug.exceptions import Unauthorized
login_manager = LoginManager() login_manager = LoginManager()
logger = logging.getLogger("authentication") logger = logging.getLogger("authentication")
@@ -216,12 +217,9 @@ def log_user_logged_in(app, user):
@login_manager.unauthorized_handler @login_manager.unauthorized_handler
def redirect_to_login(): def redirect_to_login():
if request.is_xhr or "/api/" in request.path: is_xhr = request.headers.get("X-Requested-With") == "XMLHttpRequest"
response = jsonify( if is_xhr or "/api/" in request.path:
{"message": "Couldn't find resource. Please login and try again."} return {"message": "Couldn't find resource. Please login and try again."}, 404
)
response.status_code = 404
return response
login_url = get_login_url(next=request.url, external=False) login_url = get_login_url(next=request.url, external=False)
@@ -242,14 +240,11 @@ def logout_and_redirect_to_index():
def init_app(app): def init_app(app):
from redash.authentication import ( from redash.authentication import ldap_auth, remote_user_auth, saml_auth
saml_auth, from redash.authentication.google_oauth import (
remote_user_auth, create_google_oauth_blueprint,
ldap_auth,
) )
from redash.authentication.google_oauth import create_google_oauth_blueprint
login_manager.init_app(app) login_manager.init_app(app)
login_manager.anonymous_user = models.AnonymousUser login_manager.anonymous_user = models.AnonymousUser
login_manager.REMEMBER_COOKIE_DURATION = settings.REMEMBER_COOKIE_DURATION login_manager.REMEMBER_COOKIE_DURATION = settings.REMEMBER_COOKIE_DURATION
@@ -262,7 +257,12 @@ def init_app(app):
from redash.security import csrf from redash.security import csrf
# Authlib's flask oauth client requires a Flask app to initialize # Authlib's flask oauth client requires a Flask app to initialize
for blueprint in [create_google_oauth_blueprint(app), saml_auth.blueprint, remote_user_auth.blueprint, ldap_auth.blueprint, ]: for blueprint in [
create_google_oauth_blueprint(app),
saml_auth.blueprint,
remote_user_auth.blueprint,
ldap_auth.blueprint,
]:
csrf.exempt(blueprint) csrf.exempt(blueprint)
app.register_blueprint(blueprint) app.register_blueprint(blueprint)

View File

@@ -4,14 +4,21 @@ from flask import current_app
from flask.cli import FlaskGroup, run_command, with_appcontext from flask.cli import FlaskGroup, run_command, with_appcontext
from rq import Connection from rq import Connection
from redash import __version__, create_app, settings, rq_redis_connection from redash import __version__, create_app, rq_redis_connection, settings
from redash.cli import data_sources, database, groups, organization, queries, users, rq from redash.cli import (
data_sources,
database,
groups,
organization,
queries,
rq,
users,
)
from redash.monitor import get_status from redash.monitor import get_status
def create(group): def create():
app = current_app or create_app() app = current_app or create_app()
group.app = app
@app.shell_context_processor @app.shell_context_processor
def shell_context(): def shell_context():
@@ -62,25 +69,23 @@ def send_test_mail(email=None):
""" """
Send test message to EMAIL (default: the address you defined in MAIL_DEFAULT_SENDER) Send test message to EMAIL (default: the address you defined in MAIL_DEFAULT_SENDER)
""" """
from redash import mail
from flask_mail import Message from flask_mail import Message
from redash import mail
if email is None: if email is None:
email = settings.MAIL_DEFAULT_SENDER email = settings.MAIL_DEFAULT_SENDER
mail.send( mail.send(Message(subject="Test Message from Redash", recipients=[email], body="Test message."))
Message(
subject="Test Message from Redash", recipients=[email], body="Test message."
)
)
@manager.command("shell") @manager.command("shell")
@with_appcontext @with_appcontext
def shell(): def shell():
import sys import sys
from ptpython import repl
from flask.globals import _app_ctx_stack from flask.globals import _app_ctx_stack
from ptpython import repl
app = _app_ctx_stack.top.app app = _app_ctx_stack.top.app

View File

@@ -1,6 +1,7 @@
from sys import exit from sys import exit
import click import click
from click.types import convert_type
from flask.cli import AppGroup from flask.cli import AppGroup
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
@@ -40,7 +41,7 @@ def list_command(organization=None):
) )
@manager.command() @manager.command(name="list_types")
def list_types(): def list_types():
print("Enabled Query Runners:") print("Enabled Query Runners:")
types = sorted(query_runners.keys()) types = sorted(query_runners.keys())
@@ -139,11 +140,19 @@ def new(name=None, type=None, options=None, organization="default"):
else: else:
prompt = "{} (optional)".format(prompt) prompt = "{} (optional)".format(prompt)
_type = types[prop["type"]]
def value_proc(value):
if value == default_value:
return default_value
return convert_type(_type, default_value)(value)
value = click.prompt( value = click.prompt(
prompt, prompt,
default=default_value, default=default_value,
type=types[prop["type"]], type=_type,
show_default=False, show_default=False,
value_proc=value_proc,
) )
if value != default_value: if value != default_value:
options_obj[k] = value options_obj[k] = value
@@ -154,7 +163,7 @@ def new(name=None, type=None, options=None, organization="default"):
if not options.is_valid(): if not options.is_valid():
print("Error: invalid configuration.") print("Error: invalid configuration.")
exit() exit(1)
print( print(
"Creating {} data source ({}) with options:\n{}".format( "Creating {} data source ({}) with options:\n{}".format(

View File

@@ -41,7 +41,7 @@ def load_extensions(db):
connection.execute(f'CREATE EXTENSION IF NOT EXISTS "{extension}";') connection.execute(f'CREATE EXTENSION IF NOT EXISTS "{extension}";')
@manager.command() @manager.command(name="create_tables")
def create_tables(): def create_tables():
"""Create the database tables.""" """Create the database tables."""
from redash.models import db from redash.models import db
@@ -61,7 +61,7 @@ def create_tables():
stamp() stamp()
@manager.command() @manager.command(name="drop_tables")
def drop_tables(): def drop_tables():
"""Drop the database tables.""" """Drop the database tables."""
from redash.models import db from redash.models import db

View File

@@ -1,8 +1,8 @@
from sys import exit from sys import exit
from sqlalchemy.orm.exc import NoResultFound
from flask.cli import AppGroup
from click import argument, option from click import argument, option
from flask.cli import AppGroup
from sqlalchemy.orm.exc import NoResultFound
from redash import models from redash import models
@@ -43,7 +43,7 @@ def create(name, permissions=None, organization="default"):
exit(1) exit(1)
@manager.command() @manager.command(name="change_permissions")
@argument("group_id") @argument("group_id")
@option( @option(
"--permissions", "--permissions",
@@ -119,4 +119,7 @@ def list_command(organization=None):
members = models.Group.members(group.id) members = models.Group.members(group.id)
user_names = [m.name for m in members] user_names = [m.name for m in members]
print("Users: {}".format(", ".join(user_names))) if user_names:
print("Users: {}".format(", ".join(user_names)))
else:
print("Users:")

View File

@@ -6,7 +6,7 @@ from redash import models
manager = AppGroup(help="Organization management commands.") manager = AppGroup(help="Organization management commands.")
@manager.command() @manager.command(name="set_google_apps_domains")
@argument("domains") @argument("domains")
def set_google_apps_domains(domains): def set_google_apps_domains(domains):
""" """
@@ -24,7 +24,7 @@ def set_google_apps_domains(domains):
) )
@manager.command() @manager.command(name="show_google_apps_domains")
def show_google_apps_domains(): def show_google_apps_domains():
organization = models.Organization.query.first() organization = models.Organization.query.first()
print( print(

View File

@@ -5,7 +5,7 @@ from sqlalchemy.orm.exc import NoResultFound
manager = AppGroup(help="Queries management commands.") manager = AppGroup(help="Queries management commands.")
@manager.command() @manager.command(name="add_tag")
@argument("query_id") @argument("query_id")
@argument("tag") @argument("tag")
def add_tag(query_id, tag): def add_tag(query_id, tag):
@@ -31,7 +31,7 @@ def add_tag(query_id, tag):
print("Tag added.") print("Tag added.")
@manager.command() @manager.command(name="remove_tag")
@argument("query_id") @argument("query_id")
@argument("tag") @argument("tag")
def remove_tag(query_id, tag): def remove_tag(query_id, tag):

View File

@@ -2,8 +2,8 @@ from sys import exit
from click import BOOL, argument, option, prompt from click import BOOL, argument, option, prompt
from flask.cli import AppGroup from flask.cli import AppGroup
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound
from redash import models from redash import models
from redash.handlers.users import invite_user from redash.handlers.users import invite_user
@@ -26,7 +26,7 @@ def build_groups(org, groups, is_admin):
return groups return groups
@manager.command() @manager.command(name="grant_admin")
@argument("email") @argument("email")
@option( @option(
"--org", "--org",
@@ -116,7 +116,7 @@ def create(
exit(1) exit(1)
@manager.command() @manager.command(name="create_root")
@argument("email") @argument("email")
@argument("name") @argument("name")
@option( @option(
@@ -155,9 +155,7 @@ def create_root(email, name, google_auth=False, password=None, organization="def
exit(1) exit(1)
org_slug = organization org_slug = organization
org = models.Organization.query.filter( org = models.Organization.query.filter(models.Organization.slug == org_slug).first()
models.Organization.slug == org_slug
).first()
if org is None: if org is None:
org = models.Organization(name=org_slug, slug=org_slug, settings={}) org = models.Organization(name=org_slug, slug=org_slug, settings={})

View File

@@ -1,19 +1,19 @@
import time import time
from inspect import isclass from inspect import isclass
from flask import Blueprint, current_app, request
from flask import Blueprint, current_app, request
from flask_login import current_user, login_required from flask_login import current_user, login_required
from flask_restful import Resource, abort from flask_restful import Resource, abort
from sqlalchemy import cast
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy_utils.functions import sort_query
from redash import settings from redash import settings
from redash.authentication import current_org from redash.authentication import current_org
from redash.models import db from redash.models import db
from redash.tasks import record_event as record_event_task from redash.tasks import record_event as record_event_task
from redash.utils import json_dumps from redash.utils import json_dumps
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy import cast
from sqlalchemy.dialects import postgresql
from sqlalchemy_utils import sort_query
routes = Blueprint( routes = Blueprint(
"redash", __name__, template_folder=settings.fix_assets_path("templates") "redash", __name__, template_folder=settings.fix_assets_path("templates")

View File

@@ -1,6 +1,7 @@
from flask import render_template, safe_join, send_file from flask import render_template, send_file
from flask_login import login_required from flask_login import login_required
from werkzeug.utils import safe_join
from redash import settings from redash import settings
from redash.handlers import routes from redash.handlers import routes
from redash.handlers.authentication import base_href from redash.handlers.authentication import base_href
@@ -13,7 +14,7 @@ def render_index():
response = render_template("multi_org.html", base_href=base_href()) response = render_template("multi_org.html", base_href=base_href())
else: else:
full_path = safe_join(settings.STATIC_ASSETS_PATH, "index.html") full_path = safe_join(settings.STATIC_ASSETS_PATH, "index.html")
response = send_file(full_path, **dict(cache_timeout=0, conditional=True)) response = send_file(full_path, **dict(max_age=0, conditional=True))
return response return response

View File

@@ -1,10 +1,10 @@
import functools import functools
from flask_sqlalchemy import BaseQuery, SQLAlchemy from flask_sqlalchemy import BaseQuery, SQLAlchemy
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import object_session from sqlalchemy.orm import object_session
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
from sqlalchemy_searchable import make_searchable, vectorizer, SearchQueryMixin from sqlalchemy_searchable import SearchQueryMixin, make_searchable, vectorizer
from sqlalchemy.dialects import postgresql
from redash import settings from redash import settings
from redash.utils import json_dumps from redash.utils import json_dumps
@@ -15,7 +15,7 @@ class RedashSQLAlchemy(SQLAlchemy):
options.update(json_serializer=json_dumps) options.update(json_serializer=json_dumps)
if settings.SQLALCHEMY_ENABLE_POOL_PRE_PING: if settings.SQLALCHEMY_ENABLE_POOL_PRE_PING:
options.update(pool_pre_ping=True) options.update(pool_pre_ping=True)
super(RedashSQLAlchemy, self).apply_driver_hacks(app, info, options) return super(RedashSQLAlchemy, self).apply_driver_hacks(app, info, options)
def apply_pool_defaults(self, app, options): def apply_pool_defaults(self, app, options):
super(RedashSQLAlchemy, self).apply_pool_defaults(app, options) super(RedashSQLAlchemy, self).apply_pool_defaults(app, options)
@@ -25,6 +25,7 @@ class RedashSQLAlchemy(SQLAlchemy):
options["poolclass"] = NullPool options["poolclass"] = NullPool
# Remove options NullPool does not support: # Remove options NullPool does not support:
options.pop("max_overflow", None) options.pop("max_overflow", None)
return options
db = RedashSQLAlchemy(session_options={"expire_on_commit": False}) db = RedashSQLAlchemy(session_options={"expire_on_commit": False})

View File

@@ -5,21 +5,20 @@ import time
from functools import reduce from functools import reduce
from operator import or_ from operator import or_
from flask import current_app as app, url_for, request_started from flask import current_app as app
from flask_login import current_user, AnonymousUserMixin, UserMixin from flask import request_started, url_for
from flask_login import AnonymousUserMixin, UserMixin, current_user
from passlib.apps import custom_app_context as pwd_context from passlib.apps import custom_app_context as pwd_context
from sqlalchemy.exc import DBAPIError
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy_utils import EmailType from sqlalchemy_utils import EmailType
from sqlalchemy_utils.models import generic_repr from sqlalchemy_utils.models import generic_repr
from redash import redis_connection from redash import redis_connection
from redash.utils import generate_token, utcnow, dt_from_timestamp from redash.utils import dt_from_timestamp, generate_token
from .base import db, Column, GFKBase, key_type, primary_key from .base import Column, GFKBase, db, key_type, primary_key
from .mixins import TimestampMixin, BelongsToOrgMixin from .mixins import BelongsToOrgMixin, TimestampMixin
from .types import json_cast_property, MutableDict, MutableList from .types import MutableDict, MutableList, json_cast_property
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -87,7 +86,9 @@ class User(
email = Column(EmailType) email = Column(EmailType)
password_hash = Column(db.String(128), nullable=True) password_hash = Column(db.String(128), nullable=True)
group_ids = Column( group_ids = Column(
"groups", MutableList.as_mutable(postgresql.ARRAY(key_type("Group"))), nullable=True "groups",
MutableList.as_mutable(postgresql.ARRAY(key_type("Group"))),
nullable=True,
) )
api_key = Column(db.String(40), default=lambda: generate_token(40), unique=True) api_key = Column(db.String(40), default=lambda: generate_token(40), unique=True)

View File

@@ -1,11 +1,11 @@
import logging import logging
import requests import requests
import semver import semver
from redash import __version__ as current_version from redash import __version__ as current_version
from redash import redis_connection from redash import redis_connection
from redash.models import db, Organization from redash.models import Organization, db
from redash.utils import json_dumps
REDIS_KEY = "new_version_available" REDIS_KEY = "new_version_available"

View File

@@ -1,23 +1,22 @@
Flask==1.1.1 Flask==2.3.2
Jinja2==2.11.3 Jinja2==3.1.2
itsdangerous==1.1.0 itsdangerous==2.1.2
black==19.10b0 click==8.1.3
click==6.7 MarkupSafe==2.1.1
MarkupSafe==1.1.1
pyOpenSSL==19.0.0 pyOpenSSL==19.0.0
httplib2==0.18.1 httplib2==0.18.1
wtforms==2.2.1 wtforms==2.2.1
Flask-RESTful==0.3.7 Flask-RESTful==0.3.10
Flask-Login==0.4.1 Flask-Login==0.6.0
Flask-SQLAlchemy==2.4.1 Flask-SQLAlchemy==2.5.1
Flask-Migrate==2.5.2 Flask-Migrate==2.5.2
flask-mail==0.9.1 flask-mail==0.9.1
flask-talisman==0.7.0 flask-talisman==0.7.0
Flask-Limiter==0.9.3 Flask-Limiter==0.9.3
Flask-WTF==0.14.3 Flask-WTF==1.1.1
passlib==1.7.1 passlib==1.7.1
aniso8601==8.0.0 aniso8601==8.0.0
blinker==1.4 blinker==1.6.2
psycopg2-binary==2.9.6 psycopg2-binary==2.9.6
python-dateutil==2.8.0 python-dateutil==2.8.0
pytz>=2019.3 pytz>=2019.3
@@ -58,7 +57,7 @@ gevent==21.12.0
sshtunnel==0.1.5 sshtunnel==0.1.5
supervisor==4.1.0 supervisor==4.1.0
supervisor_checks==0.8.1 supervisor_checks==0.8.1
werkzeug==0.16.1 werkzeug==2.3.3
# Uncomment the requirement for ldap3 if using ldap. # Uncomment the requirement for ldap3 if using ldap.
# It is not included by default because of the GPL license conflict. # It is not included by default because of the GPL license conflict.
# ldap3==2.2.4 # ldap3==2.2.4

View File

@@ -12,4 +12,4 @@ PyAthena>=1.5.0,<=1.11.5
ptvsd==4.3.2 ptvsd==4.3.2
freezegun==0.3.12 freezegun==0.3.12
watchdog==0.9.0 watchdog==0.9.0
ptpython==3.0.17 ptpython==3.0.23

View File

@@ -1,8 +1,8 @@
import os
import datetime import datetime
import logging import logging
from unittest import TestCase import os
from contextlib import contextmanager from contextlib import contextmanager
from unittest import TestCase
os.environ["REDASH_REDIS_URL"] = os.environ.get( os.environ["REDASH_REDIS_URL"] = os.environ.get(
"REDASH_REDIS_URL", "redis://localhost:6379/0" "REDASH_REDIS_URL", "redis://localhost:6379/0"
@@ -25,17 +25,16 @@ os.environ["REDASH_ENFORCE_CSRF"] = "false"
from redash import limiter, redis_connection from redash import limiter, redis_connection
from redash.app import create_app from redash.app import create_app
from redash.models import db from redash.models import db
from redash.utils import json_dumps, json_loads from redash.utils import json_dumps
from tests.factories import Factory, user_factory from tests.factories import Factory, user_factory
logging.disable(logging.INFO) logging.disable(logging.INFO)
logging.getLogger("metrics").setLevel(logging.ERROR) logging.getLogger("metrics").setLevel(logging.ERROR)
def authenticate_request(c, user): def authenticate_request(c, user):
with c.session_transaction() as sess: with c.session_transaction() as sess:
sess["user_id"] = user.get_id() sess["_user_id"] = user.get_id()
@contextmanager @contextmanager
@@ -110,11 +109,13 @@ class BaseTestCase(TestCase):
) )
return response return response
def get_request(self, path, org=None, headers=None): def get_request(self, path, org=None, headers=None, client=None):
if org: if org:
path = "/{}{}".format(org.slug, path) path = "/{}{}".format(org.slug, path)
return self.client.get(path, headers=headers) if client is None:
client = self.client
return client.get(path, headers=headers)
def post_request(self, path, data=None, org=None, headers=None): def post_request(self, path, data=None, org=None, headers=None):
if org: if org:

View File

@@ -1,6 +1,5 @@
from tests import BaseTestCase
from redash.models import Alert, AlertSubscription, db from redash.models import Alert, AlertSubscription, db
from tests import BaseTestCase
class TestAlertResourceGet(BaseTestCase): class TestAlertResourceGet(BaseTestCase):

View File

@@ -1,7 +1,8 @@
from redash import models, settings
from tests import BaseTestCase
from mock import patch from mock import patch
from redash import models
from tests import BaseTestCase
class TestUserListResourcePost(BaseTestCase): class TestUserListResourcePost(BaseTestCase):
def test_returns_403_for_non_admin(self): def test_returns_403_for_non_admin(self):
@@ -357,7 +358,7 @@ class TestUserResourcePost(BaseTestCase):
# visit profile page # visit profile page
self.make_request("get", "/api/users/{}".format(self.factory.user.id)) self.make_request("get", "/api/users/{}".format(self.factory.user.id))
with c.session_transaction() as sess: with c.session_transaction() as sess:
previous = sess["user_id"] previous = sess["_user_id"]
# change e-mail address - this will result in a new `user_id` value inside the session # change e-mail address - this will result in a new `user_id` value inside the session
self.make_request( self.make_request(
@@ -366,10 +367,13 @@ class TestUserResourcePost(BaseTestCase):
data={"email": "john@doe.com"}, data={"email": "john@doe.com"},
) )
with self.app.test_client() as c:
# force the old `user_id`, simulating that the user is logged in from another browser # force the old `user_id`, simulating that the user is logged in from another browser
with c.session_transaction() as sess: with c.session_transaction() as sess:
sess["user_id"] = previous sess["_user_id"] = previous
rv = self.get_request("/api/users/{}".format(self.factory.user.id)) rv = self.get_request(
"/api/users/{}".format(self.factory.user.id), client=c
)
self.assertEqual(rv.status_code, 404) self.assertEqual(rv.status_code, 404)
@@ -378,7 +382,7 @@ class TestUserResourcePost(BaseTestCase):
with self.client as c: with self.client as c:
with c.session_transaction() as sess: with c.session_transaction() as sess:
previous = sess["user_id"] previous = sess["_user_id"]
self.make_request( self.make_request(
"post", "post",
@@ -388,7 +392,7 @@ class TestUserResourcePost(BaseTestCase):
with self.client as c: with self.client as c:
with c.session_transaction() as sess: with c.session_transaction() as sess:
current = sess["user_id"] current = sess["_user_id"]
# make sure the session's `user_id` has changed to reflect the new identity, thus not logging the user out # make sure the session's `user_id` has changed to reflect the new identity, thus not logging the user out
self.assertNotEqual(previous, current) self.assertNotEqual(previous, current)

View File

@@ -4,6 +4,8 @@ import time
from flask import request from flask import request
from mock import patch from mock import patch
from sqlalchemy.orm.exc import NoResultFound
from redash import models, settings from redash import models, settings
from redash.authentication import ( from redash.authentication import (
api_key_load_user_from_request, api_key_load_user_from_request,
@@ -11,9 +13,10 @@ from redash.authentication import (
hmac_load_user_from_request, hmac_load_user_from_request,
sign, sign,
) )
from redash.authentication.google_oauth import create_and_login_user, verify_profile from redash.authentication.google_oauth import (
from redash.utils import utcnow create_and_login_user,
from sqlalchemy.orm.exc import NoResultFound verify_profile,
)
from tests import BaseTestCase from tests import BaseTestCase
@@ -264,9 +267,7 @@ class TestRedirectToUrlAfterLoggingIn(BaseTestCase):
data={"email": self.user.email, "password": self.password}, data={"email": self.user.email, "password": self.password},
org=self.factory.org, org=self.factory.org,
) )
self.assertEqual( self.assertEqual(response.location, "/{}/".format(self.user.org.slug))
response.location, "http://localhost/{}/".format(self.user.org.slug)
)
def test_simple_path_in_next_param(self): def test_simple_path_in_next_param(self):
response = self.post_request( response = self.post_request(
@@ -274,7 +275,7 @@ class TestRedirectToUrlAfterLoggingIn(BaseTestCase):
data={"email": self.user.email, "password": self.password}, data={"email": self.user.email, "password": self.password},
org=self.factory.org, org=self.factory.org,
) )
self.assertEqual(response.location, "http://localhost/default/queries") self.assertEqual(response.location, "queries")
def test_starts_scheme_url_in_next_param(self): def test_starts_scheme_url_in_next_param(self):
response = self.post_request( response = self.post_request(
@@ -282,7 +283,7 @@ class TestRedirectToUrlAfterLoggingIn(BaseTestCase):
data={"email": self.user.email, "password": self.password}, data={"email": self.user.email, "password": self.password},
org=self.factory.org, org=self.factory.org,
) )
self.assertEqual(response.location, "http://localhost/default/") self.assertEqual(response.location, "./")
def test_without_scheme_url_in_next_param(self): def test_without_scheme_url_in_next_param(self):
response = self.post_request( response = self.post_request(
@@ -290,7 +291,7 @@ class TestRedirectToUrlAfterLoggingIn(BaseTestCase):
data={"email": self.user.email, "password": self.password}, data={"email": self.user.email, "password": self.password},
org=self.factory.org, org=self.factory.org,
) )
self.assertEqual(response.location, "http://localhost/default/") self.assertEqual(response.location, "./")
def test_without_scheme_with_path_url_in_next_param(self): def test_without_scheme_with_path_url_in_next_param(self):
response = self.post_request( response = self.post_request(
@@ -298,7 +299,7 @@ class TestRedirectToUrlAfterLoggingIn(BaseTestCase):
data={"email": self.user.email, "password": self.password}, data={"email": self.user.email, "password": self.password},
org=self.factory.org, org=self.factory.org,
) )
self.assertEqual(response.location, "http://localhost/queries") self.assertEqual(response.location, "/queries")
class TestRemoteUserAuth(BaseTestCase): class TestRemoteUserAuth(BaseTestCase):

View File

@@ -1,12 +1,13 @@
import mock
import textwrap import textwrap
import mock
from click.testing import CliRunner from click.testing import CliRunner
from tests import BaseTestCase
from redash.utils.configuration import ConfigurationContainer
from redash.query_runner import query_runners
from redash.cli import manager from redash.cli import manager
from redash.models import DataSource, Group, Organization, User, db from redash.models import DataSource, Group, Organization, User, db
from redash.query_runner import query_runners
from redash.utils.configuration import ConfigurationContainer
from tests import BaseTestCase
class DataSourceCommandTests(BaseTestCase): class DataSourceCommandTests(BaseTestCase):
@@ -299,21 +300,21 @@ class GroupCommandTests(BaseTestCase):
Type: builtin Type: builtin
Organization: default Organization: default
Permissions: [admin,super_admin] Permissions: [admin,super_admin]
Users: Users:
-------------------- --------------------
Id: 4 Id: 4
Name: agroup Name: agroup
Type: regular Type: regular
Organization: default Organization: default
Permissions: [list_dashboards] Permissions: [list_dashboards]
Users: Users:
-------------------- --------------------
Id: 5 Id: 5
Name: bgroup Name: bgroup
Type: regular Type: regular
Organization: default Organization: default
Permissions: [list_dashboards] Permissions: [list_dashboards]
Users: Users:
-------------------- --------------------
Id: 2 Id: 2
Name: default Name: default
@@ -327,7 +328,7 @@ class GroupCommandTests(BaseTestCase):
Type: regular Type: regular
Organization: default Organization: default
Permissions: [list_dashboards] Permissions: [list_dashboards]
Users: Users:
""" """
self.assertMultiLineEqual(result.output, textwrap.dedent(output).lstrip()) self.assertMultiLineEqual(result.output, textwrap.dedent(output).lstrip())

View File

@@ -1,9 +1,9 @@
from flask_login import current_user from flask_login import current_user
from funcy import project from funcy import project
from mock import patch from mock import patch
from tests import BaseTestCase, authenticated_user
from redash import models, settings from redash import models, settings
from tests import BaseTestCase, authenticated_user
class AuthenticationTestMixin(object): class AuthenticationTestMixin(object):
@@ -22,7 +22,7 @@ class TestAuthentication(BaseTestCase):
def test_responds_with_success_for_signed_in_user(self): def test_responds_with_success_for_signed_in_user(self):
with self.client as c: with self.client as c:
with c.session_transaction() as sess: with c.session_transaction() as sess:
sess["user_id"] = self.factory.user.get_id() sess["_user_id"] = self.factory.user.get_id()
rv = self.client.get("/default/") rv = self.client.get("/default/")
self.assertEqual(200, rv.status_code) self.assertEqual(200, rv.status_code)
@@ -34,7 +34,7 @@ class TestAuthentication(BaseTestCase):
def test_redirects_for_invalid_session_identifier(self): def test_redirects_for_invalid_session_identifier(self):
with self.client as c: with self.client as c:
with c.session_transaction() as sess: with c.session_transaction() as sess:
sess["user_id"] = 100 sess["_user_id"] = 100
rv = self.client.get("/default/") rv = self.client.get("/default/")
self.assertEqual(302, rv.status_code) self.assertEqual(302, rv.status_code)
@@ -186,7 +186,7 @@ class TestLogin(BaseTestCase):
data={"email": user.email, "password": "password"}, data={"email": user.email, "password": "password"},
) )
self.assertEqual(rv.status_code, 302) self.assertEqual(rv.status_code, 302)
self.assertEqual(rv.location, "http://localhost/test") self.assertEqual(rv.location, "/test")
login_user_mock.assert_called_with(user, remember=False) login_user_mock.assert_called_with(user, remember=False)
def test_submit_incorrect_user(self): def test_submit_incorrect_user(self):
@@ -244,7 +244,9 @@ class TestLogin(BaseTestCase):
"/default/login", data={"email": user.email, "password": "password"} "/default/login", data={"email": user.email, "password": "password"}
) )
self.assertEqual(rv.status_code, 200) self.assertEqual(rv.status_code, 200)
self.assertIn("Password login is not enabled for your organization", str(rv.data)) self.assertIn(
"Password login is not enabled for your organization", str(rv.data)
)
class TestLogout(BaseTestCase): class TestLogout(BaseTestCase):