Fix/refactor invoke result handling in question classifier node (#12015)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-23 17:54:08 +08:00 committed by GitHub
parent af2888d394
commit c3c85276d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,10 +1,8 @@
import json
import logging
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.llm_generator.output_parser.errors import OutputParserError
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
@ -96,27 +94,28 @@ class QuestionClassifierNode(LLMNode):
jinja2_variables=[],
)
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
try:
# handle invoke result
generator = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
break
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:
@ -127,10 +126,6 @@ class QuestionClassifierNode(LLMNode):
if category_id_result in category_ids:
category_name = classes_map[category_id_result]
category_id = category_id_result
except OutputParserError:
logging.exception(f"Failed to parse result text: {result_text}")
try:
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
@ -154,7 +149,7 @@ class QuestionClassifierNode(LLMNode):
},
llm_usage=usage,
)
except Exception as e:
except ValueError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,