diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 3bf6c330db..baefca0c3f 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -684,7 +684,9 @@ class AdvancedChatAppGenerateTaskPipeline: ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation - yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text) + yield self._message_cycle_manager._message_replace_to_stream_response( + answer=event.text, reason=event.reason + ) elif isinstance(event, QueueAdvancedChatMessageEndEvent): if not graph_runtime_state: raise ValueError("graph runtime state not initialized.") @@ -695,7 +697,8 @@ class AdvancedChatAppGenerateTaskPipeline: if output_moderation_answer: self._task_state.answer = output_moderation_answer yield self._message_cycle_manager._message_replace_to_stream_response( - answer=output_moderation_answer + answer=output_moderation_answer, + reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, ) # Save message diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 3702326406..7228020e9b 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -264,8 +264,16 @@ class QueueMessageReplaceEvent(AppQueueEvent): QueueMessageReplaceEvent entity """ + class MessageReplaceReason(StrEnum): + """ + Reason for message replace event + """ + + OUTPUT_MODERATION = "output_moderation" + event: QueueEvent = QueueEvent.MESSAGE_REPLACE text: str + reason: str class QueueRetrieverResourcesEvent(AppQueueEvent): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index f23ee1b9fd..817699bd20 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -148,6 +148,7 @@ class MessageReplaceStreamResponse(StreamResponse): event: StreamEvent = StreamEvent.MESSAGE_REPLACE answer: str + reason: str class AgentThoughtStreamResponse(StreamResponse): diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index a2e06d4e1f..5331c0cc94 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -126,12 +126,12 @@ class BasedGenerateTaskPipeline: if self._output_moderation_handler: self._output_moderation_handler.stop_thread() - completion = self._output_moderation_handler.moderation_completion( + completion, flagged = self._output_moderation_handler.moderation_completion( completion=completion, public_event=False ) self._output_moderation_handler = None - - return completion + if flagged: + return completion return None diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index 6223b33b67..fde506639f 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -182,10 +182,12 @@ class MessageCycleManage: from_variable_selector=from_variable_selector, ) - def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse: + def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: """ Message replace to stream response. :param answer: answer :return: """ - return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer) + return MessageReplaceStreamResponse( + task_id=self._application_generate_entity.task_id, answer=answer, reason=reason + ) diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py index e595be126c..2ec315417f 100644 --- a/api/core/moderation/output_moderation.py +++ b/api/core/moderation/output_moderation.py @@ -46,14 +46,14 @@ class OutputModeration(BaseModel): if not self.thread: self.thread = self.start_thread() - def moderation_completion(self, completion: str, public_event: bool = False) -> str: + def moderation_completion(self, completion: str, public_event: bool = False) -> tuple[str, bool]: self.buffer = completion self.is_final_chunk = True result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion) if not result or not result.flagged: - return completion + return completion, False if result.action == ModerationAction.DIRECT_OUTPUT: final_output = result.preset_response @@ -61,9 +61,14 @@ class OutputModeration(BaseModel): final_output = result.text if public_event: - self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) + self.queue_manager.publish( + QueueMessageReplaceEvent( + text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION + ), + PublishFrom.TASK_PIPELINE, + ) - return final_output + return final_output, True def start_thread(self) -> threading.Thread: buffer_size = dify_config.MODERATION_BUFFER_SIZE @@ -112,7 +117,12 @@ class OutputModeration(BaseModel): # trigger replace event if self.thread_running: - self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE) + self.queue_manager.publish( + QueueMessageReplaceEvent( + text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION + ), + PublishFrom.TASK_PIPELINE, + ) if result.action == ModerationAction.DIRECT_OUTPUT: break