diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index d001dfba64..0e91779b2c 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -606,63 +606,63 @@ class DatasetIndexingEstimateApi(Resource): # validate args DocumentService.estimate_args_validate(args) extract_settings = [] - if args["info_list"]["data_source_type"] == "upload_file": - file_ids = args["info_list"]["file_info_list"]["file_ids"] - file_details = db.session.scalars( - select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids)) - ).all() + match args["info_list"]["data_source_type"]: + case "upload_file": + file_ids = args["info_list"]["file_info_list"]["file_ids"] + file_details = db.session.scalars( + select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids)) + ).all() + if file_details is None: + raise NotFound("File not found.") - if file_details is None: - raise NotFound("File not found.") - - if file_details: - for file_detail in file_details: + if file_details: + for file_detail in file_details: + extract_setting = ExtractSetting( + datasource_type=DatasourceType.FILE, + upload_file=file_detail, + document_model=args["doc_form"], + ) + extract_settings.append(extract_setting) + case "notion_import": + notion_info_list = args["info_list"]["notion_info_list"] + for notion_info in notion_info_list: + workspace_id = notion_info["workspace_id"] + credential_id = notion_info.get("credential_id") + for page in notion_info["pages"]: + extract_setting = ExtractSetting( + datasource_type=DatasourceType.NOTION, + notion_info=NotionInfo.model_validate( + { + "credential_id": credential_id, + "notion_workspace_id": workspace_id, + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_tenant_id, + } + ), + document_model=args["doc_form"], + ) + extract_settings.append(extract_setting) + case "website_crawl": + website_info_list = args["info_list"]["website_info_list"] + for url in website_info_list["urls"]: extract_setting = ExtractSetting( - datasource_type=DatasourceType.FILE, - upload_file=file_detail, - document_model=args["doc_form"], - ) - extract_settings.append(extract_setting) - elif args["info_list"]["data_source_type"] == "notion_import": - notion_info_list = args["info_list"]["notion_info_list"] - for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] - credential_id = notion_info.get("credential_id") - for page in notion_info["pages"]: - extract_setting = ExtractSetting( - datasource_type=DatasourceType.NOTION, - notion_info=NotionInfo.model_validate( + datasource_type=DatasourceType.WEBSITE, + website_info=WebsiteInfo.model_validate( { - "credential_id": credential_id, - "notion_workspace_id": workspace_id, - "notion_obj_id": page["page_id"], - "notion_page_type": page["type"], + "provider": website_info_list["provider"], + "job_id": website_info_list["job_id"], + "url": url, "tenant_id": current_tenant_id, + "mode": "crawl", + "only_main_content": website_info_list["only_main_content"], } ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) - elif args["info_list"]["data_source_type"] == "website_crawl": - website_info_list = args["info_list"]["website_info_list"] - for url in website_info_list["urls"]: - extract_setting = ExtractSetting( - datasource_type=DatasourceType.WEBSITE, - website_info=WebsiteInfo.model_validate( - { - "provider": website_info_list["provider"], - "job_id": website_info_list["job_id"], - "url": url, - "tenant_id": current_tenant_id, - "mode": "crawl", - "only_main_content": website_info_list["only_main_content"], - } - ), - document_model=args["doc_form"], - ) - extract_settings.append(extract_setting) - else: - raise ValueError("Data source type not support") + case _: + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: response = indexing_runner.indexing_estimate( diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 3372a967d9..c4e13c41a5 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -369,28 +369,31 @@ class DatasetDocumentListApi(Resource): else: sort_logic = asc - if sort == "hit_count": - sub_query = ( - sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")) - .where(DocumentSegment.dataset_id == str(dataset_id)) - .group_by(DocumentSegment.document_id) - .subquery() - ) + match sort: + case "hit_count": + sub_query = ( + sa.select( + DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count") + ) + .where(DocumentSegment.dataset_id == str(dataset_id)) + .group_by(DocumentSegment.document_id) + .subquery() + ) - query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( - sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)), - sort_logic(Document.position), - ) - elif sort == "created_at": - query = query.order_by( - sort_logic(Document.created_at), - sort_logic(Document.position), - ) - else: - query = query.order_by( - desc(Document.created_at), - desc(Document.position), - ) + query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( + sort_logic(sa.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) + case "created_at": + query = query.order_by( + sort_logic(Document.created_at), + sort_logic(Document.position), + ) + case _: + query = query.order_by( + desc(Document.created_at), + desc(Document.position), + ) paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 1665bdeb52..e836554ca0 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -123,12 +123,15 @@ class SimplePromptTransform(PromptTransform): for v in special_variable_keys: # support #context#, #query# and #histories# - if v == "#context#": - variables["#context#"] = context or "" - elif v == "#query#": - variables["#query#"] = query or "" - elif v == "#histories#": - variables["#histories#"] = histories or "" + match v: + case "#context#": + variables["#context#"] = context or "" + case "#query#": + variables["#query#"] = query or "" + case "#histories#": + variables["#histories#"] = histories or "" + case _: + pass prompt_template = prompt_template_config["prompt_template"] if not isinstance(prompt_template, PromptTemplateParser): diff --git a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py index f73ba01c8b..be9d64ae01 100644 --- a/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py +++ b/api/providers/trace/trace-langsmith/src/dify_trace_langsmith/entities/langsmith_trace_entity.py @@ -65,35 +65,18 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): } file_list = values.get("file_list", []) if isinstance(v, str): - if field_name == "inputs": - return { - "messages": { - "role": "user", - "content": v, - "usage_metadata": usage_metadata, - "file_list": file_list, - }, - } - elif field_name == "outputs": - return { - "choices": { - "role": "ai", - "content": v, - "usage_metadata": usage_metadata, - "file_list": file_list, - }, - } - elif isinstance(v, list): - data = {} - if len(v) > 0 and isinstance(v[0], dict): - # rename text to content - v = replace_text_with_content(data=v) - if field_name == "inputs": - data = { - "messages": v, + match field_name: + case "inputs": + return { + "messages": { + "role": "user", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, } - elif field_name == "outputs": - data = { + case "outputs": + return { "choices": { "role": "ai", "content": v, @@ -101,6 +84,29 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): "file_list": file_list, }, } + case _: + pass + elif isinstance(v, list): + data = {} + if len(v) > 0 and isinstance(v[0], dict): + # rename text to content + v = replace_text_with_content(data=v) + match field_name: + case "inputs": + data = { + "messages": v, + } + case "outputs": + data = { + "choices": { + "role": "ai", + "content": v, + "usage_metadata": usage_metadata, + "file_list": file_list, + }, + } + case _: + pass return data else: return { diff --git a/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py index 843c495d82..d6998f6672 100644 --- a/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py +++ b/api/providers/vdb/vdb-opensearch/src/dify_vdb_opensearch/opensearch_vector.py @@ -81,14 +81,15 @@ class OpenSearchConfig(BaseModel): pool_maxsize=20, ) - if self.auth_method == "basic": - logger.info("Using basic authentication for OpenSearch Vector DB") + match self.auth_method: + case AuthMethod.BASIC: + logger.info("Using basic authentication for OpenSearch Vector DB") - params["http_auth"] = (self.user, self.password) - elif self.auth_method == "aws_managed_iam": - logger.info("Using AWS managed IAM role for OpenSearch Vector DB") + params["http_auth"] = (self.user, self.password) + case AuthMethod.AWS_MANAGED_IAM: + logger.info("Using AWS managed IAM role for OpenSearch Vector DB") - params["http_auth"] = self.create_aws_managed_iam_auth() + params["http_auth"] = self.create_aws_managed_iam_auth() return params