diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 7511c970a3..39b84d3869 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -9,7 +9,14 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from libs.login import current_account_with_tenant, login_required -from services.tag_service import TagService +from models.enums import TagType +from services.tag_service import ( + SaveTagPayload, + TagBindingCreatePayload, + TagBindingDeletePayload, + TagService, + UpdateTagPayload, +) dataset_tag_fields = { "id": fields.String, @@ -25,19 +32,19 @@ def build_dataset_tag_fields(api_or_ns: Namespace): class TagBasePayload(BaseModel): name: str = Field(description="Tag name", min_length=1, max_length=50) - type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + type: TagType = Field(description="Tag type") class TagBindingPayload(BaseModel): tag_ids: list[str] = Field(description="Tag IDs to bind") target_id: str = Field(description="Target ID to bind tags to") - type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + type: TagType = Field(description="Tag type") class TagBindingRemovePayload(BaseModel): tag_id: str = Field(description="Tag ID to remove") target_id: str = Field(description="Target ID to unbind tag from") - type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type") + type: TagType = Field(description="Tag type") class TagListQueryParam(BaseModel): @@ -82,7 +89,7 @@ class TagListApi(Resource): raise Forbidden() payload = TagBasePayload.model_validate(console_ns.payload or {}) - tag = TagService.save_tags(payload.model_dump()) + tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=payload.type)) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} @@ -103,7 +110,7 @@ class TagUpdateDeleteApi(Resource): raise Forbidden() payload = TagBasePayload.model_validate(console_ns.payload or {}) - tag = TagService.update_tags(payload.model_dump(), tag_id) + tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=payload.type), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -136,7 +143,9 @@ class TagBindingCreateApi(Resource): raise Forbidden() payload = TagBindingPayload.model_validate(console_ns.payload or {}) - TagService.save_tag_binding(payload.model_dump()) + TagService.save_tag_binding( + TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type) + ) return {"result": "success"}, 200 @@ -154,6 +163,8 @@ class TagBindingDeleteApi(Resource): raise Forbidden() payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) - TagService.delete_tag_binding(payload.model_dump()) + TagService.delete_tag_binding( + TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=payload.type) + ) return {"result": "success"}, 200 diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 80205b283b..fd954be6b1 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -22,10 +22,17 @@ from fields.tag_fields import DataSetTag from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum +from models.enums import TagType from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel -from services.tag_service import TagService +from services.tag_service import ( + SaveTagPayload, + TagBindingCreatePayload, + TagBindingDeletePayload, + TagService, + UpdateTagPayload, +) DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -513,7 +520,7 @@ class DatasetTagsApi(DatasetApiResource): raise Forbidden() payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) - tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) + tag = TagService.save_tags(SaveTagPayload(name=payload.name, type=TagType.KNOWLEDGE)) response = DataSetTag.model_validate( {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} @@ -536,9 +543,8 @@ class DatasetTagsApi(DatasetApiResource): raise Forbidden() payload = TagUpdatePayload.model_validate(service_api_ns.payload or {}) - params = {"name": payload.name, "type": "knowledge"} tag_id = payload.tag_id - tag = TagService.update_tags(params, tag_id) + tag = TagService.update_tags(UpdateTagPayload(name=payload.name, type=TagType.KNOWLEDGE), tag_id) binding_count = TagService.get_tag_binding_count(tag_id) @@ -585,7 +591,9 @@ class DatasetTagBindingApi(DatasetApiResource): raise Forbidden() payload = TagBindingPayload.model_validate(service_api_ns.payload or {}) - TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"}) + TagService.save_tag_binding( + TagBindingCreatePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE) + ) return "", 204 @@ -609,7 +617,9 @@ class DatasetTagUnbindingApi(DatasetApiResource): raise Forbidden() payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) - TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"}) + TagService.delete_tag_binding( + TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE) + ) return "", 204 diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 194622bd86..1882c855ea 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -2,6 +2,7 @@ import uuid import sqlalchemy as sa from flask_login import current_user +from pydantic import BaseModel, Field from sqlalchemy import func, select from werkzeug.exceptions import NotFound @@ -11,6 +12,28 @@ from models.enums import TagType from models.model import App, Tag, TagBinding +class SaveTagPayload(BaseModel): + name: str = Field(min_length=1, max_length=50) + type: TagType + + +class UpdateTagPayload(BaseModel): + name: str = Field(min_length=1, max_length=50) + type: TagType + + +class TagBindingCreatePayload(BaseModel): + tag_ids: list[str] + target_id: str + type: TagType + + +class TagBindingDeletePayload(BaseModel): + tag_id: str + target_id: str + type: TagType + + class TagService: @staticmethod def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None): @@ -78,12 +101,12 @@ class TagService: return tags or [] @staticmethod - def save_tags(args: dict) -> Tag: - if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]): + def save_tags(payload: SaveTagPayload) -> Tag: + if TagService.get_tag_by_tag_name(payload.type, current_user.current_tenant_id, payload.name): raise ValueError("Tag name already exists") tag = Tag( - name=args["name"], - type=TagType(args["type"]), + name=payload.name, + type=TagType(payload.type), created_by=current_user.id, tenant_id=current_user.current_tenant_id, ) @@ -93,13 +116,24 @@ class TagService: return tag @staticmethod - def update_tags(args: dict, tag_id: str) -> Tag: - if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")): - raise ValueError("Tag name already exists") + def update_tags(payload: UpdateTagPayload, tag_id: str) -> Tag: tag = db.session.scalar(select(Tag).where(Tag.id == tag_id).limit(1)) if not tag: raise NotFound("Tag not found") - tag.name = args["name"] + if payload.name != tag.name: + existing = db.session.scalar( + select(Tag) + .where( + Tag.name == payload.name, + Tag.tenant_id == current_user.current_tenant_id, + Tag.type == tag.type, + Tag.id != tag_id, + ) + .limit(1) + ) + if existing: + raise ValueError("Tag name already exists") + tag.name = payload.name db.session.commit() return tag @@ -122,21 +156,19 @@ class TagService: db.session.commit() @staticmethod - def save_tag_binding(args): - # check if target exists - TagService.check_target_exists(args["type"], args["target_id"]) - # save tag binding - for tag_id in args["tag_ids"]: + def save_tag_binding(payload: TagBindingCreatePayload): + TagService.check_target_exists(payload.type, payload.target_id) + for tag_id in payload.tag_ids: tag_binding = db.session.scalar( select(TagBinding) - .where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) + .where(TagBinding.tag_id == tag_id, TagBinding.target_id == payload.target_id) .limit(1) ) if tag_binding: continue new_tag_binding = TagBinding( tag_id=tag_id, - target_id=args["target_id"], + target_id=payload.target_id, tenant_id=current_user.current_tenant_id, created_by=current_user.id, ) @@ -144,17 +176,15 @@ class TagService: db.session.commit() @staticmethod - def delete_tag_binding(args): - # check if target exists - TagService.check_target_exists(args["type"], args["target_id"]) - # delete tag binding - tag_bindings = db.session.scalar( + def delete_tag_binding(payload: TagBindingDeletePayload): + TagService.check_target_exists(payload.type, payload.target_id) + tag_binding = db.session.scalar( select(TagBinding) - .where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == args["tag_id"]) + .where(TagBinding.target_id == payload.target_id, TagBinding.tag_id == payload.tag_id) .limit(1) ) - if tag_bindings: - db.session.delete(tag_bindings) + if tag_binding: + db.session.delete(tag_binding) db.session.commit() @staticmethod diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 77a5730cf4..9b913d6d3d 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -970,8 +970,10 @@ class TestDatasetTagBindingApiPost: result = api.post(_=None) assert result == ("", 204) + from services.tag_service import TagBindingCreatePayload + mock_tag_svc.save_tag_binding.assert_called_once_with( - {"tag_ids": ["tag-1"], "target_id": "ds-1", "type": "knowledge"} + TagBindingCreatePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge") ) @patch("controllers.service_api.dataset.dataset.current_user") @@ -1019,8 +1021,10 @@ class TestDatasetTagUnbindingApiPost: result = api.post(_=None) assert result == ("", 204) + from services.tag_service import TagBindingDeletePayload + mock_tag_svc.delete_tag_binding.assert_called_once_with( - {"tag_id": "tag-1", "target_id": "ds-1", "type": "knowledge"} + TagBindingDeletePayload(tag_id="tag-1", target_id="ds-1", type="knowledge") ) @patch("controllers.service_api.dataset.dataset.current_user") diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index f504f35589..5a6bf0466e 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -12,7 +12,13 @@ from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset from models.enums import DataSourceType, TagType from models.model import App, Tag, TagBinding -from services.tag_service import TagService +from services.tag_service import ( + SaveTagPayload, + TagBindingCreatePayload, + TagBindingDeletePayload, + TagService, + UpdateTagPayload, +) class TestTagService: @@ -685,7 +691,7 @@ class TestTagService: db_session_with_containers, mock_external_service_dependencies ) - tag_args = {"name": "test_tag_name", "type": "knowledge"} + tag_args = SaveTagPayload(name="test_tag_name", type="knowledge") # Act: Execute the method under test result = TagService.save_tags(tag_args) @@ -725,7 +731,7 @@ class TestTagService: ) # Create first tag - tag_args = {"name": "duplicate_tag", "type": "app"} + tag_args = SaveTagPayload(name="duplicate_tag", type="app") TagService.save_tags(tag_args) # Act & Assert: Verify proper error handling @@ -749,11 +755,11 @@ class TestTagService: ) # Create a tag to update - tag_args = {"name": "original_name", "type": "knowledge"} + tag_args = SaveTagPayload(name="original_name", type="knowledge") tag = TagService.save_tags(tag_args) # Update args - update_args = {"name": "updated_name", "type": "knowledge"} + update_args = UpdateTagPayload(name="updated_name", type="knowledge") # Act: Execute the method under test result = TagService.update_tags(update_args, tag.id) @@ -793,7 +799,7 @@ class TestTagService: non_existent_tag_id = str(uuid.uuid4()) - update_args = {"name": "updated_name", "type": "knowledge"} + update_args = UpdateTagPayload(name="updated_name", type="knowledge") # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: @@ -817,14 +823,14 @@ class TestTagService: ) # Create two tags - tag1_args = {"name": "first_tag", "type": "app"} + tag1_args = SaveTagPayload(name="first_tag", type="app") tag1 = TagService.save_tags(tag1_args) - tag2_args = {"name": "second_tag", "type": "app"} + tag2_args = SaveTagPayload(name="second_tag", type="app") tag2 = TagService.save_tags(tag2_args) # Try to update second tag with first tag's name - update_args = {"name": "first_tag", "type": "app"} + update_args = UpdateTagPayload(name="first_tag", type="app") # Act & Assert: Verify proper error handling with pytest.raises(ValueError) as exc_info: @@ -988,8 +994,10 @@ class TestTagService: dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Execute the method under test - binding_args = {"type": "knowledge", "target_id": dataset.id, "tag_ids": [tag.id for tag in tags]} - TagService.save_tag_binding(binding_args) + binding_payload = TagBindingCreatePayload( + type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags] + ) + TagService.save_tag_binding(binding_payload) # Assert: Verify the expected outcomes @@ -1030,11 +1038,11 @@ class TestTagService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Create first binding - binding_args = {"type": "app", "target_id": app.id, "tag_ids": [tag.id]} - TagService.save_tag_binding(binding_args) + binding_payload = TagBindingCreatePayload(type="app", target_id=app.id, tag_ids=[tag.id]) + TagService.save_tag_binding(binding_payload) # Act: Try to create duplicate binding - TagService.save_tag_binding(binding_args) + TagService.save_tag_binding(binding_payload) # Assert: Verify the expected outcomes @@ -1071,11 +1079,10 @@ class TestTagService: non_existent_target_id = str(uuid.uuid4()) # Act & Assert: Verify proper error handling - binding_args = {"type": "invalid_type", "target_id": non_existent_target_id, "tag_ids": [tag.id]} + from pydantic import ValidationError - with pytest.raises(NotFound) as exc_info: - TagService.save_tag_binding(binding_args) - assert "Invalid binding type" in str(exc_info.value) + with pytest.raises(ValidationError): + TagBindingCreatePayload(type="invalid_type", target_id=non_existent_target_id, tag_ids=[tag.id]) def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -1113,8 +1120,8 @@ class TestTagService: assert binding_before is not None # Act: Execute the method under test - delete_args = {"type": "knowledge", "target_id": dataset.id, "tag_id": tag.id} - TagService.delete_tag_binding(delete_args) + delete_payload = TagBindingDeletePayload(type="knowledge", target_id=dataset.id, tag_id=tag.id) + TagService.delete_tag_binding(delete_payload) # Assert: Verify the expected outcomes # Verify tag binding was deleted @@ -1149,8 +1156,8 @@ class TestTagService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Try to delete non-existent binding - delete_args = {"type": "app", "target_id": app.id, "tag_id": tag.id} - TagService.delete_tag_binding(delete_args) + delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_id=tag.id) + TagService.delete_tag_binding(delete_payload) # Assert: Verify the expected outcomes # No error should be raised, and database state should remain unchanged