mirror of
https://github.com/langgenius/dify.git
synced 2026-04-04 12:00:30 -04:00
refactor: select in 13 small service files (#34371)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -132,7 +132,7 @@ class AudioService:
|
||||
uuid.UUID(message_id)
|
||||
except ValueError:
|
||||
return None
|
||||
message = db.session.query(Message).where(Message.id == message_id).first()
|
||||
message = db.session.get(Message, message_id)
|
||||
if message is None:
|
||||
return None
|
||||
if message.answer == "" and message.status in {MessageStatus.NORMAL, MessageStatus.PAUSED}:
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
|
||||
from typing_extensions import TypedDict
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
@@ -158,10 +159,10 @@ class BillingService:
|
||||
def is_tenant_owner_or_admin(current_user: Account):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
join: TenantAccountJoin | None = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
join: TenantAccountJoin | None = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not join:
|
||||
|
||||
@@ -137,11 +137,11 @@ class ConversationService:
|
||||
@classmethod
|
||||
def auto_generate_name(cls, app_model: App, conversation: Conversation):
|
||||
# get conversation first message
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
message = db.session.scalar(
|
||||
select(Message)
|
||||
.where(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not message:
|
||||
@@ -160,8 +160,8 @@ class ConversationService:
|
||||
|
||||
@classmethod
|
||||
def get_conversation(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
conversation = db.session.scalar(
|
||||
select(Conversation)
|
||||
.where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
@@ -170,7 +170,7 @@ class ConversationService:
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
Conversation.is_deleted == False,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@@ -29,13 +29,13 @@ class CreditPoolService:
|
||||
@classmethod
|
||||
def get_pool(cls, tenant_id: str, pool_type: str = "trial") -> TenantCreditPool | None:
|
||||
"""get tenant credit pool"""
|
||||
return (
|
||||
db.session.query(TenantCreditPool)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
pool_type=pool_type,
|
||||
return db.session.scalar(
|
||||
select(TenantCreditPool)
|
||||
.where(
|
||||
TenantCreditPool.tenant_id == tenant_id,
|
||||
TenantCreditPool.pool_type == pool_type,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,6 +4,7 @@ import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from redis import RedisError
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
@@ -104,7 +105,9 @@ def sync_account_deletion(account_id: str, *, source: str) -> bool:
|
||||
return True
|
||||
|
||||
# Fetch all workspaces the account belongs to
|
||||
workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all()
|
||||
workspace_joins = db.session.scalars(
|
||||
select(TenantAccountJoin).where(TenantAccountJoin.account_id == account_id)
|
||||
).all()
|
||||
|
||||
# Queue sync task for each workspace
|
||||
success = True
|
||||
|
||||
@@ -110,7 +110,7 @@ class PipelineGenerateService:
|
||||
Update document status to waiting
|
||||
:param document_id: document id
|
||||
"""
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
document = db.session.get(Document, document_id)
|
||||
if document:
|
||||
document.indexing_status = IndexingStatus.WAITING
|
||||
db.session.add(document)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import yaml
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
@@ -32,12 +33,11 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
pipeline_customized_templates = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
pipeline_customized_templates = db.session.scalars(
|
||||
select(PipelineCustomizedTemplate)
|
||||
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
|
||||
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
recommended_pipelines_results = []
|
||||
for pipeline_customized_template in pipeline_customized_templates:
|
||||
recommended_pipeline_result = {
|
||||
@@ -59,9 +59,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:param template_id: Template ID
|
||||
:return:
|
||||
"""
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
pipeline_template = db.session.get(PipelineCustomizedTemplate, template_id)
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import yaml
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import PipelineBuiltInTemplate
|
||||
@@ -30,8 +31,10 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:return:
|
||||
"""
|
||||
|
||||
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
|
||||
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all()
|
||||
pipeline_built_in_templates = list(
|
||||
db.session.scalars(
|
||||
select(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language)
|
||||
).all()
|
||||
)
|
||||
|
||||
recommended_pipelines_results = []
|
||||
@@ -58,9 +61,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
||||
:return:
|
||||
"""
|
||||
# is in public recommended list
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first()
|
||||
)
|
||||
pipeline_template = db.session.get(PipelineBuiltInTemplate, template_id)
|
||||
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
||||
@@ -77,17 +77,15 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
:return:
|
||||
"""
|
||||
# is in public recommended list
|
||||
recommended_app = (
|
||||
db.session.query(RecommendedApp)
|
||||
.where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
|
||||
.first()
|
||||
recommended_app = db.session.scalar(
|
||||
select(RecommendedApp).where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id).limit(1)
|
||||
)
|
||||
|
||||
if not recommended_app:
|
||||
return None
|
||||
|
||||
# get app detail
|
||||
app_model = db.session.query(App).where(App.id == app_id).first()
|
||||
app_model = db.session.get(App, app_id)
|
||||
if not app_model or not app_model.is_public:
|
||||
return None
|
||||
|
||||
|
||||
@@ -64,15 +64,15 @@ class WebConversationService:
|
||||
def pin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
||||
if not user:
|
||||
return
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
pinned_conversation = db.session.scalar(
|
||||
select(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if pinned_conversation:
|
||||
@@ -96,15 +96,15 @@ class WebConversationService:
|
||||
def unpin(cls, app_model: App, conversation_id: str, user: Union[Account, EndUser] | None):
|
||||
if not user:
|
||||
return
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
pinned_conversation = db.session.scalar(
|
||||
select(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not pinned_conversation:
|
||||
|
||||
@@ -3,6 +3,7 @@ import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
@@ -92,10 +93,10 @@ class WebAppAuthService:
|
||||
|
||||
@classmethod
|
||||
def create_end_user(cls, app_code, email) -> EndUser:
|
||||
site = db.session.query(Site).where(Site.code == app_code).first()
|
||||
site = db.session.scalar(select(Site).where(Site.code == app_code).limit(1))
|
||||
if not site:
|
||||
raise NotFound("Site not found.")
|
||||
app_model = db.session.query(App).where(App.id == site.app_id).first()
|
||||
app_model = db.session.get(App, site.app_id)
|
||||
if not app_model:
|
||||
raise NotFound("App not found.")
|
||||
end_user = EndUser(
|
||||
|
||||
@@ -6,6 +6,7 @@ from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from sqlalchemy import select
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@@ -648,10 +649,10 @@ class WorkflowConverter:
|
||||
:param api_based_extension_id: api based extension id
|
||||
:return:
|
||||
"""
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
api_based_extension = db.session.scalar(
|
||||
select(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not api_based_extension:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
@@ -24,10 +25,10 @@ class WorkspaceService:
|
||||
}
|
||||
|
||||
# Get role of user
|
||||
tenant_account_join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
tenant_account_join = db.session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
assert tenant_account_join is not None, "TenantAccountJoin not found"
|
||||
tenant_info["role"] = tenant_account_join.role
|
||||
|
||||
@@ -421,11 +421,8 @@ class TestAudioServiceTTS:
|
||||
answer="Message answer text",
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
# Mock database lookup
|
||||
mock_db_session.get.return_value = message
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
@@ -568,11 +565,8 @@ class TestAudioServiceTTS:
|
||||
# Arrange
|
||||
app = factory.create_app_mock()
|
||||
|
||||
# Mock database query returning None
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = None
|
||||
# Mock database lookup returning None
|
||||
mock_db_session.get.return_value = None
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
@@ -594,11 +588,8 @@ class TestAudioServiceTTS:
|
||||
status=MessageStatus.NORMAL,
|
||||
)
|
||||
|
||||
# Mock database query
|
||||
mock_query = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
mock_query.first.return_value = message
|
||||
# Mock database lookup
|
||||
mock_db_session.get.return_value = message
|
||||
|
||||
# Act
|
||||
result = AudioService.transcript_tts(
|
||||
|
||||
@@ -865,16 +865,11 @@ class TestBillingServiceAccountManagement:
|
||||
mock_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_join.role = TenantAccountRole.OWNER
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_join
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = mock_join
|
||||
|
||||
# Act - should not raise exception
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
|
||||
# Assert
|
||||
mock_db_session.query.assert_called_once()
|
||||
|
||||
def test_is_tenant_owner_or_admin_admin(self, mock_db_session):
|
||||
"""Test tenant owner/admin check for admin role."""
|
||||
# Arrange
|
||||
@@ -885,16 +880,11 @@ class TestBillingServiceAccountManagement:
|
||||
mock_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_join.role = TenantAccountRole.ADMIN
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_join
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = mock_join
|
||||
|
||||
# Act - should not raise exception
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
|
||||
# Assert
|
||||
mock_db_session.query.assert_called_once()
|
||||
|
||||
def test_is_tenant_owner_or_admin_normal_user_raises_error(self, mock_db_session):
|
||||
"""Test tenant owner/admin check raises error for normal user."""
|
||||
# Arrange
|
||||
@@ -905,9 +895,7 @@ class TestBillingServiceAccountManagement:
|
||||
mock_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_join.role = TenantAccountRole.NORMAL
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_join
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = mock_join
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -921,9 +909,7 @@ class TestBillingServiceAccountManagement:
|
||||
current_user.id = "account-123"
|
||||
current_user.current_tenant_id = "tenant-456"
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -1135,9 +1121,7 @@ class TestBillingServiceEdgeCases:
|
||||
mock_join.role = TenantAccountRole.EDITOR # Editor is not privileged
|
||||
|
||||
with patch("services.billing_service.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_join
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_session.scalar.return_value = mock_join
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
@@ -1155,9 +1139,7 @@ class TestBillingServiceEdgeCases:
|
||||
mock_join.role = TenantAccountRole.DATASET_OPERATOR # Dataset operator is not privileged
|
||||
|
||||
with patch("services.billing_service.db.session") as mock_session:
|
||||
mock_query = MagicMock()
|
||||
mock_query.where.return_value.first.return_value = mock_join
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_session.scalar.return_value = mock_join
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
|
||||
@@ -355,15 +355,13 @@ class TestConversationServiceGetConversation:
|
||||
from_account_id=user.id, from_source=ConversationFromSource.CONSOLE
|
||||
)
|
||||
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.first.return_value = conversation
|
||||
mock_db_session.scalar.return_value = conversation
|
||||
|
||||
# Act
|
||||
result = ConversationService.get_conversation(app_model, "conv-123", user)
|
||||
|
||||
# Assert
|
||||
assert result == conversation
|
||||
mock_db_session.query.assert_called_once_with(Conversation)
|
||||
|
||||
@patch("services.conversation_service.db.session")
|
||||
def test_get_conversation_success_with_end_user(self, mock_db_session):
|
||||
@@ -379,8 +377,7 @@ class TestConversationServiceGetConversation:
|
||||
from_end_user_id=user.id, from_source=ConversationFromSource.API
|
||||
)
|
||||
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.first.return_value = conversation
|
||||
mock_db_session.scalar.return_value = conversation
|
||||
|
||||
# Act
|
||||
result = ConversationService.get_conversation(app_model, "conv-123", user)
|
||||
@@ -399,8 +396,7 @@ class TestConversationServiceGetConversation:
|
||||
app_model = ConversationServiceTestDataFactory.create_app_mock()
|
||||
user = ConversationServiceTestDataFactory.create_account_mock()
|
||||
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ConversationNotExistsError):
|
||||
@@ -489,8 +485,7 @@ class TestConversationServiceAutoGenerateName:
|
||||
)
|
||||
|
||||
# Mock database query to return message
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.order_by.return_value.first.return_value = message
|
||||
mock_db_session.scalar.return_value = message
|
||||
|
||||
# Mock LLM generator
|
||||
mock_llm_generator.generate_conversation_name.return_value = "Generated Name"
|
||||
@@ -518,8 +513,7 @@ class TestConversationServiceAutoGenerateName:
|
||||
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
|
||||
|
||||
# Mock database query to return None
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
@@ -541,8 +535,7 @@ class TestConversationServiceAutoGenerateName:
|
||||
)
|
||||
|
||||
# Mock database query to return message
|
||||
mock_query = mock_db_session.query.return_value
|
||||
mock_query.where.return_value.order_by.return_value.first.return_value = message
|
||||
mock_db_session.scalar.return_value = message
|
||||
|
||||
# Mock LLM generator to raise exception
|
||||
mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error")
|
||||
|
||||
Reference in New Issue
Block a user