auth tests wip

This commit is contained in:
Allen Short
2016-11-28 08:41:29 -06:00
parent ea166665d3
commit f00d77dec4
9 changed files with 62 additions and 51 deletions

View File

@@ -5,6 +5,7 @@ import time
import logging import logging
from flask import redirect, request, jsonify, url_for from flask import redirect, request, jsonify, url_for
from sqlalchemy.orm.exc import NoResultFound
from redash import models, settings from redash import models, settings
from redash.authentication.org_resolving import current_org from redash.authentication.org_resolving import current_org
@@ -36,9 +37,10 @@ def sign(key, path, expires):
@login_manager.user_loader @login_manager.user_loader
def load_user(user_id): def load_user(user_id):
org = current_org._get_current_object()
try: try:
return models.User.get_by_id_and_org(user_id, current_org.id) return models.User.get_by_id_and_org(user_id, org)
except models.User.DoesNotExist: except NoResultFound:
return None return None
@@ -51,14 +53,14 @@ def hmac_load_user_from_request(request):
# TODO: 3600 should be a setting # TODO: 3600 should be a setting
if signature and time.time() < expires <= time.time() + 3600: if signature and time.time() < expires <= time.time() + 3600:
if user_id: if user_id:
user = models.User.get_by_id(user_id) user = models.User.query.get(user_id)
calculated_signature = sign(user.api_key, request.path, expires) calculated_signature = sign(user.api_key, request.path, expires)
if user.api_key and signature == calculated_signature: if user.api_key and signature == calculated_signature:
return user return user
if query_id: if query_id:
query = models.Query.get(models.Query.id == query_id) query = models.db.session.query(models.Query).filter(models.Query.id == query_id).one()
calculated_signature = sign(query.api_key, request.path, expires) calculated_signature = sign(query.api_key, request.path, expires)
if query.api_key and signature == calculated_signature: if query.api_key and signature == calculated_signature:
@@ -74,15 +76,16 @@ def get_user_from_api_key(api_key, query_id):
user = None user = None
# TODO: once we switch all api key storage into the ApiKey model, this code will be much simplified # TODO: once we switch all api key storage into the ApiKey model, this code will be much simplified
org = current_org._get_current_object()
try: try:
user = models.User.get_by_api_key_and_org(api_key, current_org.id) user = models.User.get_by_api_key_and_org(api_key, org)
except models.User.DoesNotExist: except NoResultFound:
try: try:
api_key = models.ApiKey.get_by_api_key(api_key) api_key = models.ApiKey.get_by_api_key(api_key)
user = models.ApiUser(api_key, api_key.org, []) user = models.ApiUser(api_key, api_key.org, [])
except models.ApiKey.DoesNotExist: except NoResultFound:
if query_id: if query_id:
query = models.Query.get_by_id_and_org(query_id, current_org.id) query = models.Query.get_by_id_and_org(query_id, org)
if query and query.api_key == api_key: if query and query.api_key == api_key:
user = models.ApiUser(api_key, query.org, query.groups.keys(), name="ApiKey: Query {}".format(query.id)) user = models.ApiUser(api_key, query.org, query.groups.keys(), name="ApiKey: Query {}".format(query.id))
@@ -105,7 +108,6 @@ def get_api_key_from_request(request):
def api_key_load_user_from_request(request): def api_key_load_user_from_request(request):
api_key = get_api_key_from_request(request) api_key = get_api_key_from_request(request)
query_id = request.view_args.get('query_id', None) query_id = request.view_args.get('query_id', None)
user = get_user_from_api_key(api_key, query_id) user = get_user_from_api_key(api_key, query_id)
return user return user

View File

