From 92606fbd4c42dd6803a65d9f367c163a876474f8 Mon Sep 17 00:00:00 2001 From: He Tao Date: Tue, 15 Apr 2025 16:36:02 +0800 Subject: [PATCH] feat: support interrupt protocol --- docs/sse_integration_test.md | 14 ++++---------- src/graph/nodes.py | 6 +++--- src/server/app.py | 31 ++++++++++++++++++++++++++----- src/server/chat_request.py | 8 ++++---- 4 files changed, 37 insertions(+), 22 deletions(-) diff --git a/docs/sse_integration_test.md b/docs/sse_integration_test.md index e7635a3..1d00f84 100644 --- a/docs/sse_integration_test.md +++ b/docs/sse_integration_test.md @@ -49,26 +49,20 @@ curl --location 'http://localhost:8000/api/chat/stream' \ "messages": [ { "role": "user", - "content": "what is mcp?" + "content": "make the last step be comprehensive" } ], "thread_id": "test_thread_2", "auto_accepted_plan": false, - "feedback": "[EDIT PLAN] make the last step be comprehensive" + "interrupt_feedback": "edit_plan" } ``` -### Accepted the plan +### Accept the plan ``` { - "messages": [ - { - "role": "user", - "content": "what is mcp?" - } - ], "thread_id": "test_thread_2", "auto_accepted_plan": false, - "feedback": "[ACCEPTED]" + "interrupt_feedback": "accepted" } ``` \ No newline at end of file diff --git a/src/graph/nodes.py b/src/graph/nodes.py index 6cfa7c2..74b852d 100644 --- a/src/graph/nodes.py +++ b/src/graph/nodes.py @@ -75,10 +75,10 @@ def human_feedback_node( # check if the plan is auto accepted auto_accepted_plan = state.get("auto_accepted_plan", False) if not auto_accepted_plan: - feedback = interrupt(current_plan) + feedback = interrupt("Please Review the Plan.") # if the feedback is not accepted, return the planner node - if feedback and str(feedback).upper() != "[ACCEPTED]": + if feedback and str(feedback).upper().startswith("[EDIT_PLAN]"): return Command( update={ "messages": [ @@ -87,7 +87,7 @@ def human_feedback_node( }, goto="planner", ) - elif feedback and str(feedback).upper() == "[ACCEPTED]": + elif feedback and str(feedback).upper().startswith("[ACCEPTED]"): logger.info("Plan is accepted by user.") else: raise TypeError(f"Interrupt value of {feedback} is not supported.") diff --git a/src/server/app.py b/src/server/app.py index c1de5ab..5a1d3f6 100644 --- a/src/server/app.py +++ b/src/server/app.py @@ -44,7 +44,7 @@ async def chat_stream(request: ChatRequest): request.max_plan_iterations, request.max_step_num, request.auto_accepted_plan, - request.feedback, + request.interrupt_feedback, ), media_type="text/event-stream", ) @@ -56,11 +56,15 @@ async def _astream_workflow_generator( max_plan_iterations: int, max_step_num: int, auto_accepted_plan: bool, - feedback: str, + interrupt_feedback: str, ): input_ = {"messages": messages, "auto_accepted_plan": auto_accepted_plan} - if not auto_accepted_plan and feedback: - input_ = Command(resume=feedback) + if not auto_accepted_plan and interrupt_feedback: + resume_msg = f"[{interrupt_feedback}]" + # add the last message to the resume message + if messages: + resume_msg += f" {messages[-1]["content"]}" + input_ = Command(resume=resume_msg) async for agent, _, event_data in graph.astream( input_, config={ @@ -68,9 +72,26 @@ async def _astream_workflow_generator( "max_plan_iterations": max_plan_iterations, "max_step_num": max_step_num, }, - stream_mode=["messages"], + stream_mode=["messages", "updates"], subgraphs=True, ): + if isinstance(event_data, dict): + if "__interrupt__" in event_data: + yield _make_event( + "interrupt", + { + "thread_id": thread_id, + "id": event_data["__interrupt__"][0].ns[0], + "role": "assistant", + "content": event_data["__interrupt__"][0].value, + "finish_reason": "interrupt", + "options": [ + {"text": "Accept", "value": "accepted"}, + {"text": "Edit", "value": "edit_plan"}, + ], + }, + ) + continue message_chunk, message_metadata = cast( tuple[AIMessageChunk, dict[str, any]], event_data ) diff --git a/src/server/chat_request.py b/src/server/chat_request.py index 0d1cb79..8a44ee3 100644 --- a/src/server/chat_request.py +++ b/src/server/chat_request.py @@ -22,8 +22,8 @@ class ChatMessage(BaseModel): class ChatRequest(BaseModel): - messages: List[ChatMessage] = Field( - ..., description="History of messages between the user and the assistant" + messages: Optional[List[ChatMessage]] = Field( + [], description="History of messages between the user and the assistant" ) debug: Optional[bool] = Field(False, description="Whether to enable debug logging") thread_id: Optional[str] = Field( @@ -38,6 +38,6 @@ class ChatRequest(BaseModel): auto_accepted_plan: Optional[bool] = Field( False, description="Whether to automatically accept the plan" ) - feedback: Optional[str] = Field( - None, description="Feedback from the user on the plan" + interrupt_feedback: Optional[str] = Field( + None, description="Interrupt feedback from the user on the plan" )