feat(knowledge_retrieval_node): Suppress exceptions thrown by DatasetRetrieval (#11728)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2024-12-17 13:12:29 +08:00 committed by GitHub
parent a399502ecd
commit 62b9e5a6f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 14 deletions

View File

@ -72,7 +72,11 @@ class BaseNode(Generic[GenericNodeData]):
result = self._run() result = self._run()
except Exception as e: except Exception as e:
logger.exception(f"Node {self.node_id} failed to run") logger.exception(f"Node {self.node_id} failed to run")
result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError") result = NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
error_type="WorkflowNodeError",
)
if isinstance(result, NodeRunResult): if isinstance(result, NodeRunResult):
yield RunCompletedEvent(run_result=result) yield RunCompletedEvent(run_result=result)

View File

@ -70,7 +70,20 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
except KnowledgeRetrievalNodeError as e: except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node") logger.warning("Error when running knowledge retrieval node")
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
)
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
available_datasets = [] available_datasets = []
@ -160,18 +173,18 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
reranking_model = None reranking_model = None
weights = None weights = None
all_documents = dataset_retrieval.multiple_retrieve( all_documents = dataset_retrieval.multiple_retrieve(
self.app_id, app_id=self.app_id,
self.tenant_id, tenant_id=self.tenant_id,
self.user_id, user_id=self.user_id,
self.user_from.value, user_from=self.user_from.value,
available_datasets, available_datasets=available_datasets,
query, query=query,
node_data.multiple_retrieval_config.top_k, top_k=node_data.multiple_retrieval_config.top_k,
node_data.multiple_retrieval_config.score_threshold, score_threshold=node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.reranking_mode, reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model, reranking_model=reranking_model,
weights, weights=weights,
node_data.multiple_retrieval_config.reranking_enable, reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
) )
dify_documents = [item for item in all_documents if item.provider == "dify"] dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"] external_documents = [item for item in all_documents if item.provider == "external"]