@@ -3,6 +3,8 @@ import requests
from flask import redirect, url_for, Blueprint, flash, request, session from flask import redirect, url_for, Blueprint, flash, request, session
from flask_login import login_user from flask_login import login_user
from flask_oauthlib.client import OAuth from flask_oauthlib.client import OAuth
from sqlalchemy.orm.exc import NoResultFound
from redash import models, settings from redash import models, settings
from redash.authentication.org_resolving import current_org from redash.authentication.org_resolving import current_org
@@ -63,9 +65,10 @@ def create_and_login_user(org, name, email):
logger.debug("Updating user name (%r -> %r)", user_object.name, name) logger.debug("Updating user name (%r -> %r)", user_object.name, name)
user_object.name = name user_object.name = name
user_object.save() user_object.save()
except models.User.DoesNotExist: except NoResultFound:
logger.debug("Creating user object (%r)", name) logger.debug("Creating user object (%r)", name)
user_object = models.User.create(org=org, name=name, email=email, group_ids=[org.default_group.id]) user_object = models.User(org=org, name=name, email=email, group_ids=[org.default_group.id])
models.db.session.add(user_object)
login_user(user_object, remember=True) login_user(user_object, remember=True)

View File

@@ -93,7 +93,7 @@ def paginate(query_set, page, page_size, serializer):
'count': count, 'count': count,
'page': page, 'page': page,
'page_size': page_size, 'page_size': page_size,
'results': [serializer(result) for result in results], 'results': [serializer(result) for result in results.items],
} }

View File

@@ -76,7 +76,8 @@ class QueryListResource(BaseResource):
@require_permission('view_query') @require_permission('view_query')
def get(self): def get(self):
results = models.Query.all_queries(self.current_user.groups) results = models.Query.all_queries([models.Group.query.get(g_id)
for g_id in self.current_user.group_ids])
page = request.args.get('page', 1, type=int) page = request.args.get('page', 1, type=int)
page_size = request.args.get('page_size', 25, type=int) page_size = request.args.get('page_size', 25, type=int)
return paginate(results, page, page_size, lambda q: q.to_dict(with_stats=True, with_last_modified_by=False)) return paginate(results, page, page_size, lambda q: q.to_dict(with_stats=True, with_last_modified_by=False))

View File

@@ -27,8 +27,8 @@ def calculate_metrics(response):
response.content_type, response.content_type,
response.content_length, response.content_length,
request_duration, request_duration,
db.database.query_count, # XXX instrument SQLA for metrics
db.database.query_duration) None, None)
statsd_client.timing('requests.{}.{}'.format(request.endpoint, request.method.lower()), request_duration) statsd_client.timing('requests.{}.{}'.format(request.endpoint, request.method.lower()), request_duration)

View File

