mirror of
https://github.com/langgenius/dify.git
synced 2025-12-19 17:27:16 -05:00
Fix json in md when use quesion classifier node (#26992)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -194,6 +195,8 @@ class QuestionClassifierNode(Node):
|
||||
|
||||
category_name = node_data.classes[0].name
|
||||
category_id = node_data.classes[0].id
|
||||
if "<think>" in result_text:
|
||||
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
||||
if "category_name" in result_text_json and "category_id" in result_text_json:
|
||||
|
||||
@@ -6,22 +6,22 @@ from core.llm_generator.output_parser.errors import OutputParserError
|
||||
def parse_json_markdown(json_string: str):
|
||||
# Get json from the backticks/braces
|
||||
json_string = json_string.strip()
|
||||
starts = ["```json", "```", "``", "`", "{"]
|
||||
ends = ["```", "``", "`", "}"]
|
||||
starts = ["```json", "```", "``", "`", "{", "["]
|
||||
ends = ["```", "``", "`", "}", "]"]
|
||||
end_index = -1
|
||||
start_index = 0
|
||||
parsed: dict = {}
|
||||
for s in starts:
|
||||
start_index = json_string.find(s)
|
||||
if start_index != -1:
|
||||
if json_string[start_index] != "{":
|
||||
if json_string[start_index] not in ("{", "["):
|
||||
start_index += len(s)
|
||||
break
|
||||
if start_index != -1:
|
||||
for e in ends:
|
||||
end_index = json_string.rfind(e, start_index)
|
||||
if end_index != -1:
|
||||
if json_string[end_index] == "}":
|
||||
if json_string[end_index] in ("}", "]"):
|
||||
end_index += 1
|
||||
break
|
||||
if start_index != -1 and end_index != -1 and start_index < end_index:
|
||||
@@ -38,6 +38,12 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]):
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserError(f"got invalid json object. error: {e}")
|
||||
|
||||
if isinstance(json_obj, list):
|
||||
if len(json_obj) == 1 and isinstance(json_obj[0], dict):
|
||||
json_obj = json_obj[0]
|
||||
else:
|
||||
raise OutputParserError(f"got invalid return object. obj:{json_obj}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserError(
|
||||
|
||||
@@ -86,3 +86,24 @@ def test_parse_and_check_json_markdown_multiple_blocks_fails():
|
||||
# opening fence to the last closing fence, causing JSON decode failure.
|
||||
with pytest.raises(OutputParserError):
|
||||
parse_and_check_json_markdown(src, [])
|
||||
|
||||
|
||||
def test_parse_and_check_json_markdown_handles_think_fenced_and_raw_variants():
|
||||
expected = {"keywords": ["2"], "category_id": "2", "category_name": "2"}
|
||||
cases = [
|
||||
"""
|
||||
```json
|
||||
[{"keywords": ["2"], "category_id": "2", "category_name": "2"}]
|
||||
```, error: Expecting value: line 1 column 1 (char 0)
|
||||
""",
|
||||
"""
|
||||
```json
|
||||
{"keywords": ["2"], "category_id": "2", "category_name": "2"}
|
||||
```, error: Extra data: line 2 column 5 (char 66)
|
||||
""",
|
||||
'{"keywords": ["2"], "category_id": "2", "category_name": "2"}',
|
||||
'[{"keywords": ["2"], "category_id": "2", "category_name": "2"}]',
|
||||
]
|
||||
for src in cases:
|
||||
obj = parse_and_check_json_markdown(src, ["keywords", "category_id", "category_name"])
|
||||
assert obj == expected
|
||||
|
||||
Reference in New Issue
Block a user