mirror of
https://github.com/getredash/redash.git
synced 2025-12-19 17:37:19 -05:00
auth tests wip
This commit is contained in:
@@ -5,6 +5,7 @@ import time
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from flask import redirect, request, jsonify, url_for
|
from flask import redirect, request, jsonify, url_for
|
||||||
|
from sqlalchemy.orm.exc import NoResultFound
|
||||||
|
|
||||||
from redash import models, settings
|
from redash import models, settings
|
||||||
from redash.authentication.org_resolving import current_org
|
from redash.authentication.org_resolving import current_org
|
||||||
@@ -36,9 +37,10 @@ def sign(key, path, expires):
|
|||||||
|
|
||||||
@login_manager.user_loader
|
@login_manager.user_loader
|
||||||
def load_user(user_id):
|
def load_user(user_id):
|
||||||
|
org = current_org._get_current_object()
|
||||||
try:
|
try:
|
||||||
return models.User.get_by_id_and_org(user_id, current_org.id)
|
return models.User.get_by_id_and_org(user_id, org)
|
||||||
except models.User.DoesNotExist:
|
except NoResultFound:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -51,14 +53,14 @@ def hmac_load_user_from_request(request):
|
|||||||
# TODO: 3600 should be a setting
|
# TODO: 3600 should be a setting
|
||||||
if signature and time.time() < expires <= time.time() + 3600:
|
if signature and time.time() < expires <= time.time() + 3600:
|
||||||
if user_id:
|
if user_id:
|
||||||
user = models.User.get_by_id(user_id)
|
user = models.User.query.get(user_id)
|
||||||
calculated_signature = sign(user.api_key, request.path, expires)
|
calculated_signature = sign(user.api_key, request.path, expires)
|
||||||
|
|
||||||
if user.api_key and signature == calculated_signature:
|
if user.api_key and signature == calculated_signature:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
if query_id:
|
if query_id:
|
||||||
query = models.Query.get(models.Query.id == query_id)
|
query = models.db.session.query(models.Query).filter(models.Query.id == query_id).one()
|
||||||
calculated_signature = sign(query.api_key, request.path, expires)
|
calculated_signature = sign(query.api_key, request.path, expires)
|
||||||
|
|
||||||
if query.api_key and signature == calculated_signature:
|
if query.api_key and signature == calculated_signature:
|
||||||
@@ -74,15 +76,16 @@ def get_user_from_api_key(api_key, query_id):
|
|||||||
user = None
|
user = None
|
||||||
|
|
||||||
# TODO: once we switch all api key storage into the ApiKey model, this code will be much simplified
|
# TODO: once we switch all api key storage into the ApiKey model, this code will be much simplified
|
||||||
|
org = current_org._get_current_object()
|
||||||
try:
|
try:
|
||||||
user = models.User.get_by_api_key_and_org(api_key, current_org.id)
|
user = models.User.get_by_api_key_and_org(api_key, org)
|
||||||
except models.User.DoesNotExist:
|
except NoResultFound:
|
||||||
try:
|
try:
|
||||||
api_key = models.ApiKey.get_by_api_key(api_key)
|
api_key = models.ApiKey.get_by_api_key(api_key)
|
||||||
user = models.ApiUser(api_key, api_key.org, [])
|
user = models.ApiUser(api_key, api_key.org, [])
|
||||||
except models.ApiKey.DoesNotExist:
|
except NoResultFound:
|
||||||
if query_id:
|
if query_id:
|
||||||
query = models.Query.get_by_id_and_org(query_id, current_org.id)
|
query = models.Query.get_by_id_and_org(query_id, org)
|
||||||
if query and query.api_key == api_key:
|
if query and query.api_key == api_key:
|
||||||
user = models.ApiUser(api_key, query.org, query.groups.keys(), name="ApiKey: Query {}".format(query.id))
|
user = models.ApiUser(api_key, query.org, query.groups.keys(), name="ApiKey: Query {}".format(query.id))
|
||||||
|
|
||||||
@@ -105,7 +108,6 @@ def get_api_key_from_request(request):
|
|||||||
def api_key_load_user_from_request(request):
|
def api_key_load_user_from_request(request):
|
||||||
api_key = get_api_key_from_request(request)
|
api_key = get_api_key_from_request(request)
|
||||||
query_id = request.view_args.get('query_id', None)
|
query_id = request.view_args.get('query_id', None)
|
||||||
|
|
||||||
user = get_user_from_api_key(api_key, query_id)
|
user = get_user_from_api_key(api_key, query_id)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ import requests
|
|||||||
from flask import redirect, url_for, Blueprint, flash, request, session
|
from flask import redirect, url_for, Blueprint, flash, request, session
|
||||||
from flask_login import login_user
|
from flask_login import login_user
|
||||||
from flask_oauthlib.client import OAuth
|
from flask_oauthlib.client import OAuth
|
||||||
|
from sqlalchemy.orm.exc import NoResultFound
|
||||||
|
|
||||||
from redash import models, settings
|
from redash import models, settings
|
||||||
from redash.authentication.org_resolving import current_org
|
from redash.authentication.org_resolving import current_org
|
||||||
|
|
||||||
@@ -63,9 +65,10 @@ def create_and_login_user(org, name, email):
|
|||||||
logger.debug("Updating user name (%r -> %r)", user_object.name, name)
|
logger.debug("Updating user name (%r -> %r)", user_object.name, name)
|
||||||
user_object.name = name
|
user_object.name = name
|
||||||
user_object.save()
|
user_object.save()
|
||||||
except models.User.DoesNotExist:
|
except NoResultFound:
|
||||||
logger.debug("Creating user object (%r)", name)
|
logger.debug("Creating user object (%r)", name)
|
||||||
user_object = models.User.create(org=org, name=name, email=email, group_ids=[org.default_group.id])
|
user_object = models.User(org=org, name=name, email=email, group_ids=[org.default_group.id])
|
||||||
|
models.db.session.add(user_object)
|
||||||
|
|
||||||
login_user(user_object, remember=True)
|
login_user(user_object, remember=True)
|
||||||
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ def paginate(query_set, page, page_size, serializer):
|
|||||||
'count': count,
|
'count': count,
|
||||||
'page': page,
|
'page': page,
|
||||||
'page_size': page_size,
|
'page_size': page_size,
|
||||||
'results': [serializer(result) for result in results],
|
'results': [serializer(result) for result in results.items],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -76,7 +76,8 @@ class QueryListResource(BaseResource):
|
|||||||
|
|
||||||
@require_permission('view_query')
|
@require_permission('view_query')
|
||||||
def get(self):
|
def get(self):
|
||||||
results = models.Query.all_queries(self.current_user.groups)
|
results = models.Query.all_queries([models.Group.query.get(g_id)
|
||||||
|
for g_id in self.current_user.group_ids])
|
||||||
page = request.args.get('page', 1, type=int)
|
page = request.args.get('page', 1, type=int)
|
||||||
page_size = request.args.get('page_size', 25, type=int)
|
page_size = request.args.get('page_size', 25, type=int)
|
||||||
return paginate(results, page, page_size, lambda q: q.to_dict(with_stats=True, with_last_modified_by=False))
|
return paginate(results, page, page_size, lambda q: q.to_dict(with_stats=True, with_last_modified_by=False))
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ def calculate_metrics(response):
|
|||||||
response.content_type,
|
response.content_type,
|
||||||
response.content_length,
|
response.content_length,
|
||||||
request_duration,
|
request_duration,
|
||||||
db.database.query_count,
|
# XXX instrument SQLA for metrics
|
||||||
db.database.query_duration)
|
None, None)
|
||||||
|
|
||||||
statsd_client.timing('requests.{}.{}'.format(request.endpoint, request.method.lower()), request_duration)
|
statsd_client.timing('requests.{}.{}'.format(request.endpoint, request.method.lower()), request_duration)
|
||||||
|
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ class ConflictDetectedError(Exception):
|
|||||||
class BelongsToOrgMixin(object):
|
class BelongsToOrgMixin(object):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_id_and_org(cls, object_id, org):
|
def get_by_id_and_org(cls, object_id, org):
|
||||||
return cls.query.filter(cls.id == object_id, cls.org == org).one_or_none()
|
return db.session.query(cls).filter(cls.id == object_id, cls.org == org).one_or_none()
|
||||||
|
|
||||||
|
|
||||||
class PermissionsCheckMixin(object):
|
class PermissionsCheckMixin(object):
|
||||||
@@ -265,7 +265,7 @@ def create_group_hack(*a, **kw):
|
|||||||
class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin):
|
class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin):
|
||||||
id = Column(db.Integer, primary_key=True)
|
id = Column(db.Integer, primary_key=True)
|
||||||
org_id = Column(db.Integer, db.ForeignKey('organizations.id'))
|
org_id = Column(db.Integer, db.ForeignKey('organizations.id'))
|
||||||
org = db.relationship(Organization, backref="users")
|
org = db.relationship(Organization, backref=db.backref("users", lazy="dynamic"))
|
||||||
name = Column(db.String(320))
|
name = Column(db.String(320))
|
||||||
email = Column(db.String(320))
|
email = Column(db.String(320))
|
||||||
password_hash = Column(db.String(128), nullable=True)
|
password_hash = Column(db.String(128), nullable=True)
|
||||||
@@ -287,7 +287,7 @@ class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCh
|
|||||||
'name': self.name,
|
'name': self.name,
|
||||||
'email': self.email,
|
'email': self.email,
|
||||||
'gravatar_url': self.gravatar_url,
|
'gravatar_url': self.gravatar_url,
|
||||||
'groups': self.groups,
|
'groups': self.group_ids,
|
||||||
'updated_at': self.updated_at,
|
'updated_at': self.updated_at,
|
||||||
'created_at': self.created_at
|
'created_at': self.created_at
|
||||||
}
|
}
|
||||||
@@ -311,15 +311,15 @@ class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCh
|
|||||||
def permissions(self):
|
def permissions(self):
|
||||||
# TODO: this should be cached.
|
# TODO: this should be cached.
|
||||||
return list(itertools.chain(*[g.permissions for g in
|
return list(itertools.chain(*[g.permissions for g in
|
||||||
Group.select().where(Group.id << self.groups)]))
|
Group.query.filter(Group.id.in_(self.group_ids))]))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_email_and_org(cls, email, org):
|
def get_by_email_and_org(cls, email, org):
|
||||||
return cls.get(cls.email == email, cls.org == org)
|
return cls.query.filter(cls.email == email, cls.org == org).one()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_api_key_and_org(cls, api_key, org):
|
def get_by_api_key_and_org(cls, api_key, org):
|
||||||
return cls.get(cls.api_key == api_key, cls.org == org)
|
return cls.query.filter(cls.api_key == api_key, cls.org == org).one()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def all(cls, org):
|
def all(cls, org):
|
||||||
@@ -566,7 +566,7 @@ class QueryResult(db.Model, BelongsToOrgMixin):
|
|||||||
for q in queries:
|
for q in queries:
|
||||||
q.latest_query_data = query_result
|
q.latest_query_data = query_result
|
||||||
db.session.add(q)
|
db.session.add(q)
|
||||||
query_ids = [q.id for q in queries]
|
query_ids = [q.id for q in queries]
|
||||||
logging.info("Updated %s queries with result (%s).", len(query_ids), query_hash)
|
logging.info("Updated %s queries with result (%s).", len(query_ids), query_hash)
|
||||||
|
|
||||||
return query_result, query_ids
|
return query_result, query_ids
|
||||||
@@ -618,7 +618,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
|
|||||||
latest_query_data = db.relationship(QueryResult)
|
latest_query_data = db.relationship(QueryResult)
|
||||||
name = Column(db.String(255))
|
name = Column(db.String(255))
|
||||||
description = Column(db.String(4096), nullable=True)
|
description = Column(db.String(4096), nullable=True)
|
||||||
query = Column(db.Text)
|
query_text = Column("query", db.Text)
|
||||||
query_hash = Column(db.String(32))
|
query_hash = Column(db.String(32))
|
||||||
api_key = Column(db.String(40), default=generate_query_api_key)
|
api_key = Column(db.String(40), default=generate_query_api_key)
|
||||||
user_id = Column(db.Integer, db.ForeignKey("users.id"))
|
user_id = Column(db.Integer, db.ForeignKey("users.id"))
|
||||||
@@ -639,10 +639,10 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
|
|||||||
def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, with_last_modified_by=True):
|
def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, with_last_modified_by=True):
|
||||||
d = {
|
d = {
|
||||||
'id': self.id,
|
'id': self.id,
|
||||||
'latest_query_data_id': self._data.get('latest_query_data', None),
|
'latest_query_data_id': self.latest_query_data,
|
||||||
'name': self.name,
|
'name': self.name,
|
||||||
'description': self.description,
|
'description': self.description,
|
||||||
'query': self.query,
|
'query': self.query_text,
|
||||||
'query_hash': self.query_hash,
|
'query_hash': self.query_hash,
|
||||||
'schedule': self.schedule,
|
'schedule': self.schedule,
|
||||||
'api_key': self.api_key,
|
'api_key': self.api_key,
|
||||||
@@ -666,8 +666,12 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
|
|||||||
d['last_modified_by_id'] = self.last_modified_by_id
|
d['last_modified_by_id'] = self.last_modified_by_id
|
||||||
|
|
||||||
if with_stats:
|
if with_stats:
|
||||||
d['retrieved_at'] = self.retrieved_at
|
if self.latest_query_data is not None:
|
||||||
d['runtime'] = self.runtime
|
d['retrieved_at'] = self.retrieved_at
|
||||||
|
d['runtime'] = self.runtime
|
||||||
|
else:
|
||||||
|
d['retrieved_at'] = None
|
||||||
|
d['runtime'] = None
|
||||||
|
|
||||||
if with_visualizations:
|
if with_visualizations:
|
||||||
d['visualizations'] = [vis.to_dict(with_query=False)
|
d['visualizations'] = [vis.to_dict(with_query=False)
|
||||||
@@ -692,9 +696,8 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def all_queries(cls, groups, drafts=False):
|
def all_queries(cls, groups, drafts=False):
|
||||||
q = (db.session.query(Query)
|
q = (cls.query.join(User, Query.user_id == User.id)
|
||||||
.outerjoin(QueryResult)
|
.outerjoin(QueryResult)
|
||||||
.join(User, Query.user_id == User.id)
|
|
||||||
.join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id)
|
.join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id)
|
||||||
.filter(Query.is_archived == False)
|
.filter(Query.is_archived == False)
|
||||||
.filter(DataSourceGroup.group_id.in_([g.id for g in groups]))\
|
.filter(DataSourceGroup.group_id.in_([g.id for g in groups]))\
|
||||||
@@ -714,7 +717,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def outdated_queries(cls):
|
def outdated_queries(cls):
|
||||||
queries = (db.session.query(Query)
|
queries = (cls.query(Query)
|
||||||
.join(QueryResult)
|
.join(QueryResult)
|
||||||
.join(DataSource)
|
.join(DataSource)
|
||||||
.filter(Query.schedule != None))
|
.filter(Query.schedule != None))
|
||||||
@@ -740,7 +743,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
|
|||||||
where &= Query.is_archived == False
|
where &= Query.is_archived == False
|
||||||
where &= DataSourceGroup.group_id.in_([g.id for g in groups])
|
where &= DataSourceGroup.group_id.in_([g.id for g in groups])
|
||||||
query_ids = (
|
query_ids = (
|
||||||
db.session.query(Query.id).join(
|
cls.query(Query.id).join(
|
||||||
DataSourceGroup,
|
DataSourceGroup,
|
||||||
Query.data_source_id == DataSourceGroup.data_source_id)
|
Query.data_source_id == DataSourceGroup.data_source_id)
|
||||||
.filter(where)).distinct()
|
.filter(where)).distinct()
|
||||||
@@ -750,7 +753,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def recent(cls, groups, user_id=None, limit=20):
|
def recent(cls, groups, user_id=None, limit=20):
|
||||||
query = (db.session.query(Query).join(User, Query.user_id == User.id)
|
query = (cls.query(Query).join(User, Query.user_id == User.id)
|
||||||
.filter(Event.created_at > (db.func.current_date() - 7))
|
.filter(Event.created_at > (db.func.current_date() - 7))
|
||||||
.join(Event, Query.id == Event.object_id.cast(db.Integer))
|
.join(Event, Query.id == Event.object_id.cast(db.Integer))
|
||||||
.join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id)
|
.join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id)
|
||||||
@@ -852,7 +855,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
|
|||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
return unicode(self.id)
|
return unicode(self.id)
|
||||||
|
|
||||||
@listens_for(Query.query, 'set')
|
@listens_for(Query.query_text, 'set')
|
||||||
def gen_query_hash(target, val, oldval, initiator):
|
def gen_query_hash(target, val, oldval, initiator):
|
||||||
target.query_hash = utils.gen_query_hash(val)
|
target.query_hash = utils.gen_query_hash(val)
|
||||||
|
|
||||||
@@ -1051,7 +1054,7 @@ class Alert(TimestampMixin, db.Model):
|
|||||||
def generate_slug(ctx):
|
def generate_slug(ctx):
|
||||||
slug = utils.slugify(ctx.current_parameters['name'])
|
slug = utils.slugify(ctx.current_parameters['name'])
|
||||||
tries = 1
|
tries = 1
|
||||||
while db.session.query(Dashboard).filter(Dashboard.slug == slug).first() is not None:
|
while Dashboard.query.filter(Dashboard.slug == slug).first() is not None:
|
||||||
slug = utils.slugify(ctx.current_parameters['name']) + "_" + str(tries)
|
slug = utils.slugify(ctx.current_parameters['name']) + "_" + str(tries)
|
||||||
tries += 1
|
tries += 1
|
||||||
return slug
|
return slug
|
||||||
@@ -1134,7 +1137,7 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model
|
|||||||
@classmethod
|
@classmethod
|
||||||
def all(cls, org, group_ids, user_id):
|
def all(cls, org, group_ids, user_id):
|
||||||
query = (
|
query = (
|
||||||
db.session.query(Dashboard)
|
Dashboard.query
|
||||||
.outerjoin(Widget)
|
.outerjoin(Widget)
|
||||||
.outerjoin(Visualization)
|
.outerjoin(Visualization)
|
||||||
.outerjoin(Query)
|
.outerjoin(Query)
|
||||||
@@ -1151,7 +1154,7 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def recent(cls, org, group_ids, user_id, for_user=False, limit=20):
|
def recent(cls, org, group_ids, user_id, for_user=False, limit=20):
|
||||||
query = (db.session.query(Dashboard)
|
query = (Dashboard.query
|
||||||
.outerjoin(Event, Dashboard.id == Event.object_id.cast(db.Integer))
|
.outerjoin(Event, Dashboard.id == Event.object_id.cast(db.Integer))
|
||||||
.outerjoin(Widget)
|
.outerjoin(Widget)
|
||||||
.outerjoin(Visualization)
|
.outerjoin(Visualization)
|
||||||
@@ -1331,7 +1334,7 @@ class ApiKey(TimestampMixin, GFKBase, db.Model):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_api_key(cls, api_key):
|
def get_by_api_key(cls, api_key):
|
||||||
return cls.get(cls.api_key==api_key, cls.active==True)
|
return cls.query.filter(cls.api_key==api_key, cls.active==True).one()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_by_object(cls, object):
|
def get_by_object(cls, object):
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ api_key_factory = ModelFactory(redash.models.ApiKey,
|
|||||||
query_factory = ModelFactory(redash.models.Query,
|
query_factory = ModelFactory(redash.models.Query,
|
||||||
name='Query',
|
name='Query',
|
||||||
description='',
|
description='',
|
||||||
query='SELECT 1',
|
query_text='SELECT 1',
|
||||||
user=user_factory.create,
|
user=user_factory.create,
|
||||||
is_archived=False,
|
is_archived=False,
|
||||||
is_draft=False,
|
is_draft=False,
|
||||||
@@ -75,7 +75,7 @@ query_factory = ModelFactory(redash.models.Query,
|
|||||||
query_with_params_factory = ModelFactory(redash.models.Query,
|
query_with_params_factory = ModelFactory(redash.models.Query,
|
||||||
name='New Query with Params',
|
name='New Query with Params',
|
||||||
description='',
|
description='',
|
||||||
query='SELECT {{param1}}',
|
query_text='SELECT {{param1}}',
|
||||||
user=user_factory.create,
|
user=user_factory.create,
|
||||||
is_archived=False,
|
is_archived=False,
|
||||||
is_draft=False,
|
is_draft=False,
|
||||||
@@ -100,14 +100,14 @@ query_result_factory = ModelFactory(redash.models.QueryResult,
|
|||||||
data='{"columns":{}, "rows":[]}',
|
data='{"columns":{}, "rows":[]}',
|
||||||
runtime=1,
|
runtime=1,
|
||||||
retrieved_at=utcnow,
|
retrieved_at=utcnow,
|
||||||
query="SELECT 1",
|
query_text="SELECT 1",
|
||||||
query_hash=gen_query_hash('SELECT 1'),
|
query_hash=gen_query_hash('SELECT 1'),
|
||||||
data_source=data_source_factory.create,
|
data_source=data_source_factory.create,
|
||||||
org_id=1)
|
org_id=1)
|
||||||
|
|
||||||
visualization_factory = ModelFactory(redash.models.Visualization,
|
visualization_factory = ModelFactory(redash.models.Visualization,
|
||||||
type='CHART',
|
type='CHART',
|
||||||
query=query_factory.create,
|
query_text=query_factory.create,
|
||||||
name='Chart',
|
name='Chart',
|
||||||
description='',
|
description='',
|
||||||
options='{}')
|
options='{}')
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import json
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
from tests.factories import user_factory
|
from tests.factories import user_factory
|
||||||
|
from redash.models import db
|
||||||
from redash.utils import json_dumps
|
from redash.utils import json_dumps
|
||||||
from redash.wsgi import app
|
from redash.wsgi import app
|
||||||
|
|
||||||
@@ -10,6 +11,8 @@ app.config['TESTING'] = True
|
|||||||
|
|
||||||
def authenticate_request(c, user):
|
def authenticate_request(c, user):
|
||||||
with c.session_transaction() as sess:
|
with c.session_transaction() as sess:
|
||||||
|
if user.id is None:
|
||||||
|
db.session.flush()
|
||||||
sess['user_id'] = user.id
|
sess['user_id'] = user.id
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,8 +16,9 @@ class TestApiKeyAuthentication(BaseTestCase):
|
|||||||
#
|
#
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TestApiKeyAuthentication, self).setUp()
|
super(TestApiKeyAuthentication, self).setUp()
|
||||||
self.api_key = 10
|
self.api_key = '10'
|
||||||
self.query = self.factory.create_query(api_key=self.api_key)
|
self.query = self.factory.create_query(api_key=self.api_key)
|
||||||
|
models.db.session.flush()
|
||||||
self.query_url = '/{}/api/queries/{}'.format(self.factory.org.slug, self.query.id)
|
self.query_url = '/{}/api/queries/{}'.format(self.factory.org.slug, self.query.id)
|
||||||
self.queries_url = '/{}/api/queries'.format(self.factory.org.slug)
|
self.queries_url = '/{}/api/queries'.format(self.factory.org.slug)
|
||||||
|
|
||||||
@@ -43,6 +44,7 @@ class TestApiKeyAuthentication(BaseTestCase):
|
|||||||
|
|
||||||
def test_user_api_key(self):
|
def test_user_api_key(self):
|
||||||
user = self.factory.create_user(api_key="user_key")
|
user = self.factory.create_user(api_key="user_key")
|
||||||
|
models.db.session.flush()
|
||||||
with app.test_client() as c:
|
with app.test_client() as c:
|
||||||
rv = c.get(self.queries_url, query_string={'api_key': user.api_key})
|
rv = c.get(self.queries_url, query_string={'api_key': user.api_key})
|
||||||
self.assertEqual(user.id, api_key_load_user_from_request(request).id)
|
self.assertEqual(user.id, api_key_load_user_from_request(request).id)
|
||||||
@@ -71,8 +73,9 @@ class TestHMACAuthentication(BaseTestCase):
|
|||||||
#
|
#
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TestHMACAuthentication, self).setUp()
|
super(TestHMACAuthentication, self).setUp()
|
||||||
self.api_key = 10
|
self.api_key = '10'
|
||||||
self.query = self.factory.create_query(api_key=self.api_key)
|
self.query = self.factory.create_query(api_key=self.api_key)
|
||||||
|
models.db.session.flush()
|
||||||
self.path = '/{}/api/queries/{}'.format(self.query.org.slug, self.query.id)
|
self.path = '/{}/api/queries/{}'.format(self.query.org.slug, self.query.id)
|
||||||
self.expires = time.time() + 1800
|
self.expires = time.time() + 1800
|
||||||
|
|
||||||
@@ -102,10 +105,11 @@ class TestHMACAuthentication(BaseTestCase):
|
|||||||
def test_user_api_key(self):
|
def test_user_api_key(self):
|
||||||
user = self.factory.create_user(api_key="user_key")
|
user = self.factory.create_user(api_key="user_key")
|
||||||
path = '/api/queries/'
|
path = '/api/queries/'
|
||||||
|
models.db.session.flush()
|
||||||
with app.test_client() as c:
|
with app.test_client() as c:
|
||||||
signature = sign(user.api_key, path, self.expires)
|
signature = sign(user.api_key, path, self.expires)
|
||||||
rv = c.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id})
|
rv = c.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id})
|
||||||
self.assertEqual(user.id, hmac_load_user_from_request(request).id)
|
self.assertEqual(user, hmac_load_user_from_request(request))
|
||||||
|
|
||||||
|
|
||||||
class TestCreateAndLoginUser(BaseTestCase):
|
class TestCreateAndLoginUser(BaseTestCase):
|
||||||
@@ -124,8 +128,8 @@ class TestCreateAndLoginUser(BaseTestCase):
|
|||||||
create_and_login_user(self.factory.org, name, email)
|
create_and_login_user(self.factory.org, name, email)
|
||||||
|
|
||||||
self.assertTrue(login_user_mock.called)
|
self.assertTrue(login_user_mock.called)
|
||||||
user = models.User.get(models.User.email == email)
|
user = models.User.query.filter(models.User.email == email).one()
|
||||||
|
self.assertEqual(user.email, email)
|
||||||
|
|
||||||
class TestVerifyProfile(BaseTestCase):
|
class TestVerifyProfile(BaseTestCase):
|
||||||
def test_no_domain_allowed_for_org(self):
|
def test_no_domain_allowed_for_org(self):
|
||||||
@@ -135,29 +139,24 @@ class TestVerifyProfile(BaseTestCase):
|
|||||||
def test_domain_not_in_org_domains_list(self):
|
def test_domain_not_in_org_domains_list(self):
|
||||||
profile = dict(email='arik@example.com')
|
profile = dict(email='arik@example.com')
|
||||||
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org']
|
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org']
|
||||||
self.factory.org.save()
|
|
||||||
self.assertFalse(verify_profile(self.factory.org, profile))
|
self.assertFalse(verify_profile(self.factory.org, profile))
|
||||||
|
|
||||||
def test_domain_in_org_domains_list(self):
|
def test_domain_in_org_domains_list(self):
|
||||||
profile = dict(email='arik@example.com')
|
profile = dict(email='arik@example.com')
|
||||||
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.com']
|
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.com']
|
||||||
self.factory.org.save()
|
|
||||||
self.assertTrue(verify_profile(self.factory.org, profile))
|
self.assertTrue(verify_profile(self.factory.org, profile))
|
||||||
|
|
||||||
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org', 'example.com']
|
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org', 'example.com']
|
||||||
self.factory.org.save()
|
|
||||||
self.assertTrue(verify_profile(self.factory.org, profile))
|
self.assertTrue(verify_profile(self.factory.org, profile))
|
||||||
|
|
||||||
def test_org_in_public_mode_accepts_any_domain(self):
|
def test_org_in_public_mode_accepts_any_domain(self):
|
||||||
profile = dict(email='arik@example.com')
|
profile = dict(email='arik@example.com')
|
||||||
self.factory.org.settings[models.Organization.SETTING_IS_PUBLIC] = True
|
self.factory.org.settings[models.Organization.SETTING_IS_PUBLIC] = True
|
||||||
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = []
|
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = []
|
||||||
self.factory.org.save()
|
|
||||||
self.assertTrue(verify_profile(self.factory.org, profile))
|
self.assertTrue(verify_profile(self.factory.org, profile))
|
||||||
|
|
||||||
def test_user_not_in_domain_but_account_exists(self):
|
def test_user_not_in_domain_but_account_exists(self):
|
||||||
profile = dict(email='arik@example.com')
|
profile = dict(email='arik@example.com')
|
||||||
self.factory.create_user(email='arik@example.com')
|
self.factory.create_user(email='arik@example.com')
|
||||||
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org']
|
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org']
|
||||||
self.factory.org.save()
|
|
||||||
self.assertTrue(verify_profile(self.factory.org, profile))
|
self.assertTrue(verify_profile(self.factory.org, profile))
|
||||||
|
|||||||
Reference in New Issue
Block a user