diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index c6930a76cb..ffb07916ec 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -21,13 +21,14 @@ from controllers.console.explore.error import ( NotCompletionAppError, ) from controllers.console.explore.wraps import InstalledAppResource +from controllers.console.wraps import with_current_user from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse from graphon.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.login import current_account_with_tenant +from models import Account from models.enums import FeedbackRating from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -59,8 +60,8 @@ register_response_schema_models(console_ns, ResultResponse, SuggestedQuestionsRe ) class MessageListApi(InstalledAppResource): @console_ns.expect(console_ns.models[MessageListQuery.__name__]) - def get(self, installed_app): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) @@ -96,8 +97,8 @@ class MessageListApi(InstalledAppResource): class MessageFeedbackApi(InstalledAppResource): @console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__]) @console_ns.response(200, "Feedback submitted successfully", console_ns.models[ResultResponse.__name__]) - def post(self, installed_app, message_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account, installed_app, message_id: UUID): app_model = installed_app.app message_id_str = str(message_id) @@ -124,8 +125,8 @@ class MessageFeedbackApi(InstalledAppResource): ) class MessageMoreLikeThisApi(InstalledAppResource): @console_ns.expect(console_ns.models[MoreLikeThisQuery.__name__]) - def get(self, installed_app, message_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, installed_app, message_id: UUID): app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() @@ -170,8 +171,8 @@ class MessageMoreLikeThisApi(InstalledAppResource): ) class MessageSuggestedQuestionApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[SuggestedQuestionsResponse.__name__]) - def get(self, installed_app, message_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, installed_app, message_id: UUID): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 3d41489435..cb63a52075 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -73,14 +73,13 @@ class TestMessageListApi: "/", query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "pagination_by_first_id", return_value=pagination, ), ): - result = method(installed_app) + result = method(MagicMock(), installed_app) assert result["limit"] == 20 assert result["has_more"] is False @@ -93,9 +92,8 @@ class TestMessageListApi: installed_app = MagicMock() installed_app.app = MagicMock(mode="completion") - with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): - with pytest.raises(NotChatAppError): - method(installed_app) + with pytest.raises(NotChatAppError): + method(MagicMock(), installed_app) def test_conversation_not_exists(self, app: Flask): api = module.MessageListApi() @@ -109,7 +107,6 @@ class TestMessageListApi: "/", query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "pagination_by_first_id", @@ -117,7 +114,7 @@ class TestMessageListApi: ), ): with pytest.raises(NotFound): - method(installed_app) + method(MagicMock(), installed_app) def test_first_message_not_exists(self, app: Flask): api = module.MessageListApi() @@ -131,7 +128,6 @@ class TestMessageListApi: "/", query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "pagination_by_first_id", @@ -139,7 +135,7 @@ class TestMessageListApi: ), ): with pytest.raises(NotFound): - method(installed_app) + method(MagicMock(), installed_app) class TestMessageFeedbackApi: @@ -152,13 +148,12 @@ class TestMessageFeedbackApi: with ( app.test_request_context("/", json={"rating": "like"}), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "create_feedback", ), ): - result = method(installed_app, "mid") + result = method(MagicMock(), installed_app, "mid") assert result["result"] == "success" @@ -171,7 +166,6 @@ class TestMessageFeedbackApi: with ( app.test_request_context("/", json={}), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "create_feedback", @@ -179,7 +173,7 @@ class TestMessageFeedbackApi: ), ): with pytest.raises(NotFound): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") class TestMessageMoreLikeThisApi: @@ -195,7 +189,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -207,7 +200,7 @@ class TestMessageMoreLikeThisApi: return_value=("ok", 200), ), ): - resp = method(installed_app, "mid") + resp = method(MagicMock(), installed_app, "mid") assert resp == ("ok", 200) @@ -218,9 +211,8 @@ class TestMessageMoreLikeThisApi: installed_app = MagicMock() installed_app.app = MagicMock(mode="chat") - with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): - with pytest.raises(NotCompletionAppError): - method(installed_app, "mid") + with pytest.raises(NotCompletionAppError): + method(MagicMock(), installed_app, "mid") def test_more_like_this_disabled(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -234,7 +226,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -242,7 +233,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(AppMoreLikeThisDisabledError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_message_not_exists_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -256,7 +247,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -264,7 +254,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(NotFound): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_provider_not_init_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -278,7 +268,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -286,7 +275,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_quota_exceeded_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -300,7 +289,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -308,7 +296,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_model_not_support_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -322,7 +310,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -330,7 +317,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_invoke_error_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -344,7 +331,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -352,7 +338,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(CompletionRequestError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_unexpected_error_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() @@ -366,7 +352,6 @@ class TestMessageMoreLikeThisApi: "/", query_string={"response_mode": "blocking"}, ), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.AppGenerateService, "generate_more_like_this", @@ -374,7 +359,7 @@ class TestMessageMoreLikeThisApi: ), ): with pytest.raises(InternalServerError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") class TestMessageSuggestedQuestionApi: @@ -386,14 +371,13 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", return_value=["q1", "q2"], ), ): - result = method(installed_app, "mid") + result = method(MagicMock(), installed_app, "mid") assert result["data"] == ["q1", "q2"] @@ -404,9 +388,8 @@ class TestMessageSuggestedQuestionApi: installed_app = MagicMock() installed_app.app = MagicMock(mode="completion") - with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): - with pytest.raises(NotChatAppError): - method(installed_app, "mid") + with pytest.raises(NotChatAppError): + method(MagicMock(), installed_app, "mid") def test_disabled(self): api = module.MessageSuggestedQuestionApi() @@ -416,7 +399,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -424,7 +406,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(AppSuggestedQuestionsAfterAnswerDisabledError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_message_not_exists_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -434,7 +416,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -442,7 +423,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(NotFound): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_conversation_not_exists_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -452,7 +433,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -460,7 +440,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(NotFound): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_provider_not_init_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -470,7 +450,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -478,7 +457,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_quota_exceeded_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -488,7 +467,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -496,7 +474,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_model_not_support_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -506,7 +484,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -514,7 +491,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_invoke_error_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -524,7 +501,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -532,7 +508,7 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(CompletionRequestError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid") def test_unexpected_error_suggested_question(self): api = module.MessageSuggestedQuestionApi() @@ -542,7 +518,6 @@ class TestMessageSuggestedQuestionApi: installed_app.app = MagicMock(mode="chat") with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -550,4 +525,4 @@ class TestMessageSuggestedQuestionApi: ), ): with pytest.raises(InternalServerError): - method(installed_app, "mid") + method(MagicMock(), installed_app, "mid")