@@ -128,7 +128,7 @@ class ConflictDetectedError(Exception):
class BelongsToOrgMixin(object): class BelongsToOrgMixin(object):
@classmethod @classmethod
def get_by_id_and_org(cls, object_id, org): def get_by_id_and_org(cls, object_id, org):
return cls.query.filter(cls.id == object_id, cls.org == org).one_or_none() return db.session.query(cls).filter(cls.id == object_id, cls.org == org).one_or_none()
class PermissionsCheckMixin(object): class PermissionsCheckMixin(object):
@@ -265,7 +265,7 @@ def create_group_hack(*a, **kw):
class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin): class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin):
id = Column(db.Integer, primary_key=True) id = Column(db.Integer, primary_key=True)
org_id = Column(db.Integer, db.ForeignKey('organizations.id')) org_id = Column(db.Integer, db.ForeignKey('organizations.id'))
org = db.relationship(Organization, backref="users") org = db.relationship(Organization, backref=db.backref("users", lazy="dynamic"))
name = Column(db.String(320)) name = Column(db.String(320))
email = Column(db.String(320)) email = Column(db.String(320))
password_hash = Column(db.String(128), nullable=True) password_hash = Column(db.String(128), nullable=True)
@@ -287,7 +287,7 @@ class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCh
'name': self.name, 'name': self.name,
'email': self.email, 'email': self.email,
'gravatar_url': self.gravatar_url, 'gravatar_url': self.gravatar_url,
'groups': self.groups, 'groups': self.group_ids,
'updated_at': self.updated_at, 'updated_at': self.updated_at,
'created_at': self.created_at 'created_at': self.created_at
} }
@@ -311,15 +311,15 @@ class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCh
def permissions(self): def permissions(self):
# TODO: this should be cached. # TODO: this should be cached.
return list(itertools.chain(*[g.permissions for g in return list(itertools.chain(*[g.permissions for g in
Group.select().where(Group.id << self.groups)])) Group.query.filter(Group.id.in_(self.group_ids))]))
@classmethod @classmethod
def get_by_email_and_org(cls, email, org): def get_by_email_and_org(cls, email, org):
return cls.get(cls.email == email, cls.org == org) return cls.query.filter(cls.email == email, cls.org == org).one()
@classmethod @classmethod
def get_by_api_key_and_org(cls, api_key, org): def get_by_api_key_and_org(cls, api_key, org):
return cls.get(cls.api_key == api_key, cls.org == org) return cls.query.filter(cls.api_key == api_key, cls.org == org).one()
@classmethod @classmethod
def all(cls, org): def all(cls, org):
@@ -566,7 +566,7 @@ class QueryResult(db.Model, BelongsToOrgMixin):
for q in queries: for q in queries:
q.latest_query_data = query_result q.latest_query_data = query_result
db.session.add(q) db.session.add(q)
query_ids = [q.id for q in queries] query_ids = [q.id for q in queries]
logging.info("Updated %s queries with result (%s).", len(query_ids), query_hash) logging.info("Updated %s queries with result (%s).", len(query_ids), query_hash)
return query_result, query_ids return query_result, query_ids
@@ -618,7 +618,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
latest_query_data = db.relationship(QueryResult) latest_query_data = db.relationship(QueryResult)
name = Column(db.String(255)) name = Column(db.String(255))
description = Column(db.String(4096), nullable=True) description = Column(db.String(4096), nullable=True)
query = Column(db.Text) query_text = Column("query", db.Text)
query_hash = Column(db.String(32)) query_hash = Column(db.String(32))
api_key = Column(db.String(40), default=generate_query_api_key) api_key = Column(db.String(40), default=generate_query_api_key)
user_id = Column(db.Integer, db.ForeignKey("users.id")) user_id = Column(db.Integer, db.ForeignKey("users.id"))
@@ -639,10 +639,10 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, with_last_modified_by=True): def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, with_last_modified_by=True):
d = { d = {
'id': self.id, 'id': self.id,
'latest_query_data_id': self._data.get('latest_query_data', None), 'latest_query_data_id': self.latest_query_data,
'name': self.name, 'name': self.name,
'description': self.description, 'description': self.description,
'query': self.query, 'query': self.query_text,
'query_hash': self.query_hash, 'query_hash': self.query_hash,
'schedule': self.schedule, 'schedule': self.schedule,
'api_key': self.api_key, 'api_key': self.api_key,
@@ -666,8 +666,12 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
d['last_modified_by_id'] = self.last_modified_by_id d['last_modified_by_id'] = self.last_modified_by_id
if with_stats: if with_stats:
d['retrieved_at'] = self.retrieved_at if self.latest_query_data is not None:
d['runtime'] = self.runtime d['retrieved_at'] = self.retrieved_at
d['runtime'] = self.runtime
else:
d['retrieved_at'] = None
d['runtime'] = None
if with_visualizations: if with_visualizations:
d['visualizations'] = [vis.to_dict(with_query=False) d['visualizations'] = [vis.to_dict(with_query=False)
@@ -692,9 +696,8 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
@classmethod @classmethod
def all_queries(cls, groups, drafts=False): def all_queries(cls, groups, drafts=False):
q = (db.session.query(Query) q = (cls.query.join(User, Query.user_id == User.id)
.outerjoin(QueryResult) .outerjoin(QueryResult)
.join(User, Query.user_id == User.id)
.join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id)
.filter(Query.is_archived == False) .filter(Query.is_archived == False)
.filter(DataSourceGroup.group_id.in_([g.id for g in groups]))\ .filter(DataSourceGroup.group_id.in_([g.id for g in groups]))\
@@ -714,7 +717,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
@classmethod @classmethod
def outdated_queries(cls): def outdated_queries(cls):
queries = (db.session.query(Query) queries = (cls.query(Query)
.join(QueryResult) .join(QueryResult)
.join(DataSource) .join(DataSource)
.filter(Query.schedule != None)) .filter(Query.schedule != None))
@@ -740,7 +743,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
where &= Query.is_archived == False where &= Query.is_archived == False
where &= DataSourceGroup.group_id.in_([g.id for g in groups]) where &= DataSourceGroup.group_id.in_([g.id for g in groups])
query_ids = ( query_ids = (
db.session.query(Query.id).join( cls.query(Query.id).join(
DataSourceGroup, DataSourceGroup,
Query.data_source_id == DataSourceGroup.data_source_id) Query.data_source_id == DataSourceGroup.data_source_id)
.filter(where)).distinct() .filter(where)).distinct()
@@ -750,7 +753,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
@classmethod @classmethod
def recent(cls, groups, user_id=None, limit=20): def recent(cls, groups, user_id=None, limit=20):
query = (db.session.query(Query).join(User, Query.user_id == User.id) query = (cls.query(Query).join(User, Query.user_id == User.id)
.filter(Event.created_at > (db.func.current_date() - 7)) .filter(Event.created_at > (db.func.current_date() - 7))
.join(Event, Query.id == Event.object_id.cast(db.Integer)) .join(Event, Query.id == Event.object_id.cast(db.Integer))
.join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id)
@@ -852,7 +855,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model):
def __unicode__(self): def __unicode__(self):
return unicode(self.id) return unicode(self.id)
@listens_for(Query.query, 'set') @listens_for(Query.query_text, 'set')
def gen_query_hash(target, val, oldval, initiator): def gen_query_hash(target, val, oldval, initiator):
target.query_hash = utils.gen_query_hash(val) target.query_hash = utils.gen_query_hash(val)
@@ -1051,7 +1054,7 @@ class Alert(TimestampMixin, db.Model):
def generate_slug(ctx): def generate_slug(ctx):
slug = utils.slugify(ctx.current_parameters['name']) slug = utils.slugify(ctx.current_parameters['name'])
tries = 1 tries = 1
while db.session.query(Dashboard).filter(Dashboard.slug == slug).first() is not None: while Dashboard.query.filter(Dashboard.slug == slug).first() is not None:
slug = utils.slugify(ctx.current_parameters['name']) + "_" + str(tries) slug = utils.slugify(ctx.current_parameters['name']) + "_" + str(tries)
tries += 1 tries += 1
return slug return slug
@@ -1134,7 +1137,7 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model
@classmethod @classmethod
def all(cls, org, group_ids, user_id): def all(cls, org, group_ids, user_id):
query = ( query = (
db.session.query(Dashboard) Dashboard.query
.outerjoin(Widget) .outerjoin(Widget)
.outerjoin(Visualization) .outerjoin(Visualization)
.outerjoin(Query) .outerjoin(Query)
@@ -1151,7 +1154,7 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model
@classmethod @classmethod
def recent(cls, org, group_ids, user_id, for_user=False, limit=20): def recent(cls, org, group_ids, user_id, for_user=False, limit=20):
query = (db.session.query(Dashboard) query = (Dashboard.query
.outerjoin(Event, Dashboard.id == Event.object_id.cast(db.Integer)) .outerjoin(Event, Dashboard.id == Event.object_id.cast(db.Integer))
.outerjoin(Widget) .outerjoin(Widget)
.outerjoin(Visualization) .outerjoin(Visualization)
@@ -1331,7 +1334,7 @@ class ApiKey(TimestampMixin, GFKBase, db.Model):
@classmethod @classmethod
def get_by_api_key(cls, api_key): def get_by_api_key(cls, api_key):
return cls.get(cls.api_key==api_key, cls.active==True) return cls.query.filter(cls.api_key==api_key, cls.active==True).one()
@classmethod @classmethod
def get_by_object(cls, object): def get_by_object(cls, object):

View File

@@ -64,7 +64,7 @@ api_key_factory = ModelFactory(redash.models.ApiKey,
query_factory = ModelFactory(redash.models.Query, query_factory = ModelFactory(redash.models.Query,
name='Query', name='Query',
description='', description='',
query='SELECT 1', query_text='SELECT 1',
user=user_factory.create, user=user_factory.create,
is_archived=False, is_archived=False,
is_draft=False, is_draft=False,
@@ -75,7 +75,7 @@ query_factory = ModelFactory(redash.models.Query,
query_with_params_factory = ModelFactory(redash.models.Query, query_with_params_factory = ModelFactory(redash.models.Query,
name='New Query with Params', name='New Query with Params',
description='', description='',
query='SELECT {{param1}}', query_text='SELECT {{param1}}',
user=user_factory.create, user=user_factory.create,
is_archived=False, is_archived=False,
is_draft=False, is_draft=False,
@@ -100,14 +100,14 @@ query_result_factory = ModelFactory(redash.models.QueryResult,
data='{"columns":{}, "rows":[]}', data='{"columns":{}, "rows":[]}',
runtime=1, runtime=1,
retrieved_at=utcnow, retrieved_at=utcnow,
query="SELECT 1", query_text="SELECT 1",
query_hash=gen_query_hash('SELECT 1'), query_hash=gen_query_hash('SELECT 1'),
data_source=data_source_factory.create, data_source=data_source_factory.create,
org_id=1) org_id=1)
visualization_factory = ModelFactory(redash.models.Visualization, visualization_factory = ModelFactory(redash.models.Visualization,
type='CHART', type='CHART',
query=query_factory.create, query_text=query_factory.create,
name='Chart', name='Chart',
description='', description='',
options='{}') options='{}')

View File

@@ -2,6 +2,7 @@ import json
from contextlib import contextmanager from contextlib import contextmanager
from tests.factories import user_factory from tests.factories import user_factory
from redash.models import db
from redash.utils import json_dumps from redash.utils import json_dumps
from redash.wsgi import app from redash.wsgi import app
@@ -10,6 +11,8 @@ app.config['TESTING'] = True
def authenticate_request(c, user): def authenticate_request(c, user):
with c.session_transaction() as sess: with c.session_transaction() as sess:
if user.id is None:
db.session.flush()
sess['user_id'] = user.id sess['user_id'] = user.id

View File

@@ -16,8 +16,9 @@ class TestApiKeyAuthentication(BaseTestCase):
# #
def setUp(self): def setUp(self):
super(TestApiKeyAuthentication, self).setUp() super(TestApiKeyAuthentication, self).setUp()
self.api_key = 10 self.api_key = '10'
self.query = self.factory.create_query(api_key=self.api_key) self.query = self.factory.create_query(api_key=self.api_key)
models.db.session.flush()
self.query_url = '/{}/api/queries/{}'.format(self.factory.org.slug, self.query.id) self.query_url = '/{}/api/queries/{}'.format(self.factory.org.slug, self.query.id)
self.queries_url = '/{}/api/queries'.format(self.factory.org.slug) self.queries_url = '/{}/api/queries'.format(self.factory.org.slug)
@@ -43,6 +44,7 @@ class TestApiKeyAuthentication(BaseTestCase):
def test_user_api_key(self): def test_user_api_key(self):
user = self.factory.create_user(api_key="user_key") user = self.factory.create_user(api_key="user_key")
models.db.session.flush()
with app.test_client() as c: with app.test_client() as c:
rv = c.get(self.queries_url, query_string={'api_key': user.api_key}) rv = c.get(self.queries_url, query_string={'api_key': user.api_key})
self.assertEqual(user.id, api_key_load_user_from_request(request).id) self.assertEqual(user.id, api_key_load_user_from_request(request).id)
@@ -71,8 +73,9 @@ class TestHMACAuthentication(BaseTestCase):
# #
def setUp(self): def setUp(self):
super(TestHMACAuthentication, self).setUp() super(TestHMACAuthentication, self).setUp()
self.api_key = 10 self.api_key = '10'
self.query = self.factory.create_query(api_key=self.api_key) self.query = self.factory.create_query(api_key=self.api_key)
models.db.session.flush()
self.path = '/{}/api/queries/{}'.format(self.query.org.slug, self.query.id) self.path = '/{}/api/queries/{}'.format(self.query.org.slug, self.query.id)
self.expires = time.time() + 1800 self.expires = time.time() + 1800
@@ -102,10 +105,11 @@ class TestHMACAuthentication(BaseTestCase):
def test_user_api_key(self): def test_user_api_key(self):
user = self.factory.create_user(api_key="user_key") user = self.factory.create_user(api_key="user_key")
path = '/api/queries/' path = '/api/queries/'
models.db.session.flush()
with app.test_client() as c: with app.test_client() as c:
signature = sign(user.api_key, path, self.expires) signature = sign(user.api_key, path, self.expires)
rv = c.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id}) rv = c.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id})
self.assertEqual(user.id, hmac_load_user_from_request(request).id) self.assertEqual(user, hmac_load_user_from_request(request))
class TestCreateAndLoginUser(BaseTestCase): class TestCreateAndLoginUser(BaseTestCase):
@@ -124,8 +128,8 @@ class TestCreateAndLoginUser(BaseTestCase):
create_and_login_user(self.factory.org, name, email) create_and_login_user(self.factory.org, name, email)
self.assertTrue(login_user_mock.called) self.assertTrue(login_user_mock.called)
user = models.User.get(models.User.email == email) user = models.User.query.filter(models.User.email == email).one()
self.assertEqual(user.email, email)
class TestVerifyProfile(BaseTestCase): class TestVerifyProfile(BaseTestCase):
def test_no_domain_allowed_for_org(self): def test_no_domain_allowed_for_org(self):
@@ -135,29 +139,24 @@ class TestVerifyProfile(BaseTestCase):
def test_domain_not_in_org_domains_list(self): def test_domain_not_in_org_domains_list(self):
profile = dict(email='arik@example.com') profile = dict(email='arik@example.com')
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org'] self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org']
self.factory.org.save()
self.assertFalse(verify_profile(self.factory.org, profile)) self.assertFalse(verify_profile(self.factory.org, profile))
def test_domain_in_org_domains_list(self): def test_domain_in_org_domains_list(self):
profile = dict(email='arik@example.com') profile = dict(email='arik@example.com')
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.com'] self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.com']
self.factory.org.save()
self.assertTrue(verify_profile(self.factory.org, profile)) self.assertTrue(verify_profile(self.factory.org, profile))
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org', 'example.com'] self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org', 'example.com']
self.factory.org.save()
self.assertTrue(verify_profile(self.factory.org, profile)) self.assertTrue(verify_profile(self.factory.org, profile))
def test_org_in_public_mode_accepts_any_domain(self): def test_org_in_public_mode_accepts_any_domain(self):
profile = dict(email='arik@example.com') profile = dict(email='arik@example.com')
self.factory.org.settings[models.Organization.SETTING_IS_PUBLIC] = True self.factory.org.settings[models.Organization.SETTING_IS_PUBLIC] = True
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = [] self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = []
self.factory.org.save()
self.assertTrue(verify_profile(self.factory.org, profile)) self.assertTrue(verify_profile(self.factory.org, profile))
def test_user_not_in_domain_but_account_exists(self): def test_user_not_in_domain_but_account_exists(self):
profile = dict(email='arik@example.com') profile = dict(email='arik@example.com')
self.factory.create_user(email='arik@example.com') self.factory.create_user(email='arik@example.com')
self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org'] self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org']
self.factory.org.save()
self.assertTrue(verify_profile(self.factory.org, profile)) self.assertTrue(verify_profile(self.factory.org, profile))