issue: #17056 : Add a reason field to the message_replace event (#17195)

Co-authored-by: 聂政 <niezheng@pjlab.org.cn>
This commit is contained in:
just2gooo 2025-04-25 10:08:06 +08:00 committed by GitHub
parent 37e2f73909
commit 5e2b3b34e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 36 additions and 12 deletions

View File

@ -684,7 +684,9 @@ class AdvancedChatAppGenerateTaskPipeline:
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text) yield self._message_cycle_manager._message_replace_to_stream_response(
answer=event.text, reason=event.reason
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent): elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
@ -695,7 +697,8 @@ class AdvancedChatAppGenerateTaskPipeline:
if output_moderation_answer: if output_moderation_answer:
self._task_state.answer = output_moderation_answer self._task_state.answer = output_moderation_answer
yield self._message_cycle_manager._message_replace_to_stream_response( yield self._message_cycle_manager._message_replace_to_stream_response(
answer=output_moderation_answer answer=output_moderation_answer,
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
) )
# Save message # Save message

View File

@ -264,8 +264,16 @@ class QueueMessageReplaceEvent(AppQueueEvent):
QueueMessageReplaceEvent entity QueueMessageReplaceEvent entity
""" """
class MessageReplaceReason(StrEnum):
"""
Reason for message replace event
"""
OUTPUT_MODERATION = "output_moderation"
event: QueueEvent = QueueEvent.MESSAGE_REPLACE event: QueueEvent = QueueEvent.MESSAGE_REPLACE
text: str text: str
reason: str
class QueueRetrieverResourcesEvent(AppQueueEvent): class QueueRetrieverResourcesEvent(AppQueueEvent):

View File

@ -148,6 +148,7 @@ class MessageReplaceStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_REPLACE event: StreamEvent = StreamEvent.MESSAGE_REPLACE
answer: str answer: str
reason: str
class AgentThoughtStreamResponse(StreamResponse): class AgentThoughtStreamResponse(StreamResponse):

View File

@ -126,12 +126,12 @@ class BasedGenerateTaskPipeline:
if self._output_moderation_handler: if self._output_moderation_handler:
self._output_moderation_handler.stop_thread() self._output_moderation_handler.stop_thread()
completion = self._output_moderation_handler.moderation_completion( completion, flagged = self._output_moderation_handler.moderation_completion(
completion=completion, public_event=False completion=completion, public_event=False
) )
self._output_moderation_handler = None self._output_moderation_handler = None
if flagged:
return completion return completion
return None return None

View File

@ -182,10 +182,12 @@ class MessageCycleManage:
from_variable_selector=from_variable_selector, from_variable_selector=from_variable_selector,
) )
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
""" """
Message replace to stream response. Message replace to stream response.
:param answer: answer :param answer: answer
:return: :return:
""" """
return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer) return MessageReplaceStreamResponse(
task_id=self._application_generate_entity.task_id, answer=answer, reason=reason
)

View File

@ -46,14 +46,14 @@ class OutputModeration(BaseModel):
if not self.thread: if not self.thread:
self.thread = self.start_thread() self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str: def moderation_completion(self, completion: str, public_event: bool = False) -> tuple[str, bool]:
self.buffer = completion self.buffer = completion
self.is_final_chunk = True self.is_final_chunk = True
result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion) result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion)
if not result or not result.flagged: if not result or not result.flagged:
return completion return completion, False
if result.action == ModerationAction.DIRECT_OUTPUT: if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response final_output = result.preset_response
@ -61,9 +61,14 @@ class OutputModeration(BaseModel):
final_output = result.text final_output = result.text
if public_event: if public_event:
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) self.queue_manager.publish(
QueueMessageReplaceEvent(
text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION
),
PublishFrom.TASK_PIPELINE,
)
return final_output return final_output, True
def start_thread(self) -> threading.Thread: def start_thread(self) -> threading.Thread:
buffer_size = dify_config.MODERATION_BUFFER_SIZE buffer_size = dify_config.MODERATION_BUFFER_SIZE
@ -112,7 +117,12 @@ class OutputModeration(BaseModel):
# trigger replace event # trigger replace event
if self.thread_running: if self.thread_running:
self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) self.queue_manager.publish(
QueueMessageReplaceEvent(
text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION
),
PublishFrom.TASK_PIPELINE,
)
if result.action == ModerationAction.DIRECT_OUTPUT: if result.action == ModerationAction.DIRECT_OUTPUT:
break break