feat: Enhance knowledge base node validation by adding checks for embedding and reranking models (#27241)

This commit is contained in:
Wu Tianwei
2025-10-22 10:49:49 +08:00
committed by GitHub
parent 845adb664a
commit f909040567
7 changed files with 54 additions and 9 deletions

View File

@@ -42,6 +42,9 @@ import { fetchDatasets } from '@/service/datasets'
import { MAX_TREE_DEPTH } from '@/config'
import useNodesAvailableVarList, { useGetNodesAvailableVarList } from './use-nodes-available-var-list'
import { getNodeUsedVars, isSpecialVar } from '../nodes/_base/components/variable/utils'
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { KnowledgeBaseNodeType } from '../nodes/knowledge-base/types'
export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const { t } = useTranslation()
@@ -57,6 +60,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const getToolIcon = useGetToolIcon()
const map = useNodesAvailableVarList(nodes)
const { data: embeddingModelList } = useModelList(ModelTypeEnum.textEmbedding)
const { data: rerankModelList } = useModelList(ModelTypeEnum.rerank)
const getCheckData = useCallback((data: CommonNodeType<{}>) => {
let checkData = data
@@ -72,8 +77,15 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
_datasets,
} as CommonNodeType<KnowledgeRetrievalNodeType>
}
else if (data.type === BlockEnum.KnowledgeBase) {
checkData = {
...data,
_embeddingModelList: embeddingModelList,
_rerankModelList: rerankModelList,
} as CommonNodeType<KnowledgeBaseNodeType>
}
return checkData
}, [datasetsDetail])
}, [datasetsDetail, embeddingModelList, rerankModelList])
const needWarningNodes = useMemo(() => {
const list = []

View File

@@ -57,6 +57,7 @@ const EmbeddingModel = ({
modelList={embeddingModelList}
onSelect={handleEmbeddingModelChange}
readonly={readonly}
showDeprecatedWarnIcon
/>
</Field>
)

View File

@@ -44,6 +44,7 @@ const RerankingModelSelector = ({
modelList={rerankModelList}
onSelect={handleRerankingModelChange}
readonly={readonly}
showDeprecatedWarnIcon
/>
)
}

View File

@@ -31,6 +31,8 @@ const nodeDefault: NodeDefault<KnowledgeBaseNodeType> = {
embedding_model,
embedding_model_provider,
index_chunk_variable_selector,
_embeddingModelList,
_rerankModelList,
} = payload
const {
@@ -39,6 +41,12 @@ const nodeDefault: NodeDefault<KnowledgeBaseNodeType> = {
reranking_model,
} = retrieval_model || {}
const currentEmbeddingModelProvider = _embeddingModelList?.find(provider => provider.provider === embedding_model_provider)
const currentEmbeddingModel = currentEmbeddingModelProvider?.models.find(model => model.model === embedding_model)
const currentRerankingModelProvider = _rerankModelList?.find(provider => provider.provider === reranking_model?.reranking_provider_name)
const currentRerankingModel = currentRerankingModelProvider?.models.find(model => model.model === reranking_model?.reranking_model_name)
if (!chunk_structure) {
return {
isValid: false,
@@ -60,10 +68,18 @@ const nodeDefault: NodeDefault<KnowledgeBaseNodeType> = {
}
}
if (indexing_technique === IndexingType.QUALIFIED && (!embedding_model || !embedding_model_provider)) {
return {
isValid: false,
errorMessage: t('workflow.nodes.knowledgeBase.embeddingModelIsRequired'),
if (indexing_technique === IndexingType.QUALIFIED) {
if (!embedding_model || !embedding_model_provider) {
return {
isValid: false,
errorMessage: t('workflow.nodes.knowledgeBase.embeddingModelIsRequired'),
}
}
else if (!currentEmbeddingModel) {
return {
isValid: false,
errorMessage: t('workflow.nodes.knowledgeBase.embeddingModelIsInvalid'),
}
}
}
@@ -74,10 +90,18 @@ const nodeDefault: NodeDefault<KnowledgeBaseNodeType> = {
}
}
if (reranking_enable && (!reranking_model || !reranking_model.reranking_provider_name || !reranking_model.reranking_model_name)) {
return {
isValid: false,
errorMessage: t('workflow.nodes.knowledgeBase.rerankingModelIsRequired'),
if (reranking_enable) {
if (!reranking_model || !reranking_model.reranking_provider_name || !reranking_model.reranking_model_name) {
return {
isValid: false,
errorMessage: t('workflow.nodes.knowledgeBase.rerankingModelIsRequired'),
}
}
else if (!currentRerankingModel) {
return {
isValid: false,
errorMessage: t('workflow.nodes.knowledgeBase.rerankingModelIsInvalid'),
}
}
}

View File

@@ -3,6 +3,7 @@ import type { IndexingType } from '@/app/components/datasets/create/step-two'
import type { RETRIEVE_METHOD } from '@/types/app'
import type { WeightedScoreEnum } from '@/models/datasets'
import type { RerankingModeEnum } from '@/models/datasets'
import type { Model } from '@/app/components/header/account-setting/model-provider-page/declarations'
export { WeightedScoreEnum } from '@/models/datasets'
export { IndexingType as IndexMethodEnum } from '@/app/components/datasets/create/step-two'
export { RETRIEVE_METHOD as RetrievalSearchMethodEnum } from '@/types/app'
@@ -49,4 +50,6 @@ export type KnowledgeBaseNodeType = CommonNodeType & {
embedding_model_provider?: string
keyword_number: number
retrieval_model: RetrievalSetting
_embeddingModelList?: Model[]
_rerankModelList?: Model[]
}

View File

@@ -959,8 +959,10 @@ const translation = {
indexMethodIsRequired: 'Index method is required',
chunksVariableIsRequired: 'Chunks variable is required',
embeddingModelIsRequired: 'Embedding model is required',
embeddingModelIsInvalid: 'Embedding model is invalid',
retrievalSettingIsRequired: 'Retrieval setting is required',
rerankingModelIsRequired: 'Reranking model is required',
rerankingModelIsInvalid: 'Reranking model is invalid',
},
},
tracing: {

View File

@@ -959,8 +959,10 @@ const translation = {
indexMethodIsRequired: '索引方法是必需的',
chunksVariableIsRequired: 'Chunks 变量是必需的',
embeddingModelIsRequired: 'Embedding 模型是必需的',
embeddingModelIsInvalid: '无效的 Embedding 模型',
retrievalSettingIsRequired: '检索设置是必需的',
rerankingModelIsRequired: 'Reranking 模型是必需的',
rerankingModelIsInvalid: '无效的 Reranking 模型',
},
},
tracing: {